summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/arrayobject.c81
-rw-r--r--numpy/core/src/umath/ufunc_object.c7
-rw-r--r--numpy/core/tests/test_multiarray.py146
3 files changed, 194 insertions, 40 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index ad196e010..dbbabbc99 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -1264,7 +1264,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
{
PyArrayObject *array_other;
PyObject *result = NULL;
- PyArray_Descr *dtype = NULL;
switch (cmp_op) {
case Py_LT:
@@ -1280,28 +1279,30 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_False);
return Py_False;
}
- /* Make sure 'other' is an array */
- if (PyArray_TYPE(self) == NPY_OBJECT) {
- dtype = PyArray_DTYPE(self);
- Py_INCREF(dtype);
- }
- array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0, 0,
- NULL);
+ result = PyArray_GenericBinaryFunction(self,
+ (PyObject *)other,
+ n_ops.equal);
+ if (result && result != Py_NotImplemented)
+ break;
+
/*
- * If not successful, indicate that the items cannot be compared
- * this way.
+ * The ufunc does not support void/structured types, so these
+ * need to be handled specifically. Only a few cases are supported.
*/
- if (array_other == NULL) {
- PyErr_Clear();
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
- result = PyArray_GenericBinaryFunction(self,
- (PyObject *)array_other,
- n_ops.equal);
- if ((result == Py_NotImplemented) &&
- (PyArray_TYPE(self) == NPY_VOID)) {
+ if (PyArray_TYPE(self) == NPY_VOID) {
+ array_other = (PyArrayObject *)PyArray_FromAny(other, NULL, 0, 0, 0,
+ NULL);
+ /*
+ * If not successful, indicate that the items cannot be compared
+ * this way.
+ */
+ if (array_other == NULL) {
+ PyErr_Clear();
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+
int _res;
_res = PyObject_RichCompareBool
@@ -1325,7 +1326,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
* two array objects can not be compared together;
* indicate that
*/
- Py_DECREF(array_other);
if (result == NULL) {
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
@@ -1337,27 +1337,29 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_True);
return Py_True;
}
- /* Make sure 'other' is an array */
- if (PyArray_TYPE(self) == NPY_OBJECT) {
- dtype = PyArray_DTYPE(self);
- Py_INCREF(dtype);
- }
- array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0, 0,
- NULL);
+ result = PyArray_GenericBinaryFunction(self, (PyObject *)other,
+ n_ops.not_equal);
+ if (result && result != Py_NotImplemented)
+ break;
+
/*
- * If not successful, indicate that the items cannot be compared
- * this way.
+ * The ufunc does not support void/structured types, so these
+ * need to be handled specifically. Only a few cases are supported.
*/
- if (array_other == NULL) {
- PyErr_Clear();
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
- result = PyArray_GenericBinaryFunction(self, (PyObject *)array_other,
- n_ops.not_equal);
- if ((result == Py_NotImplemented) &&
- (PyArray_TYPE(self) == NPY_VOID)) {
+ if (PyArray_TYPE(self) == NPY_VOID) {
+ array_other = (PyArrayObject *)PyArray_FromAny(other, NULL, 0, 0, 0,
+ NULL);
+ /*
+ * If not successful, indicate that the items cannot be compared
+ * this way.
+ */
+ if (array_other == NULL) {
+ PyErr_Clear();
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+
int _res;
_res = PyObject_RichCompareBool(
@@ -1377,7 +1379,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
return result;
}
- Py_DECREF(array_other);
if (result == NULL) {
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 8a028096a..1bf56a0f0 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -492,6 +492,13 @@ _has_reflected_op(PyObject *op, char *name)
_GETATTR_(bitwise_and, rand);
_GETATTR_(bitwise_xor, rxor);
_GETATTR_(bitwise_or, ror);
+ /* Comparisons */
+ _GETATTR_(equal, eq);
+ _GETATTR_(not_equal, ne);
+ _GETATTR_(greater, lt);
+ _GETATTR_(less, gt);
+ _GETATTR_(greater_equal, le);
+ _GETATTR_(less_equal, ge);
return 0;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index e3520f123..2f6716c23 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2934,6 +2934,152 @@ class TestMapIter(TestCase):
assert_equal(b, [ 100.1, 51., 6., 3., 4., 5. ])
+class PriorityNdarray():
+ __array_priority__ = 1000
+
+ def __init__(self, array):
+ self.array = array
+
+ def __lt__(self, array):
+ if isinstance(array, PriorityNdarray):
+ array = array.array
+ return PriorityNdarray(self.array < array)
+
+ def __gt__(self, array):
+ if isinstance(array, PriorityNdarray):
+ array = array.array
+ return PriorityNdarray(self.array > array)
+
+ def __le__(self, array):
+ if isinstance(array, PriorityNdarray):
+ array = array.array
+ return PriorityNdarray(self.array <= array)
+
+ def __ge__(self, array):
+ if isinstance(array, PriorityNdarray):
+ array = array.array
+ return PriorityNdarray(self.array >= array)
+
+ def __eq__(self, array):
+ if isinstance(array, PriorityNdarray):
+ array = array.array
+ return PriorityNdarray(self.array == array)
+
+ def __ne__(self, array):
+ if isinstance(array, PriorityNdarray):
+ array = array.array
+ return PriorityNdarray(self.array != array)
+
+
+class TestArrayPriority(TestCase):
+ def test_lt(self):
+ l = np.asarray([0., -1., 1.], dtype=dtype)
+ r = np.asarray([0., 1., -1.], dtype=dtype)
+ lp = PriorityNdarray(l)
+ rp = PriorityNdarray(r)
+ res1 = l < r
+ res2 = l < rp
+ res3 = lp < r
+ res4 = lp < rp
+
+ assert_array_equal(res1, res2.array)
+ assert_array_equal(res1, res3.array)
+ assert_array_equal(res1, res4.array)
+ assert_(isinstance(res1, np.ndarray))
+ assert_(isinstance(res2, PriorityNdarray))
+ assert_(isinstance(res3, PriorityNdarray))
+ assert_(isinstance(res4, PriorityNdarray))
+
+ def test_gt(self):
+ l = np.asarray([0., -1., 1.], dtype=dtype)
+ r = np.asarray([0., 1., -1.], dtype=dtype)
+ lp = PriorityNdarray(l)
+ rp = PriorityNdarray(r)
+ res1 = l > r
+ res2 = l > rp
+ res3 = lp > r
+ res4 = lp > rp
+
+ assert_array_equal(res1, res2.array)
+ assert_array_equal(res1, res3.array)
+ assert_array_equal(res1, res4.array)
+ assert_(isinstance(res1, np.ndarray))
+ assert_(isinstance(res2, PriorityNdarray))
+ assert_(isinstance(res3, PriorityNdarray))
+ assert_(isinstance(res4, PriorityNdarray))
+
+ def test_le(self):
+ l = np.asarray([0., -1., 1.], dtype=dtype)
+ r = np.asarray([0., 1., -1.], dtype=dtype)
+ lp = PriorityNdarray(l)
+ rp = PriorityNdarray(r)
+ res1 = l <= r
+ res2 = l <= rp
+ res3 = lp <= r
+ res4 = lp <= rp
+
+ assert_array_equal(res1, res2.array)
+ assert_array_equal(res1, res3.array)
+ assert_array_equal(res1, res4.array)
+ assert_(isinstance(res1, np.ndarray))
+ assert_(isinstance(res2, PriorityNdarray))
+ assert_(isinstance(res3, PriorityNdarray))
+ assert_(isinstance(res4, PriorityNdarray))
+
+ def test_ge(self):
+ l = np.asarray([0., -1., 1.], dtype=dtype)
+ r = np.asarray([0., 1., -1.], dtype=dtype)
+ lp = PriorityNdarray(l)
+ rp = PriorityNdarray(r)
+ res1 = l >= r
+ res2 = l >= rp
+ res3 = lp >= r
+ res4 = lp >= rp
+
+ assert_array_equal(res1, res2.array)
+ assert_array_equal(res1, res3.array)
+ assert_array_equal(res1, res4.array)
+ assert_(isinstance(res1, np.ndarray))
+ assert_(isinstance(res2, PriorityNdarray))
+ assert_(isinstance(res3, PriorityNdarray))
+ assert_(isinstance(res4, PriorityNdarray))
+
+ def test_eq(self):
+ l = np.asarray([0., -1., 1.], dtype=dtype)
+ r = np.asarray([0., 1., -1.], dtype=dtype)
+ lp = PriorityNdarray(l)
+ rp = PriorityNdarray(r)
+ res1 = l == r
+ res2 = l == rp
+ res3 = lp == r
+ res4 = lp == rp
+
+ assert_array_equal(res1, res2.array)
+ assert_array_equal(res1, res3.array)
+ assert_array_equal(res1, res4.array)
+ assert_(isinstance(res1, np.ndarray))
+ assert_(isinstance(res2, PriorityNdarray))
+ assert_(isinstance(res3, PriorityNdarray))
+ assert_(isinstance(res4, PriorityNdarray))
+
+ def test_ne(self):
+ l = np.asarray([0., -1., 1.], dtype=dtype)
+ r = np.asarray([0., 1., -1.], dtype=dtype)
+ lp = PriorityNdarray(l)
+ rp = PriorityNdarray(r)
+ res1 = l != r
+ res2 = l != rp
+ res3 = lp != r
+ res4 = lp != rp
+
+ assert_array_equal(res1, res2.array)
+ assert_array_equal(res1, res3.array)
+ assert_array_equal(res1, res4.array)
+ assert_(isinstance(res1, np.ndarray))
+ assert_(isinstance(res2, PriorityNdarray))
+ assert_(isinstance(res3, PriorityNdarray))
+ assert_(isinstance(res4, PriorityNdarray))
+
if __name__ == "__main__":
run_module_suite()