diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/numeric.py | 34 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 30 |
2 files changed, 61 insertions, 3 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 40c7a6ae6..8bd4e241b 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -2310,12 +2310,12 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): return cond[()] # Flatten 0d arrays to scalars -def _array_equal_dispatcher(a1, a2): +def _array_equal_dispatcher(a1, a2, equal_nan=None): return (a1, a2) @array_function_dispatch(_array_equal_dispatcher) -def array_equal(a1, a2): +def array_equal(a1, a2, equal_nan=False): """ True if two arrays have the same shape and elements, False otherwise. @@ -2323,6 +2323,12 @@ def array_equal(a1, a2): ---------- a1, a2 : array_like Input arrays. + equal_nan : bool + Whether to compare NaN's as equal. If the dtype of a1 and a2 is + complex, values will be considered equal if either the real or the + imaginary component of a given value is ``nan``. + + .. versionadded:: 1.19.0 Returns ------- @@ -2346,7 +2352,21 @@ def array_equal(a1, a2): False >>> np.array_equal([1, 2], [1, 4]) False + >>> a = np.array([1, np.nan]) + >>> np.array_equal(a, a) + False + >>> np.array_equal(a, a, equal_nan=True) + True + When ``equal_nan`` is True, complex values with nan components are + considered equal if either the real *or* the imaginary components are nan. + + >>> a = np.array([1 + 1j]) + >>> b = a.copy() + >>> a.real = np.nan + >>> b.imag = np.nan + >>> np.array_equal(a, b, equal_nan=True) + True """ try: a1, a2 = asarray(a1), asarray(a2) @@ -2354,7 +2374,15 @@ def array_equal(a1, a2): return False if a1.shape != a2.shape: return False - return bool(asarray(a1 == a2).all()) + if not equal_nan: + return bool(asarray(a1 == a2).all()) + # Handling NaN values if equal_nan is True + a1nan, a2nan = isnan(a1), isnan(a2) + # NaN's occur at different locations + if not (a1nan == a2nan).all(): + return False + # Shapes of a1, a2 and masks are guaranteed to be consistent by this point + return bool(asarray(a1[~a1nan] == a2[~a1nan]).all()) def _array_equiv_dispatcher(a1, a2): diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index bcc6a0c4e..acd442e2f 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -1446,6 +1446,36 @@ class TestArrayComparisons: assert_(res) assert_(type(res) is bool) + def test_array_equal_equal_nan(self): + # Test array_equal with equal_nan kwarg + a1 = np.array([1, 2, np.nan]) + a2 = np.array([1, np.nan, 2]) + a3 = np.array([1, 2, np.inf]) + + # equal_nan=False by default + assert_(not np.array_equal(a1, a1)) + assert_(np.array_equal(a1, a1, equal_nan=True)) + assert_(not np.array_equal(a1, a2, equal_nan=True)) + # nan's not conflated with inf's + assert_(not np.array_equal(a1, a3, equal_nan=True)) + # 0-D arrays + a = np.array(np.nan) + assert_(not np.array_equal(a, a)) + assert_(np.array_equal(a, a, equal_nan=True)) + # Non-float dtype - equal_nan should have no effect + a = np.array([1, 2, 3], dtype=int) + assert_(np.array_equal(a, a)) + assert_(np.array_equal(a, a, equal_nan=True)) + # Multi-dimensional array + a = np.array([[0, 1], [np.nan, 1]]) + assert_(not np.array_equal(a, a)) + assert_(np.array_equal(a, a, equal_nan=True)) + # Complex values + a, b = [np.array([1 + 1j])]*2 + a.real, b.imag = np.nan, np.nan + assert_(not np.array_equal(a, b, equal_nan=False)) + assert_(np.array_equal(a, b, equal_nan=True)) + def test_none_compares_elementwise(self): a = np.array([None, 1, None], dtype=object) assert_equal(a == None, [True, False, True]) |