diff options
author | Nathaniel J. Smith <njs@pobox.com> | 2015-05-07 20:24:06 -0700 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-06-13 12:32:54 -0600 |
commit | 1e3ab40493fadb5daa67c8a55c5360fd934cca7b (patch) | |
tree | b33e523decb2dbbd27e3a89d01ae3532527d91b9 /numpy/testing/utils.py | |
parent | 4b1f508a57549d8031a23160b40c7f87f47892ed (diff) | |
download | numpy-1e3ab40493fadb5daa67c8a55c5360fd934cca7b.tar.gz |
MAINT: move the special case for void comparison before the regular case
The ndarray richcompare function has special case code for handling
void dtypes (esp. structured dtypes), since there are no ufuncs for
this. Previously, we would attempt to call the relevant
ufunc (e.g. np.equal), and then when this failed (as signaled by the
ufunc returning NotImplemented), we would fall back on the special
case code. This commit moves the special case code to before the
regular code, so that it no longer requires ufuncs to return
NotImplemented.
Technically, it is possible to define ufunc loops for void dtypes
using PyUFunc_RegisterLoopForDescr, so technically I think this commit
changes behaviour: if someone had registered a ufunc loop for one of
these operations, then previously it might have been found and
pre-empted the special case fallback code; now, we use the
special-case code without even checking for any ufunc. But the only
possible use of this functionality would have been if someone wanted
to redefine what == or != meant for a particular structured dtype --
like, they decided that equality for 2-tuples of float32's should be
different from the obvious thing. This does not seem like an important
capability to preserve.
There were also several cases here where on error, an array comparison
would return a scalar instead of raising. This is supposedly
deprecated, but there were call paths that did this that had no
deprecation warning. I added those warnings.
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r-- | numpy/testing/utils.py | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 4527a51d9..6244b3ea7 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -592,6 +592,29 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) + def safe_comparison(*args, **kwargs): + # There are a number of cases where comparing two arrays hits special + # cases in array_richcompare, specifically around strings and void + # dtypes. Basically, we just can't do comparisons involving these + # types, unless both arrays have exactly the *same* type. So + # e.g. you can apply == to two string arrays, or two arrays with + # identical structured dtypes. But if you compare a non-string array + # to a string array, or two arrays with non-identical structured + # dtypes, or anything like that, then internally stuff blows up. + # Currently, when things blow up, we just return a scalar False or + # True. But we also emit a DeprecationWarning, b/c eventually we + # should raise an error here. (Ideally we might even make this work + # properly, but since that will require rewriting a bunch of how + # ufuncs work then we are not counting on that.) + # + # The point of this little function is to let the DeprecationWarning + # pass (or maybe eventually catch the errors and return False, I + # dunno, that's a little trickier and we can figure that out when the + # time comes). + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + return comparison(*args, **kwargs) + def isnumber(x): return x.dtype.char in '?bhilqpBHILQPefdgFDG' @@ -641,11 +664,11 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, return if any(x_id): - val = comparison(x[~x_id], y[~y_id]) + val = safe_comparison(x[~x_id], y[~y_id]) else: - val = comparison(x, y) + val = safe_comparison(x, y) else: - val = comparison(x, y) + val = safe_comparison(x, y) if isinstance(val, bool): cond = val |