diff options
Diffstat (limited to 'numpy/testing')
-rw-r--r-- | numpy/testing/tests/test_utils.py | 8 | ||||
-rw-r--r-- | numpy/testing/utils.py | 11 |
2 files changed, 18 insertions, 1 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 94fc4d655..5956a4294 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -53,6 +53,9 @@ class _GenericTest(object): a = np.array([1, 1], dtype=np.object) self._test_equal(a, 1) + def test_array_likes(self): + self._test_equal([1, 2, 3], (1, 2, 3)) + class TestArrayEqual(_GenericTest, unittest.TestCase): def setUp(self): self._assert_func = assert_array_equal @@ -373,6 +376,11 @@ class TestAssertAllclose(unittest.TestCase): assert_allclose(6, 10, rtol=0.5) self.assertRaises(AssertionError, assert_allclose, 10, 6, rtol=0.5) + def test_min_int(self): + a = np.array([np.iinfo(np.int_).min], dtype=np.int_) + # Should not raise: + assert_allclose(a, a) + class TestArrayAlmostEqualNulp(unittest.TestCase): @dec.knownfailureif(True, "Github issue #347") diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 2a99fe5cb..97908c7e8 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -793,7 +793,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): y: array([ 1. , 2.33333, 5. ]) """ - from numpy.core import around, number, float_ + from numpy.core import around, number, float_, result_type, array from numpy.core.numerictypes import issubdtype from numpy.core.fromnumeric import any as npany def compare(x, y): @@ -810,13 +810,22 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): y = y[~yinfid] except (TypeError, NotImplementedError): pass + + # make sure y is an inexact type to avoid abs(MIN_INT); will cause + # casting of x later. + dtype = result_type(y, 1.) + y = array(y, dtype=dtype, copy=False) z = abs(x-y) + if not issubdtype(z.dtype, number): z = z.astype(float_) # handle object arrays + return around(z, decimal) <= 10.0**(-decimal) + assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, header=('Arrays are not almost equal to %d decimals' % decimal)) + def assert_array_less(x, y, err_msg='', verbose=True): """ Raise an assertion if two array_like objects are not ordered by less than. |