diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/number.c | 20 | ||||
-rw-r--r-- | numpy/core/tests/test_ufunc.py | 30 |
2 files changed, 50 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 9e13d0f29..392370667 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -209,6 +209,26 @@ PyArray_GenericBinaryFunction(PyArrayObject *m1, PyObject *m2, PyObject *op) Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } + + if (!PyArray_Check(m2)) { + /* + * Catch priority inversion and punt, but only if it's guaranteed + * that we were called through m1 and the other guy is not an array + * at all. Note that some arrays need to pass through here even + * with priorities inverted, for example: float(17) * np.matrix(...) + * + * See also: + * - https://github.com/numpy/numpy/issues/3502 + * - https://github.com/numpy/numpy/issues/3503 + */ + double m1_prio = PyArray_GetPriority(m1, NPY_SCALAR_PRIORITY); + double m2_prio = PyArray_GetPriority(m2, NPY_SCALAR_PRIORITY); + if (m1_prio < m2_prio) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + } + return PyObject_CallFunction(op, "OO", m1, m2); } diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index ad489124e..e4698abc6 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -834,5 +834,35 @@ class TestUfunc(TestCase): dtype=rational); assert_equal(result, expected); + def test_custom_array_like(self): + class MyThing(object): + __array_priority__ = 1000 + + rmul_count = 0 + getitem_count = 0 + + def __init__(self, shape): + self.shape = shape + + def __len__(self): + return self.shape[0] + + def __getitem__(self, i): + MyThing.getitem_count += 1 + if not isinstance(i, tuple): + i = (i,) + if len(i) > len(self.shape): + raise IndexError("boo") + + return MyThing(self.shape[len(i):]) + + def __rmul__(self, other): + MyThing.rmul_count += 1 + return self + + np.float64(5)*MyThing((3, 3)) + assert_(MyThing.rmul_count == 1, MyThing.rmul_count) + assert_(MyThing.getitem_count <= 2, MyThing.getitem_count) + if __name__ == "__main__": run_module_suite() |