diff options
author | Jon Morris <jontwo@users.noreply.github.com> | 2022-06-24 16:51:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-24 11:51:36 -0400 |
commit | cafec60a5e28af98fb8798049edd7942720d2d74 (patch) | |
tree | 00d6ea28360e1a59972d07de162bd97e782b2b24 /numpy/testing/tests | |
parent | 019c8c9b2a7c084eb01cf4d8569799a5537d884d (diff) | |
download | numpy-cafec60a5e28af98fb8798049edd7942720d2d74.tar.gz |
ENH: Add strict parameter to assert_array_equal. (#21595)
Fixes #9542
Co-authored-by: Bas van Beek <43369155+BvB93@users.noreply.github.com>
Diffstat (limited to 'numpy/testing/tests')
-rw-r--r-- | numpy/testing/tests/test_utils.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 49eeecc8e..c82343f0c 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -214,6 +214,43 @@ class TestArrayEqual(_GenericTest): np.array([1, 2, 3], np.float32), np.array([1, 1e-40, 3], np.float32)) + def test_array_vs_scalar_is_equal(self): + """Test comparing an array with a scalar when all values are equal.""" + a = np.array([1., 1., 1.]) + b = 1. + + self._test_equal(a, b) + + def test_array_vs_scalar_not_equal(self): + """Test comparing an array with a scalar when not all values equal.""" + a = np.array([1., 2., 3.]) + b = 1. + + self._test_not_equal(a, b) + + def test_array_vs_scalar_strict(self): + """Test comparing an array with a scalar with strict option.""" + a = np.array([1., 1., 1.]) + b = 1. + + with pytest.raises(AssertionError): + assert_array_equal(a, b, strict=True) + + def test_array_vs_array_strict(self): + """Test comparing two arrays with strict option.""" + a = np.array([1., 1., 1.]) + b = np.array([1., 1., 1.]) + + assert_array_equal(a, b, strict=True) + + def test_array_vs_float_array_strict(self): + """Test comparing two arrays with strict option.""" + a = np.array([1, 1, 1]) + b = np.array([1., 1., 1.]) + + with pytest.raises(AssertionError): + assert_array_equal(a, b, strict=True) + class TestBuildErrorMessage: |