summaryrefslogtreecommitdiff
path: root/numpy/testing/tests
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/tests')
-rw-r--r--numpy/testing/tests/test_utils.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 49eeecc8e..c82343f0c 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -214,6 +214,43 @@ class TestArrayEqual(_GenericTest):
np.array([1, 2, 3], np.float32),
np.array([1, 1e-40, 3], np.float32))
+ def test_array_vs_scalar_is_equal(self):
+ """Test comparing an array with a scalar when all values are equal."""
+ a = np.array([1., 1., 1.])
+ b = 1.
+
+ self._test_equal(a, b)
+
+ def test_array_vs_scalar_not_equal(self):
+ """Test comparing an array with a scalar when not all values equal."""
+ a = np.array([1., 2., 3.])
+ b = 1.
+
+ self._test_not_equal(a, b)
+
+ def test_array_vs_scalar_strict(self):
+ """Test comparing an array with a scalar with strict option."""
+ a = np.array([1., 1., 1.])
+ b = 1.
+
+ with pytest.raises(AssertionError):
+ assert_array_equal(a, b, strict=True)
+
+ def test_array_vs_array_strict(self):
+ """Test comparing two arrays with strict option."""
+ a = np.array([1., 1., 1.])
+ b = np.array([1., 1., 1.])
+
+ assert_array_equal(a, b, strict=True)
+
+ def test_array_vs_float_array_strict(self):
+ """Test comparing two arrays with strict option."""
+ a = np.array([1, 1, 1])
+ b = np.array([1., 1., 1.])
+
+ with pytest.raises(AssertionError):
+ assert_array_equal(a, b, strict=True)
+
class TestBuildErrorMessage: