summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-06-07 16:57:45 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-06-18 16:10:26 -0400
commit6c25ca1ddfd8d02375544ba9be03b654a62997ba (patch)
treee77003117f0ac7527dfabe395050d6d539774db7
parent44cfca844f043722906a9b2932a4ba03dce15c75 (diff)
downloadnumpy-6c25ca1ddfd8d02375544ba9be03b654a62997ba.tar.gz
MAINT: move comparison operator special-handling out of ufunc parsing.
The argument parsing for ufuncs contains special logic to enable returning NotImplemented for comparison functions. The ufuncs, however, should not have to know whether they were called via a python comparison operator, so this PR moves the logic to the operators, checking after ufunc failed whether a NotImplemented should be returned.
-rw-r--r--numpy/core/src/multiarray/arrayobject.c155
-rw-r--r--numpy/core/src/umath/ufunc_object.c180
-rw-r--r--numpy/core/tests/test_ufunc.py11
3 files changed, 123 insertions, 223 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index 943edc772..f09121027 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -1250,7 +1250,8 @@ PyArray_ChainExceptionsCause(PyObject *exc, PyObject *val, PyObject *tb)
}
}
-/* Silence the current error and emit a deprecation warning instead.
+/*
+ * Silence the current error and emit a deprecation warning instead.
*
* If warnings are raised as errors, this sets the warning __cause__ to the
* silenced error.
@@ -1366,26 +1367,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
result = PyArray_GenericBinaryFunction(self,
(PyObject *)other,
n_ops.equal);
- /*
- * If the comparison results in NULL, then the
- * two array objects can not be compared together;
- * indicate that
- */
- if (result == NULL) {
- /*
- * Comparisons should raise errors when element-wise comparison
- * is not possible.
- */
- /* 2015-05-14, 1.10 */
- if (DEPRECATE_silence_error(
- "elementwise == comparison failed; "
- "this will raise an error in the future.") < 0) {
- return NULL;
- }
-
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
break;
case Py_NE:
RICHCMP_GIVE_UP_IF_NEEDED(obj_self, other);
@@ -1437,21 +1418,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
result = PyArray_GenericBinaryFunction(self, (PyObject *)other,
n_ops.not_equal);
- if (result == NULL) {
- /*
- * Comparisons should raise errors when element-wise comparison
- * is not possible.
- */
- /* 2015-05-14, 1.10 */
- if (DEPRECATE_silence_error(
- "elementwise != comparison failed; "
- "this will raise an error in the future.") < 0) {
- return NULL;
- }
-
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
break;
case Py_GT:
RICHCMP_GIVE_UP_IF_NEEDED(obj_self, other);
@@ -1464,8 +1430,121 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
n_ops.greater_equal);
break;
default:
- result = Py_NotImplemented;
- Py_INCREF(result);
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+ if (result == NULL) {
+ /*
+ * 2015-05-14, 1.10; updated 2018-06-18, 1.16.
+ *
+ * Comparisons can raise errors when element-wise comparison is not
+ * possible. Some of these, though, should not be passed on.
+ * In particular, the ufuncs do not have loops for flexible dtype,
+ * so those should be treated separately. Furthermore, for EQ and NE,
+ * we should never fail.
+ *
+ * Our ideal behaviour would be:
+ *
+ * 1. For EQ and NE:
+ * - If self and other are scalars, return NotImplemented,
+ * so that python can assign True of False as appropriate.
+ * - If either is an array, return an array of False or True.
+ *
+ * 2. For LT, LE, GE, GT:
+ * - If self or other was flexible, return NotImplemented
+ * (as is in fact the case), so python can raise a TypeError.
+ * - If other is not convertible to an array, pass on the error
+ * (MHvK, 2018-06-18: not sure about this, but it's what we have).
+ *
+ * However, for backwards compatibilty, we cannot yet return arrays,
+ * so we raise warnings instead. Furthermore, we warn on python2
+ * for LT, LE, GE, GT, since fall-back behaviour is poorly defined.
+ */
+ PyObject *exc, *val, *tb;
+ int other_is_flexible, ndim_other;
+
+ PyErr_Fetch(&exc, &val, &tb);
+ /*
+ * Determine whether other has a flexible dtype; here, inconvertible
+ * is counted as inflexible. (This repeats work done in the ufunc,
+ * but OK to waste some time in an unlikely path.)
+ */
+ array_other = (PyArrayObject *)PyArray_FROM_O(other);
+ if (array_other) {
+ other_is_flexible = PyTypeNum_ISFLEXIBLE(
+ PyArray_DESCR(array_other)->type_num);
+ ndim_other = PyArray_NDIM(array_other);
+ Py_DECREF(array_other);
+ }
+ else {
+ PyErr_Clear(); /* we restore the original error if needed */
+ other_is_flexible = 0;
+ ndim_other = 0;
+ }
+ if (cmp_op == Py_EQ || cmp_op == Py_NE) {
+ /* note: for == and !=, a flexible self cannot get here */
+ if (other_is_flexible) {
+ /*
+ * For scalars, returning NotImplemented is correct.
+ * For arrays, we emit a future deprecation warning.
+ */
+ if (ndim_other != 0 || PyArray_NDIM(self) != 0) {
+ if (DEPRECATE_FUTUREWARNING(
+ "elementwise comparison failed; returning scalar "
+ "instead, but in the future will perform "
+ "elementwise comparison") < 0) {
+ /*
+ * In future, we should create a correctly shaped
+ * array of bool. For now, a placeholder error.
+ */
+ PyArray_ChainExceptionsCause(exc, val, tb);
+ return NULL;
+ }
+ }
+ }
+ else {
+ /*
+ * If other did not have a flexible dtype, the error cannot
+ * have been caused by a lack of implementation in the ufunc.
+ */
+ if (DEPRECATE(
+ "elementwise comparison failed; "
+ "this will raise an error in the future.") < 0) {
+ PyArray_ChainExceptionsCause(exc, val, tb);
+ return NULL;
+ }
+ }
+ Py_XDECREF(exc);
+ Py_XDECREF(val);
+ Py_XDECREF(tb);
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+ else if (other_is_flexible ||
+ PyTypeNum_ISFLEXIBLE(PyArray_DESCR(self)->type_num)) {
+ /*
+ * For LE, LT, GT, GE and a flexible self or other, we return
+ * NotImplemented, which is the correct answer since the ufuncs do
+ * not in fact implement loops for those. On python 3 this will
+ * get us the desired TypeError, but on python 2, one gets strange
+ * ordering, so we emit a warning.
+ */
+#if !defined(NPY_PY3K)
+ if (DEPRECATE(
+ "unorderable dtypes; returning scalar but in "
+ "the future this will be an error") < 0) {
+ PyArray_ChainExceptionsCause(exc, val, tb);
+ return NULL;
+ }
+#endif
+ Py_XDECREF(exc);
+ Py_XDECREF(val);
+ Py_XDECREF(tb);
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+ /* LE, LT, GT, or GE with non-flexible other; just pass on error */
+ PyErr_Restore(exc, val, tb);
}
return result;
}
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 5e92bc991..a5cb77da4 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -577,9 +577,6 @@ get_ufunc_arguments(PyUFuncObject *ufunc,
PyObject *obj, *context;
PyObject *str_key_obj = NULL;
const char *ufunc_name = ufunc_get_name_cstr(ufunc);
- int type_num;
-
- int any_flexible = 0, any_object = 0, any_flexible_userloops = 0;
int has_sig = 0;
/*
@@ -638,166 +635,6 @@ get_ufunc_arguments(PyUFuncObject *ufunc,
if (out_op[i] == NULL) {
goto fail;
}
-
- type_num = PyArray_DESCR(out_op[i])->type_num;
- if (!any_flexible &&
- PyTypeNum_ISFLEXIBLE(type_num)) {
- any_flexible = 1;
- }
- if (!any_object &&
- PyTypeNum_ISOBJECT(type_num)) {
- any_object = 1;
- }
-
- /*
- * If any operand is a flexible dtype, check to see if any
- * struct dtype ufuncs are registered. A ufunc has been registered
- * for a struct dtype if ufunc's arg_dtypes array is not NULL.
- */
- if (PyTypeNum_ISFLEXIBLE(type_num) &&
- !any_flexible_userloops &&
- ufunc->userloops != NULL) {
- PyUFunc_Loop1d *funcdata;
- PyObject *key, *obj;
- key = PyInt_FromLong(type_num);
- if (key == NULL) {
- continue;
- }
- obj = PyDict_GetItem(ufunc->userloops, key);
- Py_DECREF(key);
- if (obj == NULL) {
- continue;
- }
- funcdata = (PyUFunc_Loop1d *)NpyCapsule_AsVoidPtr(obj);
- while (funcdata != NULL) {
- if (funcdata->arg_dtypes != NULL) {
- any_flexible_userloops = 1;
- break;
- }
- funcdata = funcdata->next;
- }
- }
- }
-
- if (any_flexible && !any_flexible_userloops && !any_object && nin == 2) {
- /* Traditionally, we return -2 here (meaning "NotImplemented") anytime
- * we hit the above condition.
- *
- * This condition basically means "we are doomed", b/c the "flexible"
- * dtypes -- strings and void -- cannot have their own ufunc loops
- * registered (except via the special "flexible userloops" mechanism),
- * and they can't be cast to anything except object (and we only cast
- * to object if any_object is true). So really we should do nothing
- * here and continue and let the proper error be raised. But, we can't
- * quite yet, b/c of backcompat.
- *
- * Most of the time, this NotImplemented either got returned directly
- * to the user (who can't do anything useful with it), or got passed
- * back out of a special function like __mul__. And fortunately, for
- * almost all special functions, the end result of this was a
- * TypeError. Which is also what we get if we just continue without
- * this special case, so this special case is unnecessary.
- *
- * The only thing that actually depended on the NotImplemented is
- * array_richcompare, which did two things with it. First, it needed
- * to see this NotImplemented in order to implement the special-case
- * comparisons for
- *
- * string < <= == != >= > string
- * void == != void
- *
- * Now it checks for those cases first, before trying to call the
- * ufunc, so that's no problem. What it doesn't handle, though, is
- * cases like
- *
- * float < string
- *
- * or
- *
- * float == void
- *
- * For those, it just let the NotImplemented bubble out, and accepted
- * Python's default handling. And unfortunately, for comparisons,
- * Python's default is *not* to raise an error. Instead, it returns
- * something that depends on the operator:
- *
- * == return False
- * != return True
- * < <= >= > Python 2: use "fallback" (= weird and broken) ordering
- * Python 3: raise TypeError (hallelujah)
- *
- * In most cases this is straightforwardly broken, because comparison
- * of two arrays should always return an array, and here we end up
- * returning a scalar. However, there is an exception: if we are
- * comparing two scalars for equality, then it actually is correct to
- * return a scalar bool instead of raising an error. If we just
- * removed this special check entirely, then "np.float64(1) == 'foo'"
- * would raise an error instead of returning False, which is genuinely
- * wrong.
- *
- * The proper end goal here is:
- * 1) == and != should be implemented in a proper vectorized way for
- * all types. The short-term hack for this is just to add a
- * special case to PyUFunc_DefaultLegacyInnerLoopSelector where
- * if it can't find a comparison loop for the given types, and
- * the ufunc is np.equal or np.not_equal, then it returns a loop
- * that just fills the output array with False (resp. True). Then
- * array_richcompare could trust that whenever its special cases
- * don't apply, simply calling the ufunc will do the right thing,
- * even without this special check.
- * 2) < <= >= > should raise an error if no comparison function can
- * be found. array_richcompare already handles all string <>
- * string cases, and void dtypes don't have ordering, so again
- * this would mean that array_richcompare could simply call the
- * ufunc and it would do the right thing (i.e., raise an error),
- * again without needing this special check.
- *
- * So this means that for the transition period, our goal is:
- * == and != on scalars should simply return NotImplemented like
- * they always did, since everything ends up working out correctly
- * in this case only
- * == and != on arrays should issue a FutureWarning and then return
- * NotImplemented
- * < <= >= > on all flexible dtypes on py2 should raise a
- * DeprecationWarning, and then return NotImplemented. On py3 we
- * skip the warning, though, b/c it would just be immediately be
- * followed by an exception anyway.
- *
- * And for all other operations, we let things continue as normal.
- */
- /* strcmp() is a hack but I think we can get away with it for this
- * temporary measure.
- */
- if (!strcmp(ufunc_name, "equal") ||
- !strcmp(ufunc_name, "not_equal")) {
- /* Warn on non-scalar, return NotImplemented regardless */
- if (PyArray_NDIM(out_op[0]) != 0 ||
- PyArray_NDIM(out_op[1]) != 0) {
- if (DEPRECATE_FUTUREWARNING(
- "elementwise comparison failed; returning scalar "
- "instead, but in the future will perform elementwise "
- "comparison") < 0) {
- goto fail;
- }
- }
- Py_DECREF(out_op[0]);
- Py_DECREF(out_op[1]);
- return -2;
- }
- else if (!strcmp(ufunc_name, "less") ||
- !strcmp(ufunc_name, "less_equal") ||
- !strcmp(ufunc_name, "greater") ||
- !strcmp(ufunc_name, "greater_equal")) {
-#if !defined(NPY_PY3K)
- if (DEPRECATE("unorderable dtypes; returning scalar but in "
- "the future this will be an error") < 0) {
- goto fail;
- }
-#endif
- Py_DECREF(out_op[0]);
- Py_DECREF(out_op[1]);
- return -2;
- }
}
/* Get positional output arguments */
@@ -4507,22 +4344,7 @@ ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
errval = PyUFunc_GenericFunction(ufunc, args, kwds, mps);
if (errval < 0) {
- if (errval == -1) {
- return NULL;
- }
- else if (ufunc->nin == 2 && ufunc->nout == 1) {
- /*
- * For array_richcompare's benefit -- see the long comment in
- * get_ufunc_arguments.
- */
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
- else {
- PyErr_SetString(PyExc_TypeError,
- "XX can't happen, please report a bug XX");
- return NULL;
- }
+ return NULL;
}
/* Free the input references */
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index ef9ced354..49a4dbbc9 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -1652,17 +1652,16 @@ class TestUfunc(object):
np.bitwise_xor, np.left_shift, np.right_shift, np.fmax,
np.fmin, np.fmod, np.hypot, np.logaddexp, np.logaddexp2,
np.logical_and, np.logical_or, np.logical_xor, np.maximum,
- np.minimum, np.mod
- ]
-
- # These functions still return NotImplemented. Will be fixed in
- # future.
- # bad = [np.greater, np.greater_equal, np.less, np.less_equal, np.not_equal]
+ np.minimum, np.mod,
+ np.greater, np.greater_equal, np.less, np.less_equal,
+ np.equal, np.not_equal]
a = np.array('1')
b = 1
+ c = np.array([1., 2.])
for f in binary_funcs:
assert_raises(TypeError, f, a, b)
+ assert_raises(TypeError, f, c, a)
def test_reduce_noncontig_output(self):
# Check that reduction deals with non-contiguous output arrays