diff options
author | Blake Griffith <blake.a.griffith@gmail.com> | 2013-07-22 23:22:55 -0500 |
---|---|---|
committer | Blake Griffith <blake.a.griffith@gmail.com> | 2013-09-05 20:09:44 -0500 |
commit | 5c630f06f9012cdfff9b35cf96cea4a696c6b66c (patch) | |
tree | 753899866644709a1d043c8b2dd31effcec6cfd3 | |
parent | 21976ca31eda537824ad877d432634eaf103567b (diff) | |
download | numpy-5c630f06f9012cdfff9b35cf96cea4a696c6b66c.tar.gz |
TST: Add ufunc override tests.
-rw-r--r-- | numpy/core/tests/test_blasdot.py | 18 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 18 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 181 |
3 files changed, 217 insertions, 0 deletions
diff --git a/numpy/core/tests/test_blasdot.py b/numpy/core/tests/test_blasdot.py index 624c617d3..caa576abc 100644 --- a/numpy/core/tests/test_blasdot.py +++ b/numpy/core/tests/test_blasdot.py @@ -151,3 +151,21 @@ def test_dot_array_order(): assert_almost_equal(c.T.dot(b.T).T, b.dot(c), decimal=prec) assert_almost_equal(b.dot(c), _dot(b, c), decimal=prec) assert_almost_equal(c.T.dot(b.T), _dot(c.T, b.T), decimal=prec) + +def test_dot_override(): + class A(object): + def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs): + return "A" + + class B(object): + def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs): + return NotImplemented + + a = A() + b = B() + c = np.array([[1]]) + + assert_equal(np.dot(a, b), "A") + assert_equal(c.dot(a), "A") + assert_raises(TypeError, np.dot, b, c) + assert_raises(TypeError, c.dot, b) diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 9870d44f3..0e7f6ec61 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1324,6 +1324,24 @@ class TestMethods(TestCase): a.dot(b=b, out=c) assert_equal(c, np.dot(a, b)) + def test_dot_override(self): + class A(object): + def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs): + return "A" + + class B(object): + def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs): + return NotImplemented + + a = A() + b = B() + c = np.array([[1]]) + + assert_equal(np.dot(a, b), "A") + assert_equal(c.dot(a), "A") + assert_raises(TypeError, np.dot, b, c) + assert_raises(TypeError, c.dot, b) + def test_diagonal(self): a = np.arange(12).reshape((3, 4)) assert_equal(a.diagonal(), [0, 5, 10]) diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index eb18304ea..3706cfa81 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -868,6 +868,187 @@ class TestSpecialMethods(TestCase): assert_equal(ncu.maximum(a, B()), 0) assert_equal(ncu.maximum(a, C()), 0) + def test_ufunc_override(self): + class A(object): + def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs): + return self, func, method, pos, inputs, kwargs + + a = A() + + b = np.matrix([1]) + c = np.array([1]) + res0 = np.multiply(a, b) + res1 = np.dot(a, b) + + # self + assert_equal(res0[0], a) + assert_equal(res1[0], a) + assert_equal(res0[1], np.multiply) + assert_equal(res1[1], np.dot) + assert_equal(res0[2], '__call__') + assert_equal(res1[2], '__call__') + assert_equal(res0[3], 0) + assert_equal(res1[3], 0) + assert_equal(res0[4], (a, b)) + assert_equal(res1[4], (a, b)) + assert_equal(res0[5], {}) + assert_equal(res1[5], {}) + + def test_ufunc_override_mro(self): + + # Some multi arg functions for testing. + def tres_mul(a, b, c): + return a * b * c + + def quatro_mul(a, b, c, d): + return a * b * c * d + + # Make these into ufuncs. + three_mul_ufunc = np.frompyfunc(tres_mul, 3, 1) + four_mul_ufunc = np.frompyfunc(quatro_mul, 4, 1) + + class A(object): + def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs): + return "A" + + class ASub(A): + def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs): + return "ASub" + + class B(object): + def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs): + return "B" + + class C(object): + def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs): + return NotImplemented + + class CSub(object): + def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs): + return NotImplemented + + + + a = A() + a_sub = ASub() + b = B() + c = C() + c_sub = CSub() + + # Standard + res = np.multiply(a, a_sub) + assert_equal(res, "ASub") + res = np.multiply(a_sub, b) + assert_equal(res, "ASub") + + # With 1 NotImplemented + res = np.multiply(c, a) + assert_equal(res, "A") + + # Both NotImplemented. + assert_raises(TypeError, np.multiply, c, c_sub) + assert_raises(TypeError, np.multiply, c_sub, c) + assert_raises(TypeError, np.multiply, 2, c) + + # Ternary testing. + assert_equal(three_mul_ufunc(a, 1, 2), "A") + assert_equal(three_mul_ufunc(1, a, 2), "A") + assert_equal(three_mul_ufunc(1, 2, a), "A") + + assert_equal(three_mul_ufunc(a, a, 6), "A") + assert_equal(three_mul_ufunc(a, 2, a), "A") + assert_equal(three_mul_ufunc(a, 2, b), "A") + assert_equal(three_mul_ufunc(a, 2, a_sub), "ASub") + assert_equal(three_mul_ufunc(a, a_sub, 3), "ASub") + assert_equal(three_mul_ufunc(c, a_sub, 3), "ASub") + assert_equal(three_mul_ufunc(1, a_sub, c), "ASub") + + assert_equal(three_mul_ufunc(a, b, c), "A") + assert_equal(three_mul_ufunc(a, b, c_sub), "A") + assert_equal(three_mul_ufunc(1, 2, b), "B") + + assert_raises(TypeError, three_mul_ufunc, 1, 2, c) + assert_raises(TypeError, three_mul_ufunc, c_sub, 2, c) + assert_raises(TypeError, three_mul_ufunc, c_sub, 2, 3) + + # Quaternary testing. + assert_equal(four_mul_ufunc(a, 1, 2, 3), "A") + assert_equal(four_mul_ufunc(1, a, 2, 3), "A") + assert_equal(four_mul_ufunc(1, 1, a, 3), "A") + assert_equal(four_mul_ufunc(1, 1, 2, a), "A") + + assert_equal(four_mul_ufunc(a, b, 2, 3), "A") + assert_equal(four_mul_ufunc(1, a, 2, b), "A") + assert_equal(four_mul_ufunc(b, 1, a, 3), "B") + assert_equal(four_mul_ufunc(a_sub, 1, 2, a), "ASub") + assert_equal(four_mul_ufunc(a, 1, 2, a_sub), "ASub") + + assert_raises(TypeError, four_mul_ufunc, 1, 2, 3, c) + assert_raises(TypeError, four_mul_ufunc, 1, 2, c_sub, c) + assert_raises(TypeError, four_mul_ufunc, 1, c, c_sub, c) + + def test_ufunc_override_methods(self): + class A(object): + def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs): + if method == "__call__": + return method + if method == "reduce": + return method + if method == "accumulate": + return method + if method == "reduceat": + return method + + a = A() + res = np.multiply(1, a) + assert_equal(res, "__call__") + + res = np.multiply.reduce(1, a) + assert_equal(res, "reduce") + + res = np.multiply.accumulate(1, a) + assert_equal(res, "accumulate") + + res = np.multiply.reduceat(1, a) + assert_equal(res, "reduceat") + + res = np.multiply(a, 1) + assert_equal(res, "__call__") + + res = np.multiply.reduce(a, 1) + assert_equal(res, "reduce") + + res = np.multiply.accumulate(a, 1) + assert_equal(res, "accumulate") + + res = np.multiply.reduceat(a, 1) + assert_equal(res, "reduceat") + + def test_ufunc_override_out(self): + class A(object): + def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs): + return kwargs + + + class B(object): + def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs): + return kwargs + + a = A() + b = B() + res0 = np.multiply(a, b, 'out_arg') + res1 = np.multiply(a, b, out='out_arg') + res2 = np.multiply(2, b, 'out_arg') + res3 = np.multiply(3, b, out='out_arg') + res4 = np.multiply(a, 4, 'out_arg') + res5 = np.multiply(a, 5, out='out_arg') + + assert_equal(res0['out'], 'out_arg') + assert_equal(res1['out'], 'out_arg') + assert_equal(res2['out'], 'out_arg') + assert_equal(res3['out'], 'out_arg') + assert_equal(res4['out'], 'out_arg') + assert_equal(res5['out'], 'out_arg') class TestChoose(TestCase): def test_mixed(self): |