summaryrefslogtreecommitdiff
path: root/numpy/testing/tests
diff options
context:
space:
mode:
authorJon Morris <jontwo@users.noreply.github.com>2022-06-24 16:51:36 +0100
committerGitHub <noreply@github.com>2022-06-24 11:51:36 -0400
commitcafec60a5e28af98fb8798049edd7942720d2d74 (patch)
tree00d6ea28360e1a59972d07de162bd97e782b2b24 /numpy/testing/tests
parent019c8c9b2a7c084eb01cf4d8569799a5537d884d (diff)
downloadnumpy-cafec60a5e28af98fb8798049edd7942720d2d74.tar.gz
ENH: Add strict parameter to assert_array_equal. (#21595)
Fixes #9542 Co-authored-by: Bas van Beek <43369155+BvB93@users.noreply.github.com>
Diffstat (limited to 'numpy/testing/tests')
-rw-r--r--numpy/testing/tests/test_utils.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 49eeecc8e..c82343f0c 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -214,6 +214,43 @@ class TestArrayEqual(_GenericTest):
np.array([1, 2, 3], np.float32),
np.array([1, 1e-40, 3], np.float32))
+ def test_array_vs_scalar_is_equal(self):
+ """Test comparing an array with a scalar when all values are equal."""
+ a = np.array([1., 1., 1.])
+ b = 1.
+
+ self._test_equal(a, b)
+
+ def test_array_vs_scalar_not_equal(self):
+ """Test comparing an array with a scalar when not all values equal."""
+ a = np.array([1., 2., 3.])
+ b = 1.
+
+ self._test_not_equal(a, b)
+
+ def test_array_vs_scalar_strict(self):
+ """Test comparing an array with a scalar with strict option."""
+ a = np.array([1., 1., 1.])
+ b = 1.
+
+ with pytest.raises(AssertionError):
+ assert_array_equal(a, b, strict=True)
+
+ def test_array_vs_array_strict(self):
+ """Test comparing two arrays with strict option."""
+ a = np.array([1., 1., 1.])
+ b = np.array([1., 1., 1.])
+
+ assert_array_equal(a, b, strict=True)
+
+ def test_array_vs_float_array_strict(self):
+ """Test comparing two arrays with strict option."""
+ a = np.array([1, 1, 1])
+ b = np.array([1., 1., 1.])
+
+ with pytest.raises(AssertionError):
+ assert_array_equal(a, b, strict=True)
+
class TestBuildErrorMessage: