summaryrefslogtreecommitdiff
path: root/numpy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing')
-rw-r--r--numpy/testing/tests/test_utils.py9
-rw-r--r--numpy/testing/utils.py43
2 files changed, 35 insertions, 17 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 9798a25be..067782dc0 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -189,6 +189,13 @@ class TestArrayAlmostEqual(_GenericTest, unittest.TestCase):
self.assertRaises(AssertionError,
lambda : self._assert_func(ainf, anan))
+ def test_inf(self):
+ a = np.array([[1., 2.], [3., 4.]])
+ b = a.copy()
+ a[0,0] = np.inf
+ self.assertRaises(AssertionError,
+ lambda : self._assert_func(a, b))
+
class TestAlmostEqual(_GenericTest, unittest.TestCase):
def setUp(self):
self._assert_func = assert_almost_equal
@@ -205,6 +212,8 @@ class TestAlmostEqual(_GenericTest, unittest.TestCase):
def test_inf_item(self):
self._assert_func(np.inf, np.inf)
self._assert_func(-np.inf, -np.inf)
+ self.assertRaises(AssertionError,
+ lambda : self._assert_func(np.inf, 1))
def test_simple_item(self):
self._test_not_equal(1, 2)
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 0c0483f39..01ce31c4a 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -568,13 +568,25 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header=''):
- from numpy.core import array, isnan, any
+ from numpy.core import array, isnan, isinf, any
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
def isnumber(x):
return x.dtype.char in '?bhilqpBHILQPfdgFDG'
+ def chk_same_position(x_id, y_id, hasval='nan'):
+ """Handling nan/inf: check that x and y have the nan/inf at the same
+ locations."""
+ try:
+ assert_array_equal(x_id, y_id)
+ except AssertionError:
+ msg = build_err_msg([x, y],
+ err_msg + '\nx and y %s location mismatch:' \
+ % (hasval), verbose=verbose, header=header,
+ names=('x', 'y'))
+ raise AssertionError(msg)
+
try:
cond = (x.shape==() or y.shape==()) or x.shape == y.shape
if not cond:
@@ -588,27 +600,24 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
raise AssertionError(msg)
if (isnumber(x) and isnumber(y)) and (any(isnan(x)) or any(isnan(y))):
- # Handling nan: we first check that x and y have the nan at the
- # same locations, and then we mask the nan and do the comparison as
- # usual.
- xnanid = isnan(x)
- ynanid = isnan(y)
- try:
- assert_array_equal(xnanid, ynanid)
- except AssertionError:
- msg = build_err_msg([x, y],
- err_msg
- + '\n(x and y nan location mismatch %s, ' \
- '%s mismatch)' % (xnanid, ynanid),
- verbose=verbose, header=header,
- names=('x', 'y'))
- raise AssertionError(msg)
+ x_id = isnan(x)
+ y_id = isnan(y)
+ chk_same_position(x_id, y_id, hasval='nan')
# If only one item, it was a nan, so just return
if x.size == y.size == 1:
return
- val = comparison(x[~xnanid], y[~ynanid])
+ val = comparison(x[~x_id], y[~y_id])
+ elif (isnumber(x) and isnumber(y)) and (any(isinf(x)) or any(isinf(y))):
+ x_id = isinf(x)
+ y_id = isinf(y)
+ chk_same_position(x_id, y_id, hasval='inf')
+ # If only one item, it was a inf, so just return
+ if x.size == y.size == 1:
+ return
+ val = comparison(x[~x_id], y[~y_id])
else:
val = comparison(x,y)
+
if isinstance(val, bool):
cond = val
reduced = [0]