diff options
-rw-r--r-- | numpy/core/src/multiarray/arrayobject.c | 81 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 7 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 146 |
3 files changed, 194 insertions, 40 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index ad196e010..dbbabbc99 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -1264,7 +1264,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) { PyArrayObject *array_other; PyObject *result = NULL; - PyArray_Descr *dtype = NULL; switch (cmp_op) { case Py_LT: @@ -1280,28 +1279,30 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_False); return Py_False; } - /* Make sure 'other' is an array */ - if (PyArray_TYPE(self) == NPY_OBJECT) { - dtype = PyArray_DTYPE(self); - Py_INCREF(dtype); - } - array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0, 0, - NULL); + result = PyArray_GenericBinaryFunction(self, + (PyObject *)other, + n_ops.equal); + if (result && result != Py_NotImplemented) + break; + /* - * If not successful, indicate that the items cannot be compared - * this way. + * The ufunc does not support void/structured types, so these + * need to be handled specifically. Only a few cases are supported. */ - if (array_other == NULL) { - PyErr_Clear(); - Py_INCREF(Py_NotImplemented); - return Py_NotImplemented; - } - result = PyArray_GenericBinaryFunction(self, - (PyObject *)array_other, - n_ops.equal); - if ((result == Py_NotImplemented) && - (PyArray_TYPE(self) == NPY_VOID)) { + if (PyArray_TYPE(self) == NPY_VOID) { + array_other = (PyArrayObject *)PyArray_FromAny(other, NULL, 0, 0, 0, + NULL); + /* + * If not successful, indicate that the items cannot be compared + * this way. + */ + if (array_other == NULL) { + PyErr_Clear(); + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + int _res; _res = PyObject_RichCompareBool @@ -1325,7 +1326,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) * two array objects can not be compared together; * indicate that */ - Py_DECREF(array_other); if (result == NULL) { PyErr_Clear(); Py_INCREF(Py_NotImplemented); @@ -1337,27 +1337,29 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_True); return Py_True; } - /* Make sure 'other' is an array */ - if (PyArray_TYPE(self) == NPY_OBJECT) { - dtype = PyArray_DTYPE(self); - Py_INCREF(dtype); - } - array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0, 0, - NULL); + result = PyArray_GenericBinaryFunction(self, (PyObject *)other, + n_ops.not_equal); + if (result && result != Py_NotImplemented) + break; + /* - * If not successful, indicate that the items cannot be compared - * this way. + * The ufunc does not support void/structured types, so these + * need to be handled specifically. Only a few cases are supported. */ - if (array_other == NULL) { - PyErr_Clear(); - Py_INCREF(Py_NotImplemented); - return Py_NotImplemented; - } - result = PyArray_GenericBinaryFunction(self, (PyObject *)array_other, - n_ops.not_equal); - if ((result == Py_NotImplemented) && - (PyArray_TYPE(self) == NPY_VOID)) { + if (PyArray_TYPE(self) == NPY_VOID) { + array_other = (PyArrayObject *)PyArray_FromAny(other, NULL, 0, 0, 0, + NULL); + /* + * If not successful, indicate that the items cannot be compared + * this way. + */ + if (array_other == NULL) { + PyErr_Clear(); + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + int _res; _res = PyObject_RichCompareBool( @@ -1377,7 +1379,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) return result; } - Py_DECREF(array_other); if (result == NULL) { PyErr_Clear(); Py_INCREF(Py_NotImplemented); diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 8a028096a..1bf56a0f0 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -492,6 +492,13 @@ _has_reflected_op(PyObject *op, char *name) _GETATTR_(bitwise_and, rand); _GETATTR_(bitwise_xor, rxor); _GETATTR_(bitwise_or, ror); + /* Comparisons */ + _GETATTR_(equal, eq); + _GETATTR_(not_equal, ne); + _GETATTR_(greater, lt); + _GETATTR_(less, gt); + _GETATTR_(greater_equal, le); + _GETATTR_(less_equal, ge); return 0; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index e3520f123..2f6716c23 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2934,6 +2934,152 @@ class TestMapIter(TestCase): assert_equal(b, [ 100.1, 51., 6., 3., 4., 5. ]) +class PriorityNdarray(): + __array_priority__ = 1000 + + def __init__(self, array): + self.array = array + + def __lt__(self, array): + if isinstance(array, PriorityNdarray): + array = array.array + return PriorityNdarray(self.array < array) + + def __gt__(self, array): + if isinstance(array, PriorityNdarray): + array = array.array + return PriorityNdarray(self.array > array) + + def __le__(self, array): + if isinstance(array, PriorityNdarray): + array = array.array + return PriorityNdarray(self.array <= array) + + def __ge__(self, array): + if isinstance(array, PriorityNdarray): + array = array.array + return PriorityNdarray(self.array >= array) + + def __eq__(self, array): + if isinstance(array, PriorityNdarray): + array = array.array + return PriorityNdarray(self.array == array) + + def __ne__(self, array): + if isinstance(array, PriorityNdarray): + array = array.array + return PriorityNdarray(self.array != array) + + +class TestArrayPriority(TestCase): + def test_lt(self): + l = np.asarray([0., -1., 1.], dtype=dtype) + r = np.asarray([0., 1., -1.], dtype=dtype) + lp = PriorityNdarray(l) + rp = PriorityNdarray(r) + res1 = l < r + res2 = l < rp + res3 = lp < r + res4 = lp < rp + + assert_array_equal(res1, res2.array) + assert_array_equal(res1, res3.array) + assert_array_equal(res1, res4.array) + assert_(isinstance(res1, np.ndarray)) + assert_(isinstance(res2, PriorityNdarray)) + assert_(isinstance(res3, PriorityNdarray)) + assert_(isinstance(res4, PriorityNdarray)) + + def test_gt(self): + l = np.asarray([0., -1., 1.], dtype=dtype) + r = np.asarray([0., 1., -1.], dtype=dtype) + lp = PriorityNdarray(l) + rp = PriorityNdarray(r) + res1 = l > r + res2 = l > rp + res3 = lp > r + res4 = lp > rp + + assert_array_equal(res1, res2.array) + assert_array_equal(res1, res3.array) + assert_array_equal(res1, res4.array) + assert_(isinstance(res1, np.ndarray)) + assert_(isinstance(res2, PriorityNdarray)) + assert_(isinstance(res3, PriorityNdarray)) + assert_(isinstance(res4, PriorityNdarray)) + + def test_le(self): + l = np.asarray([0., -1., 1.], dtype=dtype) + r = np.asarray([0., 1., -1.], dtype=dtype) + lp = PriorityNdarray(l) + rp = PriorityNdarray(r) + res1 = l <= r + res2 = l <= rp + res3 = lp <= r + res4 = lp <= rp + + assert_array_equal(res1, res2.array) + assert_array_equal(res1, res3.array) + assert_array_equal(res1, res4.array) + assert_(isinstance(res1, np.ndarray)) + assert_(isinstance(res2, PriorityNdarray)) + assert_(isinstance(res3, PriorityNdarray)) + assert_(isinstance(res4, PriorityNdarray)) + + def test_ge(self): + l = np.asarray([0., -1., 1.], dtype=dtype) + r = np.asarray([0., 1., -1.], dtype=dtype) + lp = PriorityNdarray(l) + rp = PriorityNdarray(r) + res1 = l >= r + res2 = l >= rp + res3 = lp >= r + res4 = lp >= rp + + assert_array_equal(res1, res2.array) + assert_array_equal(res1, res3.array) + assert_array_equal(res1, res4.array) + assert_(isinstance(res1, np.ndarray)) + assert_(isinstance(res2, PriorityNdarray)) + assert_(isinstance(res3, PriorityNdarray)) + assert_(isinstance(res4, PriorityNdarray)) + + def test_eq(self): + l = np.asarray([0., -1., 1.], dtype=dtype) + r = np.asarray([0., 1., -1.], dtype=dtype) + lp = PriorityNdarray(l) + rp = PriorityNdarray(r) + res1 = l == r + res2 = l == rp + res3 = lp == r + res4 = lp == rp + + assert_array_equal(res1, res2.array) + assert_array_equal(res1, res3.array) + assert_array_equal(res1, res4.array) + assert_(isinstance(res1, np.ndarray)) + assert_(isinstance(res2, PriorityNdarray)) + assert_(isinstance(res3, PriorityNdarray)) + assert_(isinstance(res4, PriorityNdarray)) + + def test_ne(self): + l = np.asarray([0., -1., 1.], dtype=dtype) + r = np.asarray([0., 1., -1.], dtype=dtype) + lp = PriorityNdarray(l) + rp = PriorityNdarray(r) + res1 = l != r + res2 = l != rp + res3 = lp != r + res4 = lp != rp + + assert_array_equal(res1, res2.array) + assert_array_equal(res1, res3.array) + assert_array_equal(res1, res4.array) + assert_(isinstance(res1, np.ndarray)) + assert_(isinstance(res2, PriorityNdarray)) + assert_(isinstance(res3, PriorityNdarray)) + assert_(isinstance(res4, PriorityNdarray)) + if __name__ == "__main__": run_module_suite() |