summaryrefslogtreecommitdiff
path: root/numpy/testing/tests/test_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/tests/test_utils.py')
-rw-r--r--numpy/testing/tests/test_utils.py132
1 files changed, 77 insertions, 55 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 43afafaa8..4f1b46d4f 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -17,6 +17,7 @@ from numpy.testing import (
clear_and_catch_warnings, suppress_warnings, assert_string_equal, assert_,
tempdir, temppath, assert_no_gc_cycles, HAS_REFCOUNT
)
+from numpy.core.overrides import ARRAY_FUNCTION_ENABLED
class _GenericTest(object):
@@ -179,6 +180,8 @@ class TestArrayEqual(_GenericTest):
self._test_not_equal(a, b)
self._test_not_equal(b, a)
+ @pytest.mark.skipif(
+ not ARRAY_FUNCTION_ENABLED, reason='requires __array_function__')
def test_subclass_that_does_not_implement_npall(self):
class MyArray(np.ndarray):
def __array_function__(self, *args, **kwargs):
@@ -186,9 +189,8 @@ class TestArrayEqual(_GenericTest):
a = np.array([1., 2.]).view(MyArray)
b = np.array([2., 3.]).view(MyArray)
- if np.core.overrides.ENABLE_ARRAY_FUNCTION:
- with assert_raises(TypeError):
- np.all(a)
+ with assert_raises(TypeError):
+ np.all(a)
self._test_equal(a, a)
self._test_not_equal(a, b)
self._test_not_equal(b, a)
@@ -327,24 +329,29 @@ class TestEqual(TestArrayEqual):
self._test_not_equal(x, y)
def test_error_message(self):
- try:
+ with pytest.raises(AssertionError) as exc_info:
self._assert_func(np.array([1, 2]), np.array([[1, 2]]))
- except AssertionError as e:
- msg = str(e)
- msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)")
- msg_reference = textwrap.dedent("""\
+ msg = str(exc_info.value)
+ msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)")
+ msg_reference = textwrap.dedent("""\
- Arrays are not equal
+ Arrays are not equal
- (shapes (2,), (1, 2) mismatch)
- x: array([1, 2])
- y: array([[1, 2]])""")
- try:
- assert_equal(msg, msg_reference)
- except AssertionError:
- assert_equal(msg2, msg_reference)
- else:
- raise AssertionError("Did not raise")
+ (shapes (2,), (1, 2) mismatch)
+ x: array([1, 2])
+ y: array([[1, 2]])""")
+
+ try:
+ assert_equal(msg, msg_reference)
+ except AssertionError:
+ assert_equal(msg2, msg_reference)
+
+ def test_object(self):
+ #gh-12942
+ import datetime
+ a = np.array([datetime.datetime(2000, 1, 1),
+ datetime.datetime(2000, 1, 2)])
+ self._test_not_equal(a, a[::-1])
class TestArrayAlmostEqual(_GenericTest):
@@ -509,38 +516,53 @@ class TestAlmostEqual(_GenericTest):
x = np.array([1.00000000001, 2.00000000002, 3.00003])
y = np.array([1.00000000002, 2.00000000003, 3.00004])
- # test with a different amount of decimal digits
- # note that we only check for the formatting of the arrays themselves
- b = ('x: array([1.00000000001, 2.00000000002, 3.00003 '
- ' ])\n y: array([1.00000000002, 2.00000000003, 3.00004 ])')
- try:
+ # Test with a different amount of decimal digits
+ with pytest.raises(AssertionError) as exc_info:
self._assert_func(x, y, decimal=12)
- except AssertionError as e:
- # remove anything that's not the array string
- assert_equal(str(e).split('%)\n ')[1], b)
-
- # with the default value of decimal digits, only the 3rd element differs
- # note that we only check for the formatting of the arrays themselves
- b = ('x: array([1. , 2. , 3.00003])\n y: array([1. , '
- '2. , 3.00004])')
- try:
+ msgs = str(exc_info.value).split('\n')
+ assert_equal(msgs[3], 'Mismatched elements: 3 / 3 (100%)')
+ assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
+ assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
+ assert_equal(
+ msgs[6],
+ ' x: array([1.00000000001, 2.00000000002, 3.00003 ])')
+ assert_equal(
+ msgs[7],
+ ' y: array([1.00000000002, 2.00000000003, 3.00004 ])')
+
+ # With the default value of decimal digits, only the 3rd element
+ # differs. Note that we only check for the formatting of the arrays
+ # themselves.
+ with pytest.raises(AssertionError) as exc_info:
self._assert_func(x, y)
- except AssertionError as e:
- # remove anything that's not the array string
- assert_equal(str(e).split('%)\n ')[1], b)
-
- # Check the error message when input includes inf or nan
+ msgs = str(exc_info.value).split('\n')
+ assert_equal(msgs[3], 'Mismatched elements: 1 / 3 (33.3%)')
+ assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
+ assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
+ assert_equal(msgs[6], ' x: array([1. , 2. , 3.00003])')
+ assert_equal(msgs[7], ' y: array([1. , 2. , 3.00004])')
+
+ # Check the error message when input includes inf
x = np.array([np.inf, 0])
y = np.array([np.inf, 1])
- try:
+ with pytest.raises(AssertionError) as exc_info:
+ self._assert_func(x, y)
+ msgs = str(exc_info.value).split('\n')
+ assert_equal(msgs[3], 'Mismatched elements: 1 / 2 (50%)')
+ assert_equal(msgs[4], 'Max absolute difference: 1.')
+ assert_equal(msgs[5], 'Max relative difference: 1.')
+ assert_equal(msgs[6], ' x: array([inf, 0.])')
+ assert_equal(msgs[7], ' y: array([inf, 1.])')
+
+ # Check the error message when dividing by zero
+ x = np.array([1, 2])
+ y = np.array([0, 0])
+ with pytest.raises(AssertionError) as exc_info:
self._assert_func(x, y)
- except AssertionError as e:
- msgs = str(e).split('\n')
- # assert error percentage is 50%
- assert_equal(msgs[3], '(mismatch 50.0%)')
- # assert output array contains inf
- assert_equal(msgs[4], ' x: array([inf, 0.])')
- assert_equal(msgs[5], ' y: array([inf, 1.])')
+ msgs = str(exc_info.value).split('\n')
+ assert_equal(msgs[3], 'Mismatched elements: 2 / 2 (100%)')
+ assert_equal(msgs[4], 'Max absolute difference: 2')
+ assert_equal(msgs[5], 'Max relative difference: inf')
def test_subclass_that_cannot_be_bool(self):
# While we cannot guarantee testing functions will always work for
@@ -829,12 +851,13 @@ class TestAssertAllclose(object):
def test_report_fail_percentage(self):
a = np.array([1, 1, 1, 1])
b = np.array([1, 1, 1, 2])
- try:
+
+ with pytest.raises(AssertionError) as exc_info:
assert_allclose(a, b)
- msg = ''
- except AssertionError as exc:
- msg = exc.args[0]
- assert_("mismatch 25.0%" in msg)
+ msg = str(exc_info.value)
+ assert_('Mismatched elements: 1 / 4 (25%)\n'
+ 'Max absolute difference: 1\n'
+ 'Max relative difference: 0.5' in msg)
def test_equal_nan(self):
a = np.array([np.nan])
@@ -1117,12 +1140,10 @@ class TestStringEqual(object):
assert_string_equal("hello", "hello")
assert_string_equal("hello\nmultiline", "hello\nmultiline")
- try:
+ with pytest.raises(AssertionError) as exc_info:
assert_string_equal("foo\nbar", "hello\nbar")
- except AssertionError as exc:
- assert_equal(str(exc), "Differences in strings:\n- foo\n+ hello")
- else:
- raise AssertionError("exception not raised")
+ msg = str(exc_info.value)
+ assert_equal(msg, "Differences in strings:\n- foo\n+ hello")
assert_raises(AssertionError,
lambda: assert_string_equal("foo", "hello"))
@@ -1488,6 +1509,7 @@ class TestAssertNoGcCycles(object):
with assert_raises(AssertionError):
assert_no_gc_cycles(make_cycle)
+ @pytest.mark.slow
def test_fails(self):
"""
Test that in cases where the garbage cannot be collected, we raise an