summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorAndreas Kloeckner <inform@tiker.net>2013-07-08 23:11:03 -0400
committerAndreas Kloeckner <inform@tiker.net>2013-07-08 23:11:03 -0400
commit4441bdd95197ba10651eee8366e67176fb3b5b51 (patch)
treecb6a2b114b1120aeebd80da3106f348e1f8e4f2c /numpy/core
parent79188b21dd85e4115195971522be91a2fcb1a9d2 (diff)
downloadnumpy-4441bdd95197ba10651eee8366e67176fb3b5b51.tar.gz
BUG: Check earlier for higher priority in binary operators, add test
Fixes #3375
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/number.c20
-rw-r--r--numpy/core/tests/test_ufunc.py30
2 files changed, 50 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 9e13d0f29..392370667 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -209,6 +209,26 @@ PyArray_GenericBinaryFunction(PyArrayObject *m1, PyObject *m2, PyObject *op)
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
+
+ if (!PyArray_Check(m2)) {
+ /*
+ * Catch priority inversion and punt, but only if it's guaranteed
+ * that we were called through m1 and the other guy is not an array
+ * at all. Note that some arrays need to pass through here even
+ * with priorities inverted, for example: float(17) * np.matrix(...)
+ *
+ * See also:
+ * - https://github.com/numpy/numpy/issues/3502
+ * - https://github.com/numpy/numpy/issues/3503
+ */
+ double m1_prio = PyArray_GetPriority(m1, NPY_SCALAR_PRIORITY);
+ double m2_prio = PyArray_GetPriority(m2, NPY_SCALAR_PRIORITY);
+ if (m1_prio < m2_prio) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+ }
+
return PyObject_CallFunction(op, "OO", m1, m2);
}
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index ad489124e..e4698abc6 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -834,5 +834,35 @@ class TestUfunc(TestCase):
dtype=rational);
assert_equal(result, expected);
+ def test_custom_array_like(self):
+ class MyThing(object):
+ __array_priority__ = 1000
+
+ rmul_count = 0
+ getitem_count = 0
+
+ def __init__(self, shape):
+ self.shape = shape
+
+ def __len__(self):
+ return self.shape[0]
+
+ def __getitem__(self, i):
+ MyThing.getitem_count += 1
+ if not isinstance(i, tuple):
+ i = (i,)
+ if len(i) > len(self.shape):
+ raise IndexError("boo")
+
+ return MyThing(self.shape[len(i):])
+
+ def __rmul__(self, other):
+ MyThing.rmul_count += 1
+ return self
+
+ np.float64(5)*MyThing((3, 3))
+ assert_(MyThing.rmul_count == 1, MyThing.rmul_count)
+ assert_(MyThing.getitem_count <= 2, MyThing.getitem_count)
+
if __name__ == "__main__":
run_module_suite()