summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/private/ufunc_override.c30
-rw-r--r--numpy/core/tests/test_umath.py48
2 files changed, 69 insertions, 9 deletions
diff --git a/numpy/core/src/private/ufunc_override.c b/numpy/core/src/private/ufunc_override.c
index 116da3267..522b1744d 100644
--- a/numpy/core/src/private/ufunc_override.c
+++ b/numpy/core/src/private/ufunc_override.c
@@ -73,7 +73,6 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds,
int out_kwd_is_tuple = 0;
int num_override_args = 0;
- PyObject *obj;
PyObject *out_kwd_obj = NULL;
/*
* Check inputs
@@ -106,7 +105,9 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds,
}
for (i = 0; i < nargs + nout_kwd; ++i) {
- PyObject *method;
+ PyObject *obj;
+ int new_class = 1;
+
if (i < nargs) {
obj = PyTuple_GET_ITEM(args, i);
}
@@ -119,12 +120,27 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds,
}
}
/*
- * Now see if the object provides an __array_ufunc__. However, we should
- * ignore the base ndarray.__ufunc__, so we skip any ndarray as well as
- * any ndarray subclass instances that did not override __array_ufunc__.
+ * Have we seen this class before? If so, ignore.
*/
- method = get_non_default_array_ufunc(obj);
- if (method != NULL) {
+ if (with_override != NULL) {
+ int j;
+ for (j = 0; j < num_override_args; j++) {
+ new_class = (Py_TYPE(obj) != Py_TYPE(with_override[j]));
+ if (!new_class) {
+ break;
+ }
+ }
+ }
+ if (new_class) {
+ /*
+ * Now see if the object provides an __array_ufunc__. However, we should
+ * ignore the base ndarray.__ufunc__, so we skip any ndarray as well as
+ * any ndarray subclass instances that did not override __array_ufunc__.
+ */
+ PyObject *method = get_non_default_array_ufunc(obj);
+ if (method == NULL) {
+ continue;
+ }
if (method == Py_None) {
PyErr_Format(PyExc_TypeError,
"operand '%.200s' does not support ufuncs "
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 4772913be..95107b538 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -1745,18 +1745,22 @@ class TestSpecialMethods(object):
return "B"
class C(object):
+ def __init__(self):
+ self.count = 0
+
def __array_ufunc__(self, func, method, *inputs, **kwargs):
+ self.count += 1
return NotImplemented
class CSub(C):
def __array_ufunc__(self, func, method, *inputs, **kwargs):
+ self.count += 1
return NotImplemented
a = A()
a_sub = ASub()
b = B()
c = C()
- c_sub = CSub()
# Standard
res = np.multiply(a, a_sub)
@@ -1767,11 +1771,27 @@ class TestSpecialMethods(object):
# With 1 NotImplemented
res = np.multiply(c, a)
assert_equal(res, "A")
+ assert_equal(c.count, 1)
+ # Check our counter works, so we can trust tests below.
+ res = np.multiply(c, a)
+ assert_equal(c.count, 2)
# Both NotImplemented.
+ c = C()
+ c_sub = CSub()
assert_raises(TypeError, np.multiply, c, c_sub)
+ assert_equal(c.count, 1)
+ assert_equal(c_sub.count, 1)
+ c.count = c_sub.count = 0
assert_raises(TypeError, np.multiply, c_sub, c)
+ assert_equal(c.count, 1)
+ assert_equal(c_sub.count, 1)
+ c.count = 0
+ assert_raises(TypeError, np.multiply, c, c)
+ assert_equal(c.count, 1)
+ c.count = 0
assert_raises(TypeError, np.multiply, 2, c)
+ assert_equal(c.count, 1)
# Ternary testing.
assert_equal(three_mul_ufunc(a, 1, 2), "A")
@@ -1783,11 +1803,19 @@ class TestSpecialMethods(object):
assert_equal(three_mul_ufunc(a, 2, b), "A")
assert_equal(three_mul_ufunc(a, 2, a_sub), "ASub")
assert_equal(three_mul_ufunc(a, a_sub, 3), "ASub")
+ c.count = 0
assert_equal(three_mul_ufunc(c, a_sub, 3), "ASub")
+ assert_equal(c.count, 1)
+ c.count = 0
assert_equal(three_mul_ufunc(1, a_sub, c), "ASub")
+ assert_equal(c.count, 0)
+ c.count = 0
assert_equal(three_mul_ufunc(a, b, c), "A")
+ assert_equal(c.count, 0)
+ c_sub.count = 0
assert_equal(three_mul_ufunc(a, b, c_sub), "A")
+ assert_equal(c_sub.count, 0)
assert_equal(three_mul_ufunc(1, 2, b), "B")
assert_raises(TypeError, three_mul_ufunc, 1, 2, c)
@@ -1806,9 +1834,25 @@ class TestSpecialMethods(object):
assert_equal(four_mul_ufunc(a_sub, 1, 2, a), "ASub")
assert_equal(four_mul_ufunc(a, 1, 2, a_sub), "ASub")
+ c = C()
+ c_sub = CSub()
assert_raises(TypeError, four_mul_ufunc, 1, 2, 3, c)
+ assert_equal(c.count, 1)
+ c.count = 0
assert_raises(TypeError, four_mul_ufunc, 1, 2, c_sub, c)
- assert_raises(TypeError, four_mul_ufunc, 1, c, c_sub, c)
+ assert_equal(c_sub.count, 1)
+ assert_equal(c.count, 1)
+ c2 = C()
+ c.count = c_sub.count = 0
+ assert_raises(TypeError, four_mul_ufunc, 1, c, c_sub, c2)
+ assert_equal(c_sub.count, 1)
+ assert_equal(c.count, 1)
+ assert_equal(c2.count, 0)
+ c.count = c2.count = c_sub.count = 0
+ assert_raises(TypeError, four_mul_ufunc, c2, c, c_sub, c)
+ assert_equal(c_sub.count, 1)
+ assert_equal(c.count, 0)
+ assert_equal(c2.count, 1)
def test_ufunc_override_methods(self):