diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-06-11 17:36:23 -0400 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-06-28 09:49:14 -0400 |
commit | f1e56ad7c7600e07d31d8dd380063d8eecdf6d4d (patch) | |
tree | 220fa32333a4e8e0ac1735a0742d39a058f2f4a7 /numpy | |
parent | a9b01a2d24aa3aa5c523df6b97467db1c92607d4 (diff) | |
download | numpy-f1e56ad7c7600e07d31d8dd380063d8eecdf6d4d.tar.gz |
MAINT: Ensure __array_ufunc__ on given class is only called once.
Overall, it likely doesn't matter much for performance, but it is
more logical and more consistent with what python does: reverse
operators are not called if the forward one of a given class
already returned NotImplemented.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/private/ufunc_override.c | 30 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 48 |
2 files changed, 69 insertions, 9 deletions
diff --git a/numpy/core/src/private/ufunc_override.c b/numpy/core/src/private/ufunc_override.c index 116da3267..522b1744d 100644 --- a/numpy/core/src/private/ufunc_override.c +++ b/numpy/core/src/private/ufunc_override.c @@ -73,7 +73,6 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds, int out_kwd_is_tuple = 0; int num_override_args = 0; - PyObject *obj; PyObject *out_kwd_obj = NULL; /* * Check inputs @@ -106,7 +105,9 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds, } for (i = 0; i < nargs + nout_kwd; ++i) { - PyObject *method; + PyObject *obj; + int new_class = 1; + if (i < nargs) { obj = PyTuple_GET_ITEM(args, i); } @@ -119,12 +120,27 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds, } } /* - * Now see if the object provides an __array_ufunc__. However, we should - * ignore the base ndarray.__ufunc__, so we skip any ndarray as well as - * any ndarray subclass instances that did not override __array_ufunc__. + * Have we seen this class before? If so, ignore. */ - method = get_non_default_array_ufunc(obj); - if (method != NULL) { + if (with_override != NULL) { + int j; + for (j = 0; j < num_override_args; j++) { + new_class = (Py_TYPE(obj) != Py_TYPE(with_override[j])); + if (!new_class) { + break; + } + } + } + if (new_class) { + /* + * Now see if the object provides an __array_ufunc__. However, we should + * ignore the base ndarray.__ufunc__, so we skip any ndarray as well as + * any ndarray subclass instances that did not override __array_ufunc__. + */ + PyObject *method = get_non_default_array_ufunc(obj); + if (method == NULL) { + continue; + } if (method == Py_None) { PyErr_Format(PyExc_TypeError, "operand '%.200s' does not support ufuncs " diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index 4772913be..95107b538 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -1745,18 +1745,22 @@ class TestSpecialMethods(object): return "B" class C(object): + def __init__(self): + self.count = 0 + def __array_ufunc__(self, func, method, *inputs, **kwargs): + self.count += 1 return NotImplemented class CSub(C): def __array_ufunc__(self, func, method, *inputs, **kwargs): + self.count += 1 return NotImplemented a = A() a_sub = ASub() b = B() c = C() - c_sub = CSub() # Standard res = np.multiply(a, a_sub) @@ -1767,11 +1771,27 @@ class TestSpecialMethods(object): # With 1 NotImplemented res = np.multiply(c, a) assert_equal(res, "A") + assert_equal(c.count, 1) + # Check our counter works, so we can trust tests below. + res = np.multiply(c, a) + assert_equal(c.count, 2) # Both NotImplemented. + c = C() + c_sub = CSub() assert_raises(TypeError, np.multiply, c, c_sub) + assert_equal(c.count, 1) + assert_equal(c_sub.count, 1) + c.count = c_sub.count = 0 assert_raises(TypeError, np.multiply, c_sub, c) + assert_equal(c.count, 1) + assert_equal(c_sub.count, 1) + c.count = 0 + assert_raises(TypeError, np.multiply, c, c) + assert_equal(c.count, 1) + c.count = 0 assert_raises(TypeError, np.multiply, 2, c) + assert_equal(c.count, 1) # Ternary testing. assert_equal(three_mul_ufunc(a, 1, 2), "A") @@ -1783,11 +1803,19 @@ class TestSpecialMethods(object): assert_equal(three_mul_ufunc(a, 2, b), "A") assert_equal(three_mul_ufunc(a, 2, a_sub), "ASub") assert_equal(three_mul_ufunc(a, a_sub, 3), "ASub") + c.count = 0 assert_equal(three_mul_ufunc(c, a_sub, 3), "ASub") + assert_equal(c.count, 1) + c.count = 0 assert_equal(three_mul_ufunc(1, a_sub, c), "ASub") + assert_equal(c.count, 0) + c.count = 0 assert_equal(three_mul_ufunc(a, b, c), "A") + assert_equal(c.count, 0) + c_sub.count = 0 assert_equal(three_mul_ufunc(a, b, c_sub), "A") + assert_equal(c_sub.count, 0) assert_equal(three_mul_ufunc(1, 2, b), "B") assert_raises(TypeError, three_mul_ufunc, 1, 2, c) @@ -1806,9 +1834,25 @@ class TestSpecialMethods(object): assert_equal(four_mul_ufunc(a_sub, 1, 2, a), "ASub") assert_equal(four_mul_ufunc(a, 1, 2, a_sub), "ASub") + c = C() + c_sub = CSub() assert_raises(TypeError, four_mul_ufunc, 1, 2, 3, c) + assert_equal(c.count, 1) + c.count = 0 assert_raises(TypeError, four_mul_ufunc, 1, 2, c_sub, c) - assert_raises(TypeError, four_mul_ufunc, 1, c, c_sub, c) + assert_equal(c_sub.count, 1) + assert_equal(c.count, 1) + c2 = C() + c.count = c_sub.count = 0 + assert_raises(TypeError, four_mul_ufunc, 1, c, c_sub, c2) + assert_equal(c_sub.count, 1) + assert_equal(c.count, 1) + assert_equal(c2.count, 0) + c.count = c2.count = c_sub.count = 0 + assert_raises(TypeError, four_mul_ufunc, c2, c, c_sub, c) + assert_equal(c_sub.count, 1) + assert_equal(c.count, 0) + assert_equal(c2.count, 1) def test_ufunc_override_methods(self): |