diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/arrayobject.c | 67 | ||||
-rw-r--r-- | numpy/core/tests/test_deprecations.py | 2 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 14 | ||||
-rw-r--r-- | numpy/testing/utils.py | 29 |
4 files changed, 81 insertions, 31 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index ac9e1a2d2..3e5084fb3 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -1343,16 +1343,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_False); return Py_False; } - if (needs_right_binop_forward(obj_self, other, "__eq__", 0) && - Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { - Py_INCREF(Py_NotImplemented); - return Py_NotImplemented; - } - result = PyArray_GenericBinaryFunction(self, - (PyObject *)other, - n_ops.equal); - if (result && result != Py_NotImplemented) - break; /* * The ufunc does not support void/structured types, so these @@ -1370,6 +1360,12 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) */ if (array_other == NULL) { PyErr_Clear(); + if (DEPRECATE_FUTUREWARNING( + "elementwise comparison failed and returning scalar " + "instead; this will raise an error or perform " + "elementwise comparison in the future.") < 0) { + return NULL; + } Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -1378,18 +1374,31 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) PyArray_DESCR(array_other), NPY_EQUIV_CASTING); if (_res == 0) { - Py_DECREF(result); Py_DECREF(array_other); + if (DEPRECATE_FUTUREWARNING( + "elementwise comparison failed and returning scalar " + "instead; this will raise an error or perform " + "elementwise comparison in the future.") < 0) { + return NULL; + } Py_INCREF(Py_False); return Py_False; } else { - Py_DECREF(result); result = _void_compare(self, array_other, cmp_op); } Py_DECREF(array_other); return result; } + + if (needs_right_binop_forward(obj_self, other, "__eq__", 0) && + Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + result = PyArray_GenericBinaryFunction(self, + (PyObject *)other, + n_ops.equal); /* * If the comparison results in NULL, then the * two array objects can not be compared together; @@ -1402,7 +1411,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) */ PyErr_Clear(); if (DEPRECATE("elementwise comparison failed; " - "this will raise the error in the future.") < 0) { + "this will raise an error in the future.") < 0) { return NULL; } @@ -1419,15 +1428,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_True); return Py_True; } - if (needs_right_binop_forward(obj_self, other, "__ne__", 0) && - Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { - Py_INCREF(Py_NotImplemented); - return Py_NotImplemented; - } - result = PyArray_GenericBinaryFunction(self, (PyObject *)other, - n_ops.not_equal); - if (result && result != Py_NotImplemented) - break; /* * The ufunc does not support void/structured types, so these @@ -1445,6 +1445,12 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) */ if (array_other == NULL) { PyErr_Clear(); + if (DEPRECATE_FUTUREWARNING( + "elementwise comparison failed and returning scalar " + "instead; this will raise an error or perform " + "elementwise comparison in the future.") < 0) { + return NULL; + } Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -1453,19 +1459,30 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) PyArray_DESCR(array_other), NPY_EQUIV_CASTING); if (_res == 0) { - Py_DECREF(result); Py_DECREF(array_other); + if (DEPRECATE_FUTUREWARNING( + "elementwise comparison failed and returning scalar " + "instead; this will raise an error or perform " + "elementwise comparison in the future.") < 0) { + return NULL; + } Py_INCREF(Py_True); return Py_True; } else { - Py_DECREF(result); result = _void_compare(self, array_other, cmp_op); Py_DECREF(array_other); } return result; } + if (needs_right_binop_forward(obj_self, other, "__ne__", 0) && + Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + result = PyArray_GenericBinaryFunction(self, (PyObject *)other, + n_ops.not_equal); if (result == NULL) { /* * Comparisons should raise errors when element-wise comparison @@ -1473,7 +1490,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) */ PyErr_Clear(); if (DEPRECATE("elementwise comparison failed; " - "this will raise the error in the future.") < 0) { + "this will raise an error in the future.") < 0) { return NULL; } diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py index 159753797..b4889e491 100644 --- a/numpy/core/tests/test_deprecations.py +++ b/numpy/core/tests/test_deprecations.py @@ -383,7 +383,7 @@ class TestComparisonDeprecations(_DeprecationTestCase): """ message = "elementwise comparison failed; " \ - "this will raise the error in the future." + "this will raise an error in the future." def test_normal_types(self): for op in (operator.eq, operator.ne): diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index e00c361d8..b2773c189 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -738,11 +738,21 @@ class TestStructured(TestCase): # Check that incompatible sub-array shapes don't result to broadcasting x = np.zeros((1,), dtype=[('a', ('f4', (1, 2))), ('b', 'i1')]) y = np.zeros((1,), dtype=[('a', ('f4', (2,))), ('b', 'i1')]) - assert_equal(x == y, False) + # This comparison invokes deprecated behaviour, and will probably + # start raising an error eventually. What we really care about in this + # test is just that it doesn't return True. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + assert_equal(x == y, False) x = np.zeros((1,), dtype=[('a', ('f4', (2, 1))), ('b', 'i1')]) y = np.zeros((1,), dtype=[('a', ('f4', (2,))), ('b', 'i1')]) - assert_equal(x == y, False) + # This comparison invokes deprecated behaviour, and will probably + # start raising an error eventually. What we really care about in this + # test is just that it doesn't return True. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + assert_equal(x == y, False) # Check that structured arrays that are different only in # byte-order work diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 4527a51d9..6244b3ea7 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -592,6 +592,29 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) + def safe_comparison(*args, **kwargs): + # There are a number of cases where comparing two arrays hits special + # cases in array_richcompare, specifically around strings and void + # dtypes. Basically, we just can't do comparisons involving these + # types, unless both arrays have exactly the *same* type. So + # e.g. you can apply == to two string arrays, or two arrays with + # identical structured dtypes. But if you compare a non-string array + # to a string array, or two arrays with non-identical structured + # dtypes, or anything like that, then internally stuff blows up. + # Currently, when things blow up, we just return a scalar False or + # True. But we also emit a DeprecationWarning, b/c eventually we + # should raise an error here. (Ideally we might even make this work + # properly, but since that will require rewriting a bunch of how + # ufuncs work then we are not counting on that.) + # + # The point of this little function is to let the DeprecationWarning + # pass (or maybe eventually catch the errors and return False, I + # dunno, that's a little trickier and we can figure that out when the + # time comes). + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + return comparison(*args, **kwargs) + def isnumber(x): return x.dtype.char in '?bhilqpBHILQPefdgFDG' @@ -641,11 +664,11 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, return if any(x_id): - val = comparison(x[~x_id], y[~y_id]) + val = safe_comparison(x[~x_id], y[~y_id]) else: - val = comparison(x, y) + val = safe_comparison(x, y) else: - val = comparison(x, y) + val = safe_comparison(x, y) if isinstance(val, bool): cond = val |