summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_mixins.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2017-04-22 18:01:35 -0700
committerCharles Harris <charlesr.harris@gmail.com>2017-04-27 13:37:51 -0600
commit256a8ae75fc36f7d4531557f9572a046508afa07 (patch)
tree80b2bd8c59c173a1fe56c2ed419684cef60de6fd /numpy/lib/tests/test_mixins.py
parentb9359f1d9fede0d4ecc08e868e2b0dcb85dbd7e2 (diff)
downloadnumpy-256a8ae75fc36f7d4531557f9572a046508afa07.tar.gz
BUG: Fix ArrayLike(NDArrayOperatorsMixin) operations with object()
Diffstat (limited to 'numpy/lib/tests/test_mixins.py')
-rw-r--r--numpy/lib/tests/test_mixins.py26
1 files changed, 20 insertions, 6 deletions
diff --git a/numpy/lib/tests/test_mixins.py b/numpy/lib/tests/test_mixins.py
index f45a3c661..bca974fc5 100644
--- a/numpy/lib/tests/test_mixins.py
+++ b/numpy/lib/tests/test_mixins.py
@@ -26,18 +26,23 @@ class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin):
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
out = kwargs.get('out', ())
for x in inputs + out:
- # Only support operations with instances of _HANDLED_TYPES
- # and superclass instances of this type
+ # Only support operations with instances of _HANDLED_TYPES,
+ # or instances of ArrayLike that are superclasses of this
+ # object's type.
if not (isinstance(x, self._HANDLED_TYPES) or
- isinstance(self, type(x))):
+ (isinstance(x, ArrayLike) and
+ isinstance(self, type(x)))):
return NotImplemented
- # Defer to the implementation of the ufunc on unwrapped values
- inputs = tuple(x.value if isinstance(self, type(x)) else x
+ # Defer to the implementation of the ufunc on unwrapped values.
+ # Use ArrayLike instead of type(self) for isinstance to allow
+ # subclasses that don't override __array_ufunc__ to handle
+ # ArrayLike objects.
+ inputs = tuple(x.value if isinstance(x, ArrayLike) else x
for x in inputs)
if out:
kwargs['out'] = tuple(
- x.value if isinstance(self, type(x)) else x
+ x.value if isinstance(x, ArrayLike) else x
for x in out)
result = getattr(ufunc, method)(*inputs, **kwargs)
@@ -128,6 +133,15 @@ class TestNDArrayOperatorsMixin(TestCase):
_assert_equal_type_and_value(x + y, y)
_assert_equal_type_and_value(y + x, y)
+ def test_object(self):
+ x = ArrayLike(0)
+ obj = object()
+ assert_equal(x.__add__(obj), NotImplemented)
+ with assert_raises(TypeError):
+ x + obj
+ with assert_raises(TypeError):
+ obj + x
+
def test_unary_methods(self):
array = np.array([-1, 0, 1, 2])
array_like = ArrayLike(array)