diff options
Diffstat (limited to 'numpy/testing/tests')
-rw-r--r-- | numpy/testing/tests/test_utils.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 49eeecc8e..c82343f0c 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -214,6 +214,43 @@ class TestArrayEqual(_GenericTest): np.array([1, 2, 3], np.float32), np.array([1, 1e-40, 3], np.float32)) + def test_array_vs_scalar_is_equal(self): + """Test comparing an array with a scalar when all values are equal.""" + a = np.array([1., 1., 1.]) + b = 1. + + self._test_equal(a, b) + + def test_array_vs_scalar_not_equal(self): + """Test comparing an array with a scalar when not all values equal.""" + a = np.array([1., 2., 3.]) + b = 1. + + self._test_not_equal(a, b) + + def test_array_vs_scalar_strict(self): + """Test comparing an array with a scalar with strict option.""" + a = np.array([1., 1., 1.]) + b = 1. + + with pytest.raises(AssertionError): + assert_array_equal(a, b, strict=True) + + def test_array_vs_array_strict(self): + """Test comparing two arrays with strict option.""" + a = np.array([1., 1., 1.]) + b = np.array([1., 1., 1.]) + + assert_array_equal(a, b, strict=True) + + def test_array_vs_float_array_strict(self): + """Test comparing two arrays with strict option.""" + a = np.array([1, 1, 1]) + b = np.array([1., 1., 1.]) + + with pytest.raises(AssertionError): + assert_array_equal(a, b, strict=True) + class TestBuildErrorMessage: |