summaryrefslogtreecommitdiff
path: root/numpy/testing/tests/test_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/tests/test_utils.py')
-rw-r--r--numpy/testing/tests/test_utils.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 94fc4d655..aa0a2669f 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -53,6 +53,9 @@ class _GenericTest(object):
a = np.array([1, 1], dtype=np.object)
self._test_equal(a, 1)
+ def test_array_likes(self):
+ self._test_equal([1, 2, 3], (1, 2, 3))
+
class TestArrayEqual(_GenericTest, unittest.TestCase):
def setUp(self):
self._assert_func = assert_array_equal
@@ -131,6 +134,49 @@ class TestArrayEqual(_GenericTest, unittest.TestCase):
self._test_not_equal(c, b)
+class TestBuildErrorMessage(unittest.TestCase):
+ def test_build_err_msg_defaults(self):
+ x = np.array([1.00001, 2.00002, 3.00003])
+ y = np.array([1.00002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg)
+ b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ '
+ '1.00001, 2.00002, 3.00003])\n DESIRED: array([ 1.00002, '
+ '2.00003, 3.00004])')
+ self.assertEqual(a, b)
+
+ def test_build_err_msg_no_verbose(self):
+ x = np.array([1.00001, 2.00002, 3.00003])
+ y = np.array([1.00002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg, verbose=False)
+ b = '\nItems are not equal: There is a mismatch'
+ self.assertEqual(a, b)
+
+ def test_build_err_msg_custom_names(self):
+ x = np.array([1.00001, 2.00002, 3.00003])
+ y = np.array([1.00002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
+ b = ('\nItems are not equal: There is a mismatch\n FOO: array([ '
+ '1.00001, 2.00002, 3.00003])\n BAR: array([ 1.00002, 2.00003, '
+ '3.00004])')
+ self.assertEqual(a, b)
+
+ def test_build_err_msg_custom_precision(self):
+ x = np.array([1.000000001, 2.00002, 3.00003])
+ y = np.array([1.000000002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg, precision=10)
+ b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ '
+ '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array([ '
+ '1.000000002, 2.00003 , 3.00004 ])')
+ self.assertEqual(a, b)
+
class TestEqual(TestArrayEqual):
def setUp(self):
self._assert_func = assert_equal
@@ -236,6 +282,31 @@ class TestAlmostEqual(_GenericTest, unittest.TestCase):
self._test_not_equal(x, y)
self._test_not_equal(x, z)
+ def test_error_message(self):
+ """Check the message is formatted correctly for the decimal value"""
+ x = np.array([1.00000000001, 2.00000000002, 3.00003])
+ y = np.array([1.00000000002, 2.00000000003, 3.00004])
+
+ # test with a different amount of decimal digits
+ # note that we only check for the formatting of the arrays themselves
+ b = ('x: array([ 1.00000000001, 2.00000000002, 3.00003 '
+ ' ])\n y: array([ 1.00000000002, 2.00000000003, 3.00004 ])')
+ try:
+ self._assert_func(x, y, decimal=12)
+ except AssertionError as e:
+ # remove anything that's not the array string
+ self.assertEqual(str(e).split('%)\n ')[1], b)
+
+ # with the default value of decimal digits, only the 3rd element differs
+ # note that we only check for the formatting of the arrays themselves
+ b = ('x: array([ 1. , 2. , 3.00003])\n y: array([ 1. , '
+ '2. , 3.00004])')
+ try:
+ self._assert_func(x, y)
+ except AssertionError as e:
+ # remove anything that's not the array string
+ self.assertEqual(str(e).split('%)\n ')[1], b)
+
class TestApproxEqual(unittest.TestCase):
def setUp(self):
self._assert_func = assert_approx_equal
@@ -373,6 +444,11 @@ class TestAssertAllclose(unittest.TestCase):
assert_allclose(6, 10, rtol=0.5)
self.assertRaises(AssertionError, assert_allclose, 10, 6, rtol=0.5)
+ def test_min_int(self):
+ a = np.array([np.iinfo(np.int_).min], dtype=np.int_)
+ # Should not raise:
+ assert_allclose(a, a)
+
class TestArrayAlmostEqualNulp(unittest.TestCase):
@dec.knownfailureif(True, "Github issue #347")