diff options
author | Nathaniel J. Smith <njs@pobox.com> | 2015-05-07 20:24:06 -0700 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-06-13 12:32:54 -0600 |
commit | 1e3ab40493fadb5daa67c8a55c5360fd934cca7b (patch) | |
tree | b33e523decb2dbbd27e3a89d01ae3532527d91b9 /numpy | |
parent | 4b1f508a57549d8031a23160b40c7f87f47892ed (diff) | |
download | numpy-1e3ab40493fadb5daa67c8a55c5360fd934cca7b.tar.gz |
MAINT: move the special case for void comparison before the regular case
The ndarray richcompare function has special case code for handling
void dtypes (esp. structured dtypes), since there are no ufuncs for
this. Previously, we would attempt to call the relevant
ufunc (e.g. np.equal), and then when this failed (as signaled by the
ufunc returning NotImplemented), we would fall back on the special
case code. This commit moves the special case code to before the
regular code, so that it no longer requires ufuncs to return
NotImplemented.
Technically, it is possible to define ufunc loops for void dtypes
using PyUFunc_RegisterLoopForDescr, so technically I think this commit
changes behaviour: if someone had registered a ufunc loop for one of
these operations, then previously it might have been found and
pre-empted the special case fallback code; now, we use the
special-case code without even checking for any ufunc. But the only
possible use of this functionality would have been if someone wanted
to redefine what == or != meant for a particular structured dtype --
like, they decided that equality for 2-tuples of float32's should be
different from the obvious thing. This does not seem like an important
capability to preserve.
There were also several cases here where on error, an array comparison
would return a scalar instead of raising. This is supposedly
deprecated, but there were call paths that did this that had no
deprecation warning. I added those warnings.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/arrayobject.c | 67 | ||||
-rw-r--r-- | numpy/core/tests/test_deprecations.py | 2 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 14 | ||||
-rw-r--r-- | numpy/testing/utils.py | 29 |
4 files changed, 81 insertions, 31 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index ac9e1a2d2..3e5084fb3 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -1343,16 +1343,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_False); return Py_False; } - if (needs_right_binop_forward(obj_self, other, "__eq__", 0) && - Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { - Py_INCREF(Py_NotImplemented); - return Py_NotImplemented; - } - result = PyArray_GenericBinaryFunction(self, - (PyObject *)other, - n_ops.equal); - if (result && result != Py_NotImplemented) - break; /* * The ufunc does not support void/structured types, so these @@ -1370,6 +1360,12 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) */ if (array_other == NULL) { PyErr_Clear(); + if (DEPRECATE_FUTUREWARNING( + "elementwise comparison failed and returning scalar " + "instead; this will raise an error or perform " + "elementwise comparison in the future.") < 0) { + return NULL; + } Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -1378,18 +1374,31 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) PyArray_DESCR(array_other), NPY_EQUIV_CASTING); if (_res == 0) { - Py_DECREF(result); Py_DECREF(array_other); + if (DEPRECATE_FUTUREWARNING( + "elementwise comparison failed and returning scalar " + "instead; this will raise an error or perform " + "elementwise comparison in the future.") < 0) { + return NULL; + } Py_INCREF(Py_False); return Py_False; } else { - Py_DECREF(result); result = _void_compare(self, array_other, cmp_op); } Py_DECREF(array_other); return result; } + + if (needs_right_binop_forward(obj_self, other, "__eq__", 0) && + Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + result = PyArray_GenericBinaryFunction(self, + (PyObject *)other, + n_ops.equal); /* * If the comparison results in NULL, then the * two array objects can not be compared together; @@ -1402,7 +1411,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) */ PyErr_Clear(); if (DEPRECATE("elementwise comparison failed; " - "this will raise the error in the future.") < 0) { + "this will raise an error in the future.") < 0) { return NULL; } @@ -1419,15 +1428,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_True); return Py_True; } - if (needs_right_binop_forward(obj_self, other, "__ne__", 0) && - Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { - Py_INCREF(Py_NotImplemented); - return Py_NotImplemented; - } - result = PyArray_GenericBinaryFunction(self, (PyObject *)other, - n_ops.not_equal); - if (result && result != Py_NotImplemented) - break; /* * The ufunc does not support void/structured types, so these @@ -1445,6 +1445,12 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) */ if (array_other == NULL) { PyErr_Clear(); + if (DEPRECATE_FUTUREWARNING( + "elementwise comparison failed and returning scalar " + "instead; this will raise an error or perform " + "elementwise comparison in the future.") < 0) { + return NULL; + } Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -1453,19 +1459,30 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) PyArray_DESCR(array_other), NPY_EQUIV_CASTING); if (_res == 0) { - Py_DECREF(result); Py_DECREF(array_other); + if (DEPRECATE_FUTUREWARNING( + "elementwise comparison failed and returning scalar " + "instead; this will raise an error or perform " + "elementwise comparison in the future.") < 0) { + return NULL; + } Py_INCREF(Py_True); return Py_True; } else { - Py_DECREF(result); result = _void_compare(self, array_other, cmp_op); Py_DECREF(array_other); } return result; } + if (needs_right_binop_forward(obj_self, other, "__ne__", 0) && + Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + result = PyArray_GenericBinaryFunction(self, (PyObject *)other, + n_ops.not_equal); if (result == NULL) { /* * Comparisons should raise errors when element-wise comparison @@ -1473,7 +1490,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) */ PyErr_Clear(); if (DEPRECATE("elementwise comparison failed; " - "this will raise the error in the future.") < 0) { + "this will raise an error in the future.") < 0) { return NULL; } diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py index 159753797..b4889e491 100644 --- a/numpy/core/tests/test_deprecations.py +++ b/numpy/core/tests/test_deprecations.py @@ -383,7 +383,7 @@ class TestComparisonDeprecations(_DeprecationTestCase): """ message = "elementwise comparison failed; " \ - "this will raise the error in the future." + "this will raise an error in the future." def test_normal_types(self): for op in (operator.eq, operator.ne): diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index e00c361d8..b2773c189 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -738,11 +738,21 @@ class TestStructured(TestCase): # Check that incompatible sub-array shapes don't result to broadcasting x = np.zeros((1,), dtype=[('a', ('f4', (1, 2))), ('b', 'i1')]) y = np.zeros((1,), dtype=[('a', ('f4', (2,))), ('b', 'i1')]) - assert_equal(x == y, False) + # This comparison invokes deprecated behaviour, and will probably + # start raising an error eventually. What we really care about in this + # test is just that it doesn't return True. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + assert_equal(x == y, False) x = np.zeros((1,), dtype=[('a', ('f4', (2, 1))), ('b', 'i1')]) y = np.zeros((1,), dtype=[('a', ('f4', (2,))), ('b', 'i1')]) - assert_equal(x == y, False) + # This comparison invokes deprecated behaviour, and will probably + # start raising an error eventually. What we really care about in this + # test is just that it doesn't return True. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + assert_equal(x == y, False) # Check that structured arrays that are different only in # byte-order work diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 4527a51d9..6244b3ea7 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -592,6 +592,29 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) + def safe_comparison(*args, **kwargs): + # There are a number of cases where comparing two arrays hits special + # cases in array_richcompare, specifically around strings and void + # dtypes. Basically, we just can't do comparisons involving these + # types, unless both arrays have exactly the *same* type. So + # e.g. you can apply == to two string arrays, or two arrays with + # identical structured dtypes. But if you compare a non-string array + # to a string array, or two arrays with non-identical structured + # dtypes, or anything like that, then internally stuff blows up. + # Currently, when things blow up, we just return a scalar False or + # True. But we also emit a DeprecationWarning, b/c eventually we + # should raise an error here. (Ideally we might even make this work + # properly, but since that will require rewriting a bunch of how + # ufuncs work then we are not counting on that.) + # + # The point of this little function is to let the DeprecationWarning + # pass (or maybe eventually catch the errors and return False, I + # dunno, that's a little trickier and we can figure that out when the + # time comes). + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + return comparison(*args, **kwargs) + def isnumber(x): return x.dtype.char in '?bhilqpBHILQPefdgFDG' @@ -641,11 +664,11 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, return if any(x_id): - val = comparison(x[~x_id], y[~y_id]) + val = safe_comparison(x[~x_id], y[~y_id]) else: - val = comparison(x, y) + val = safe_comparison(x, y) else: - val = comparison(x, y) + val = safe_comparison(x, y) if isinstance(val, bool): cond = val |