summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/arrayobject.c67
-rw-r--r--numpy/core/tests/test_deprecations.py2
-rw-r--r--numpy/core/tests/test_multiarray.py14
-rw-r--r--numpy/testing/utils.py29
4 files changed, 81 insertions, 31 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index ac9e1a2d2..3e5084fb3 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -1343,16 +1343,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_False);
return Py_False;
}
- if (needs_right_binop_forward(obj_self, other, "__eq__", 0) &&
- Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) {
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
- result = PyArray_GenericBinaryFunction(self,
- (PyObject *)other,
- n_ops.equal);
- if (result && result != Py_NotImplemented)
- break;
/*
* The ufunc does not support void/structured types, so these
@@ -1370,6 +1360,12 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
*/
if (array_other == NULL) {
PyErr_Clear();
+ if (DEPRECATE_FUTUREWARNING(
+ "elementwise comparison failed and returning scalar "
+ "instead; this will raise an error or perform "
+ "elementwise comparison in the future.") < 0) {
+ return NULL;
+ }
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
@@ -1378,18 +1374,31 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
PyArray_DESCR(array_other),
NPY_EQUIV_CASTING);
if (_res == 0) {
- Py_DECREF(result);
Py_DECREF(array_other);
+ if (DEPRECATE_FUTUREWARNING(
+ "elementwise comparison failed and returning scalar "
+ "instead; this will raise an error or perform "
+ "elementwise comparison in the future.") < 0) {
+ return NULL;
+ }
Py_INCREF(Py_False);
return Py_False;
}
else {
- Py_DECREF(result);
result = _void_compare(self, array_other, cmp_op);
}
Py_DECREF(array_other);
return result;
}
+
+ if (needs_right_binop_forward(obj_self, other, "__eq__", 0) &&
+ Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+ result = PyArray_GenericBinaryFunction(self,
+ (PyObject *)other,
+ n_ops.equal);
/*
* If the comparison results in NULL, then the
* two array objects can not be compared together;
@@ -1402,7 +1411,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
*/
PyErr_Clear();
if (DEPRECATE("elementwise comparison failed; "
- "this will raise the error in the future.") < 0) {
+ "this will raise an error in the future.") < 0) {
return NULL;
}
@@ -1419,15 +1428,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_True);
return Py_True;
}
- if (needs_right_binop_forward(obj_self, other, "__ne__", 0) &&
- Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) {
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
- result = PyArray_GenericBinaryFunction(self, (PyObject *)other,
- n_ops.not_equal);
- if (result && result != Py_NotImplemented)
- break;
/*
* The ufunc does not support void/structured types, so these
@@ -1445,6 +1445,12 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
*/
if (array_other == NULL) {
PyErr_Clear();
+ if (DEPRECATE_FUTUREWARNING(
+ "elementwise comparison failed and returning scalar "
+ "instead; this will raise an error or perform "
+ "elementwise comparison in the future.") < 0) {
+ return NULL;
+ }
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
@@ -1453,19 +1459,30 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
PyArray_DESCR(array_other),
NPY_EQUIV_CASTING);
if (_res == 0) {
- Py_DECREF(result);
Py_DECREF(array_other);
+ if (DEPRECATE_FUTUREWARNING(
+ "elementwise comparison failed and returning scalar "
+ "instead; this will raise an error or perform "
+ "elementwise comparison in the future.") < 0) {
+ return NULL;
+ }
Py_INCREF(Py_True);
return Py_True;
}
else {
- Py_DECREF(result);
result = _void_compare(self, array_other, cmp_op);
Py_DECREF(array_other);
}
return result;
}
+ if (needs_right_binop_forward(obj_self, other, "__ne__", 0) &&
+ Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+ result = PyArray_GenericBinaryFunction(self, (PyObject *)other,
+ n_ops.not_equal);
if (result == NULL) {
/*
* Comparisons should raise errors when element-wise comparison
@@ -1473,7 +1490,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
*/
PyErr_Clear();
if (DEPRECATE("elementwise comparison failed; "
- "this will raise the error in the future.") < 0) {
+ "this will raise an error in the future.") < 0) {
return NULL;
}
diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py
index 159753797..b4889e491 100644
--- a/numpy/core/tests/test_deprecations.py
+++ b/numpy/core/tests/test_deprecations.py
@@ -383,7 +383,7 @@ class TestComparisonDeprecations(_DeprecationTestCase):
"""
message = "elementwise comparison failed; " \
- "this will raise the error in the future."
+ "this will raise an error in the future."
def test_normal_types(self):
for op in (operator.eq, operator.ne):
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index e00c361d8..b2773c189 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -738,11 +738,21 @@ class TestStructured(TestCase):
# Check that incompatible sub-array shapes don't result to broadcasting
x = np.zeros((1,), dtype=[('a', ('f4', (1, 2))), ('b', 'i1')])
y = np.zeros((1,), dtype=[('a', ('f4', (2,))), ('b', 'i1')])
- assert_equal(x == y, False)
+ # This comparison invokes deprecated behaviour, and will probably
+ # start raising an error eventually. What we really care about in this
+ # test is just that it doesn't return True.
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
+ assert_equal(x == y, False)
x = np.zeros((1,), dtype=[('a', ('f4', (2, 1))), ('b', 'i1')])
y = np.zeros((1,), dtype=[('a', ('f4', (2,))), ('b', 'i1')])
- assert_equal(x == y, False)
+ # This comparison invokes deprecated behaviour, and will probably
+ # start raising an error eventually. What we really care about in this
+ # test is just that it doesn't return True.
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
+ assert_equal(x == y, False)
# Check that structured arrays that are different only in
# byte-order work
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 4527a51d9..6244b3ea7 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -592,6 +592,29 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
+ def safe_comparison(*args, **kwargs):
+ # There are a number of cases where comparing two arrays hits special
+ # cases in array_richcompare, specifically around strings and void
+ # dtypes. Basically, we just can't do comparisons involving these
+ # types, unless both arrays have exactly the *same* type. So
+ # e.g. you can apply == to two string arrays, or two arrays with
+ # identical structured dtypes. But if you compare a non-string array
+ # to a string array, or two arrays with non-identical structured
+ # dtypes, or anything like that, then internally stuff blows up.
+ # Currently, when things blow up, we just return a scalar False or
+ # True. But we also emit a DeprecationWarning, b/c eventually we
+ # should raise an error here. (Ideally we might even make this work
+ # properly, but since that will require rewriting a bunch of how
+ # ufuncs work then we are not counting on that.)
+ #
+ # The point of this little function is to let the DeprecationWarning
+ # pass (or maybe eventually catch the errors and return False, I
+ # dunno, that's a little trickier and we can figure that out when the
+ # time comes).
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
+ return comparison(*args, **kwargs)
+
def isnumber(x):
return x.dtype.char in '?bhilqpBHILQPefdgFDG'
@@ -641,11 +664,11 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
return
if any(x_id):
- val = comparison(x[~x_id], y[~y_id])
+ val = safe_comparison(x[~x_id], y[~y_id])
else:
- val = comparison(x, y)
+ val = safe_comparison(x, y)
else:
- val = comparison(x, y)
+ val = safe_comparison(x, y)
if isinstance(val, bool):
cond = val