summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_mixins.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@gmail.com>2017-04-21 09:35:45 -0700
committerCharles Harris <charlesr.harris@gmail.com>2017-04-27 13:37:51 -0600
commit02600d38f3b2e70c3cd07770f93c3bac5255c8a6 (patch)
tree36cf7474c136258dd69ce3bd54dc67c54ff105cf /numpy/lib/tests/test_mixins.py
parent1e460b74bac7da0d9029b1fd414213f00bb66c9f (diff)
downloadnumpy-02600d38f3b2e70c3cd07770f93c3bac5255c8a6.tar.gz
ENH: Add NDArrayOperatorsMixin mixin class.
This mixin class provides an easy way to implement arithmetic operators that defer to __array_ufunc__ like numpy.ndarray in non-ndarray subclasses.
Diffstat (limited to 'numpy/lib/tests/test_mixins.py')
-rw-r--r--numpy/lib/tests/test_mixins.py189
1 files changed, 189 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_mixins.py b/numpy/lib/tests/test_mixins.py
new file mode 100644
index 000000000..f45a3c661
--- /dev/null
+++ b/numpy/lib/tests/test_mixins.py
@@ -0,0 +1,189 @@
+from __future__ import division, absolute_import, print_function
+
+import numbers
+import operator
+import sys
+
+import numpy as np
+from numpy.testing import (
+ TestCase, run_module_suite, assert_, assert_equal, assert_raises)
+
+
+PY2 = sys.version_info.major < 3
+
+
+# NOTE: This class should be kept as an exact copy of the example from the
+# docstring for NDArrayOperatorsMixin.
+
+class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin):
+ def __init__(self, value):
+ self.value = np.asarray(value)
+
+ # One might also consider adding the built-in list type to this
+ # list, to support operations like np.add(array_like, list)
+ _HANDLED_TYPES = (np.ndarray, numbers.Number)
+
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
+ out = kwargs.get('out', ())
+ for x in inputs + out:
+ # Only support operations with instances of _HANDLED_TYPES
+ # and superclass instances of this type
+ if not (isinstance(x, self._HANDLED_TYPES) or
+ isinstance(self, type(x))):
+ return NotImplemented
+
+ # Defer to the implementation of the ufunc on unwrapped values
+ inputs = tuple(x.value if isinstance(self, type(x)) else x
+ for x in inputs)
+ if out:
+ kwargs['out'] = tuple(
+ x.value if isinstance(self, type(x)) else x
+ for x in out)
+ result = getattr(ufunc, method)(*inputs, **kwargs)
+
+ if type(result) is tuple:
+ # multiple return values
+ return tuple(type(self)(x) for x in result)
+ elif method == 'at':
+ # no return value
+ return None
+ else:
+ # one return value
+ return type(self)(result)
+
+ def __repr__(self):
+ return '%s(%r)' % (type(self).__name__, self.value)
+
+
+def _assert_equal_type_and_value(result, expected, err_msg=None):
+ assert_equal(type(result), type(expected), err_msg=err_msg)
+ assert_equal(result.value, expected.value, err_msg=err_msg)
+ assert_equal(getattr(result.value, 'dtype', None),
+ getattr(expected.value, 'dtype', None), err_msg=err_msg)
+
+
+class TestNDArrayOperatorsMixin(TestCase):
+
+ def test_array_like_add(self):
+
+ def check(result):
+ _assert_equal_type_and_value(result, ArrayLike(0))
+
+ check(ArrayLike(0) + 0)
+ check(0 + ArrayLike(0))
+
+ check(ArrayLike(0) + np.array(0))
+ check(np.array(0) + ArrayLike(0))
+
+ check(ArrayLike(np.array(0)) + 0)
+ check(0 + ArrayLike(np.array(0)))
+
+ check(ArrayLike(np.array(0)) + np.array(0))
+ check(np.array(0) + ArrayLike(np.array(0)))
+
+ def test_inplace(self):
+ array_like = ArrayLike(np.array([0]))
+ array_like += 1
+ _assert_equal_type_and_value(array_like, ArrayLike(np.array([1])))
+
+ array = np.array([0])
+ array += ArrayLike(1)
+ _assert_equal_type_and_value(array, ArrayLike(np.array([1])))
+
+ def test_opt_out(self):
+
+ class OptOut(object):
+ """Object that opts out of __array_ufunc__."""
+ __array_ufunc__ = None
+
+ def __add__(self, other):
+ return self
+
+ def __radd__(self, other):
+ return self
+
+ array_like = ArrayLike(1)
+ opt_out = OptOut()
+
+ # supported operations
+ assert_(array_like + opt_out is opt_out)
+ assert_(opt_out + array_like is opt_out)
+
+ # not supported
+ with assert_raises(TypeError):
+ # don't use the Python default, array_like = array_like + opt_out
+ array_like += opt_out
+ with assert_raises(TypeError):
+ array_like - opt_out
+ with assert_raises(TypeError):
+ opt_out - array_like
+
+ def test_subclass(self):
+
+ class SubArrayLike(ArrayLike):
+ """Should take precedence over ArrayLike."""
+
+ x = ArrayLike(0)
+ y = SubArrayLike(1)
+ _assert_equal_type_and_value(x + y, y)
+ _assert_equal_type_and_value(y + x, y)
+
+ def test_unary_methods(self):
+ array = np.array([-1, 0, 1, 2])
+ array_like = ArrayLike(array)
+ for op in [operator.neg,
+ # pos is not yet implemented
+ abs,
+ operator.invert]:
+ _assert_equal_type_and_value(op(array_like), ArrayLike(op(array)))
+
+ def test_binary_methods(self):
+ array = np.array([-1, 0, 1, 2])
+ array_like = ArrayLike(array)
+ operators = [
+ operator.lt,
+ operator.le,
+ operator.eq,
+ operator.ne,
+ operator.gt,
+ operator.ge,
+ operator.add,
+ operator.sub,
+ operator.mul,
+ operator.truediv,
+ operator.floordiv,
+ # TODO: test div on Python 2, only
+ operator.mod,
+ # divmod is not yet implemented
+ pow,
+ operator.lshift,
+ operator.rshift,
+ operator.and_,
+ operator.xor,
+ operator.or_,
+ ]
+ for op in operators:
+ expected = ArrayLike(op(array, 1))
+ actual = op(array_like, 1)
+ err_msg = 'failed for operator {}'.format(op)
+ _assert_equal_type_and_value(expected, actual, err_msg=err_msg)
+
+ def test_ufunc_at(self):
+ array = ArrayLike(np.array([1, 2, 3, 4]))
+ assert_(np.negative.at(array, np.array([0, 1])) is None)
+ _assert_equal_type_and_value(array, ArrayLike([-1, -2, 3, 4]))
+
+ def test_ufunc_two_outputs(self):
+ def check(result):
+ assert_(type(result) is tuple)
+ assert_equal(len(result), 2)
+ mantissa, exponent = np.frexp(2 ** -3)
+ _assert_equal_type_and_value(result[0], ArrayLike(mantissa))
+ _assert_equal_type_and_value(result[1], ArrayLike(exponent))
+
+ check(np.frexp(ArrayLike(2 ** -3)))
+ check(np.frexp(ArrayLike(np.array(2 ** -3))))
+
+
+if __name__ == "__main__":
+ run_module_suite()