summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-05-27 13:15:16 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-06-04 09:37:02 -0400
commit3ad49aaaf497c6daadb9b66f295f58a315476e01 (patch)
tree6e0908fdcdb3e7c4130fb932119a83fb2f8c501c
parent5718b3306303b4899c017641a9282714dcf8c9b8 (diff)
downloadnumpy-3ad49aaaf497c6daadb9b66f295f58a315476e01.tar.gz
MAINT: clean up assert_array_compare a bit further.
This brought to light two bugs in tests, which are fixed here, viz., that a sample ndarray subclass that tested propagation of an added parameter was incomplete, in that in propagating the parameter in __array_wrap__ it assumed it was there on self, but that assumption could be broken when a view of self was taken (as is done by x[~flagged] in the test routine), since there was no __array_finalize__ defined. The other subclass bug counted, incorrectly, on only needing to provide one type of comparison, the __lt__ being explicitly tested. But flags are compared with __eq__ and those flags will have the same subclass.
-rw-r--r--numpy/lib/tests/test_ufunclike.py4
-rw-r--r--numpy/testing/_private/utils.py23
-rw-r--r--numpy/testing/tests/test_utils.py6
3 files changed, 23 insertions, 10 deletions
diff --git a/numpy/lib/tests/test_ufunclike.py b/numpy/lib/tests/test_ufunclike.py
index ad006fe17..5604b3744 100644
--- a/numpy/lib/tests/test_ufunclike.py
+++ b/numpy/lib/tests/test_ufunclike.py
@@ -55,6 +55,10 @@ class TestUfunclike(object):
obj.metadata = self.metadata
return obj
+ def __array_finalize__(self, obj):
+ self.metadata = getattr(obj, 'metadata', None)
+ return self
+
a = nx.array([1.1, -1.1])
m = MyArray(a, metadata='foo')
f = ufl.fix(m)
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index 5b7e35366..f8bfe0ba9 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -685,7 +685,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header='', precision=6, equal_nan=True,
equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
- from numpy.core import array, isnan, any, inf, ndim
+ from numpy.core import array, isnan, inf, bool_
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -698,14 +698,12 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
"""Handling nan/inf: combine results of running func on x and y,
checking that they are True at the same locations."""
+ # Both the != True comparison here and the cast to bool_ at
+ # the end are done to deal with `masked`, which cannot be
+ # compared usefully, and for which .all() yields masked.
x_id = func(x)
y_id = func(y)
- if not any(x_id) and not any(y_id):
- return False
-
- try:
- assert_array_equal(x_id, y_id)
- except AssertionError:
+ if (x_id == y_id).all() != True:
msg = build_err_msg([x, y],
err_msg + '\nx and y %s location mismatch:'
% (hasval), verbose=verbose, header=header,
@@ -713,7 +711,12 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
raise AssertionError(msg)
# If there is a scalar, then here we know the array has the same
# flag as it everywhere, so we should return the scalar flag.
- return x_id if x_id.ndim == 0 else y_id
+ if x_id.ndim == 0:
+ return bool_(x_id)
+ elif y_id.ndim == 0:
+ return bool_(y_id)
+ else:
+ return y_id
try:
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
@@ -726,7 +729,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
- flagged = False
+ flagged = bool_(False)
if isnumber(x) and isnumber(y):
if equal_nan:
flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
@@ -744,7 +747,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
if equal_nan and x.dtype.type == y.dtype.type:
flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT")
- if ndim(flagged):
+ if flagged.ndim > 0:
x, y = x[~flagged], y[~flagged]
# Only do the comparison if actual values are left
if x.size == 0:
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 235b3750c..465c217d4 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -401,6 +401,9 @@ class TestArrayAlmostEqual(_GenericTest):
# comparison operators, not on them being able to store booleans
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
class MyArray(np.ndarray):
+ def __eq__(self, other):
+ return super(MyArray, self).__eq__(other).view(np.ndarray)
+
def __lt__(self, other):
return super(MyArray, self).__lt__(other).view(np.ndarray)
@@ -500,6 +503,9 @@ class TestAlmostEqual(_GenericTest):
# comparison operators, not on them being able to store booleans
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
class MyArray(np.ndarray):
+ def __eq__(self, other):
+ return super(MyArray, self).__eq__(other).view(np.ndarray)
+
def __lt__(self, other):
return super(MyArray, self).__lt__(other).view(np.ndarray)