summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBlake Griffith <blake.a.griffith@gmail.com>2013-07-22 23:22:55 -0500
committerBlake Griffith <blake.a.griffith@gmail.com>2013-09-05 20:09:44 -0500
commit5c630f06f9012cdfff9b35cf96cea4a696c6b66c (patch)
tree753899866644709a1d043c8b2dd31effcec6cfd3
parent21976ca31eda537824ad877d432634eaf103567b (diff)
downloadnumpy-5c630f06f9012cdfff9b35cf96cea4a696c6b66c.tar.gz
TST: Add ufunc override tests.
-rw-r--r--numpy/core/tests/test_blasdot.py18
-rw-r--r--numpy/core/tests/test_multiarray.py18
-rw-r--r--numpy/core/tests/test_umath.py181
3 files changed, 217 insertions, 0 deletions
diff --git a/numpy/core/tests/test_blasdot.py b/numpy/core/tests/test_blasdot.py
index 624c617d3..caa576abc 100644
--- a/numpy/core/tests/test_blasdot.py
+++ b/numpy/core/tests/test_blasdot.py
@@ -151,3 +151,21 @@ def test_dot_array_order():
assert_almost_equal(c.T.dot(b.T).T, b.dot(c), decimal=prec)
assert_almost_equal(b.dot(c), _dot(b, c), decimal=prec)
assert_almost_equal(c.T.dot(b.T), _dot(c.T, b.T), decimal=prec)
+
+def test_dot_override():
+ class A(object):
+ def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
+ return "A"
+
+ class B(object):
+ def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
+ return NotImplemented
+
+ a = A()
+ b = B()
+ c = np.array([[1]])
+
+ assert_equal(np.dot(a, b), "A")
+ assert_equal(c.dot(a), "A")
+ assert_raises(TypeError, np.dot, b, c)
+ assert_raises(TypeError, c.dot, b)
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 9870d44f3..0e7f6ec61 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1324,6 +1324,24 @@ class TestMethods(TestCase):
a.dot(b=b, out=c)
assert_equal(c, np.dot(a, b))
+ def test_dot_override(self):
+ class A(object):
+ def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
+ return "A"
+
+ class B(object):
+ def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
+ return NotImplemented
+
+ a = A()
+ b = B()
+ c = np.array([[1]])
+
+ assert_equal(np.dot(a, b), "A")
+ assert_equal(c.dot(a), "A")
+ assert_raises(TypeError, np.dot, b, c)
+ assert_raises(TypeError, c.dot, b)
+
def test_diagonal(self):
a = np.arange(12).reshape((3, 4))
assert_equal(a.diagonal(), [0, 5, 10])
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index eb18304ea..3706cfa81 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -868,6 +868,187 @@ class TestSpecialMethods(TestCase):
assert_equal(ncu.maximum(a, B()), 0)
assert_equal(ncu.maximum(a, C()), 0)
+ def test_ufunc_override(self):
+ class A(object):
+ def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs):
+ return self, func, method, pos, inputs, kwargs
+
+ a = A()
+
+ b = np.matrix([1])
+ c = np.array([1])
+ res0 = np.multiply(a, b)
+ res1 = np.dot(a, b)
+
+ # self
+ assert_equal(res0[0], a)
+ assert_equal(res1[0], a)
+ assert_equal(res0[1], np.multiply)
+ assert_equal(res1[1], np.dot)
+ assert_equal(res0[2], '__call__')
+ assert_equal(res1[2], '__call__')
+ assert_equal(res0[3], 0)
+ assert_equal(res1[3], 0)
+ assert_equal(res0[4], (a, b))
+ assert_equal(res1[4], (a, b))
+ assert_equal(res0[5], {})
+ assert_equal(res1[5], {})
+
+ def test_ufunc_override_mro(self):
+
+ # Some multi arg functions for testing.
+ def tres_mul(a, b, c):
+ return a * b * c
+
+ def quatro_mul(a, b, c, d):
+ return a * b * c * d
+
+ # Make these into ufuncs.
+ three_mul_ufunc = np.frompyfunc(tres_mul, 3, 1)
+ four_mul_ufunc = np.frompyfunc(quatro_mul, 4, 1)
+
+ class A(object):
+ def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs):
+ return "A"
+
+ class ASub(A):
+ def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs):
+ return "ASub"
+
+ class B(object):
+ def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs):
+ return "B"
+
+ class C(object):
+ def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs):
+ return NotImplemented
+
+ class CSub(object):
+ def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs):
+ return NotImplemented
+
+
+
+ a = A()
+ a_sub = ASub()
+ b = B()
+ c = C()
+ c_sub = CSub()
+
+ # Standard
+ res = np.multiply(a, a_sub)
+ assert_equal(res, "ASub")
+ res = np.multiply(a_sub, b)
+ assert_equal(res, "ASub")
+
+ # With 1 NotImplemented
+ res = np.multiply(c, a)
+ assert_equal(res, "A")
+
+ # Both NotImplemented.
+ assert_raises(TypeError, np.multiply, c, c_sub)
+ assert_raises(TypeError, np.multiply, c_sub, c)
+ assert_raises(TypeError, np.multiply, 2, c)
+
+ # Ternary testing.
+ assert_equal(three_mul_ufunc(a, 1, 2), "A")
+ assert_equal(three_mul_ufunc(1, a, 2), "A")
+ assert_equal(three_mul_ufunc(1, 2, a), "A")
+
+ assert_equal(three_mul_ufunc(a, a, 6), "A")
+ assert_equal(three_mul_ufunc(a, 2, a), "A")
+ assert_equal(three_mul_ufunc(a, 2, b), "A")
+ assert_equal(three_mul_ufunc(a, 2, a_sub), "ASub")
+ assert_equal(three_mul_ufunc(a, a_sub, 3), "ASub")
+ assert_equal(three_mul_ufunc(c, a_sub, 3), "ASub")
+ assert_equal(three_mul_ufunc(1, a_sub, c), "ASub")
+
+ assert_equal(three_mul_ufunc(a, b, c), "A")
+ assert_equal(three_mul_ufunc(a, b, c_sub), "A")
+ assert_equal(three_mul_ufunc(1, 2, b), "B")
+
+ assert_raises(TypeError, three_mul_ufunc, 1, 2, c)
+ assert_raises(TypeError, three_mul_ufunc, c_sub, 2, c)
+ assert_raises(TypeError, three_mul_ufunc, c_sub, 2, 3)
+
+ # Quaternary testing.
+ assert_equal(four_mul_ufunc(a, 1, 2, 3), "A")
+ assert_equal(four_mul_ufunc(1, a, 2, 3), "A")
+ assert_equal(four_mul_ufunc(1, 1, a, 3), "A")
+ assert_equal(four_mul_ufunc(1, 1, 2, a), "A")
+
+ assert_equal(four_mul_ufunc(a, b, 2, 3), "A")
+ assert_equal(four_mul_ufunc(1, a, 2, b), "A")
+ assert_equal(four_mul_ufunc(b, 1, a, 3), "B")
+ assert_equal(four_mul_ufunc(a_sub, 1, 2, a), "ASub")
+ assert_equal(four_mul_ufunc(a, 1, 2, a_sub), "ASub")
+
+ assert_raises(TypeError, four_mul_ufunc, 1, 2, 3, c)
+ assert_raises(TypeError, four_mul_ufunc, 1, 2, c_sub, c)
+ assert_raises(TypeError, four_mul_ufunc, 1, c, c_sub, c)
+
+ def test_ufunc_override_methods(self):
+ class A(object):
+ def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
+ if method == "__call__":
+ return method
+ if method == "reduce":
+ return method
+ if method == "accumulate":
+ return method
+ if method == "reduceat":
+ return method
+
+ a = A()
+ res = np.multiply(1, a)
+ assert_equal(res, "__call__")
+
+ res = np.multiply.reduce(1, a)
+ assert_equal(res, "reduce")
+
+ res = np.multiply.accumulate(1, a)
+ assert_equal(res, "accumulate")
+
+ res = np.multiply.reduceat(1, a)
+ assert_equal(res, "reduceat")
+
+ res = np.multiply(a, 1)
+ assert_equal(res, "__call__")
+
+ res = np.multiply.reduce(a, 1)
+ assert_equal(res, "reduce")
+
+ res = np.multiply.accumulate(a, 1)
+ assert_equal(res, "accumulate")
+
+ res = np.multiply.reduceat(a, 1)
+ assert_equal(res, "reduceat")
+
+ def test_ufunc_override_out(self):
+ class A(object):
+ def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
+ return kwargs
+
+
+ class B(object):
+ def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
+ return kwargs
+
+ a = A()
+ b = B()
+ res0 = np.multiply(a, b, 'out_arg')
+ res1 = np.multiply(a, b, out='out_arg')
+ res2 = np.multiply(2, b, 'out_arg')
+ res3 = np.multiply(3, b, out='out_arg')
+ res4 = np.multiply(a, 4, 'out_arg')
+ res5 = np.multiply(a, 5, out='out_arg')
+
+ assert_equal(res0['out'], 'out_arg')
+ assert_equal(res1['out'], 'out_arg')
+ assert_equal(res2['out'], 'out_arg')
+ assert_equal(res3['out'], 'out_arg')
+ assert_equal(res4['out'], 'out_arg')
+ assert_equal(res5['out'], 'out_arg')
class TestChoose(TestCase):
def test_mixed(self):