diff options
author | Ross Barnowski <rossbar@berkeley.edu> | 2020-05-11 10:18:37 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-11 12:18:37 -0500 |
commit | ba9f93d8ad9763ef9867bac5761a40fa4770aa82 (patch) | |
tree | f00f84508a9eb341aa39f3ce54a3d440a263d15e | |
parent | 9ff05dc5d5afd4e801f1acf27b33ca2ed74c48f8 (diff) | |
download | numpy-ba9f93d8ad9763ef9867bac5761a40fa4770aa82.tar.gz |
ENH: Add equal_nan keyword argument to array_equal (gh-16128)
Match the corresponding kwarg and behavior from related functions
like allclose.
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
-rw-r--r-- | doc/release/upcoming_changes/16128.new_feature.rst | 6 | ||||
-rw-r--r-- | numpy/core/numeric.py | 34 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 30 |
3 files changed, 67 insertions, 3 deletions
diff --git a/doc/release/upcoming_changes/16128.new_feature.rst b/doc/release/upcoming_changes/16128.new_feature.rst new file mode 100644 index 000000000..03b061c07 --- /dev/null +++ b/doc/release/upcoming_changes/16128.new_feature.rst @@ -0,0 +1,6 @@ +``equal_nan`` parameter for `numpy.array_equal` +------------------------------------------------ +The keyword argument ``equal_nan`` was added to `numpy.array_equal`. +``equal_nan`` is a boolean value that toggles whether or not ``nan`` values +are considered equal in comparison (default is ``False``). This matches API +used in related functions such as `numpy.isclose` and `numpy.allclose`. diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 40c7a6ae6..8bd4e241b 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -2310,12 +2310,12 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): return cond[()] # Flatten 0d arrays to scalars -def _array_equal_dispatcher(a1, a2): +def _array_equal_dispatcher(a1, a2, equal_nan=None): return (a1, a2) @array_function_dispatch(_array_equal_dispatcher) -def array_equal(a1, a2): +def array_equal(a1, a2, equal_nan=False): """ True if two arrays have the same shape and elements, False otherwise. @@ -2323,6 +2323,12 @@ def array_equal(a1, a2): ---------- a1, a2 : array_like Input arrays. + equal_nan : bool + Whether to compare NaN's as equal. If the dtype of a1 and a2 is + complex, values will be considered equal if either the real or the + imaginary component of a given value is ``nan``. + + .. versionadded:: 1.19.0 Returns ------- @@ -2346,7 +2352,21 @@ def array_equal(a1, a2): False >>> np.array_equal([1, 2], [1, 4]) False + >>> a = np.array([1, np.nan]) + >>> np.array_equal(a, a) + False + >>> np.array_equal(a, a, equal_nan=True) + True + When ``equal_nan`` is True, complex values with nan components are + considered equal if either the real *or* the imaginary components are nan. + + >>> a = np.array([1 + 1j]) + >>> b = a.copy() + >>> a.real = np.nan + >>> b.imag = np.nan + >>> np.array_equal(a, b, equal_nan=True) + True """ try: a1, a2 = asarray(a1), asarray(a2) @@ -2354,7 +2374,15 @@ def array_equal(a1, a2): return False if a1.shape != a2.shape: return False - return bool(asarray(a1 == a2).all()) + if not equal_nan: + return bool(asarray(a1 == a2).all()) + # Handling NaN values if equal_nan is True + a1nan, a2nan = isnan(a1), isnan(a2) + # NaN's occur at different locations + if not (a1nan == a2nan).all(): + return False + # Shapes of a1, a2 and masks are guaranteed to be consistent by this point + return bool(asarray(a1[~a1nan] == a2[~a1nan]).all()) def _array_equiv_dispatcher(a1, a2): diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index bcc6a0c4e..acd442e2f 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -1446,6 +1446,36 @@ class TestArrayComparisons: assert_(res) assert_(type(res) is bool) + def test_array_equal_equal_nan(self): + # Test array_equal with equal_nan kwarg + a1 = np.array([1, 2, np.nan]) + a2 = np.array([1, np.nan, 2]) + a3 = np.array([1, 2, np.inf]) + + # equal_nan=False by default + assert_(not np.array_equal(a1, a1)) + assert_(np.array_equal(a1, a1, equal_nan=True)) + assert_(not np.array_equal(a1, a2, equal_nan=True)) + # nan's not conflated with inf's + assert_(not np.array_equal(a1, a3, equal_nan=True)) + # 0-D arrays + a = np.array(np.nan) + assert_(not np.array_equal(a, a)) + assert_(np.array_equal(a, a, equal_nan=True)) + # Non-float dtype - equal_nan should have no effect + a = np.array([1, 2, 3], dtype=int) + assert_(np.array_equal(a, a)) + assert_(np.array_equal(a, a, equal_nan=True)) + # Multi-dimensional array + a = np.array([[0, 1], [np.nan, 1]]) + assert_(not np.array_equal(a, a)) + assert_(np.array_equal(a, a, equal_nan=True)) + # Complex values + a, b = [np.array([1 + 1j])]*2 + a.real, b.imag = np.nan, np.nan + assert_(not np.array_equal(a, b, equal_nan=False)) + assert_(np.array_equal(a, b, equal_nan=True)) + def test_none_compares_elementwise(self): a = np.array([None, 1, None], dtype=object) assert_equal(a == None, [True, False, True]) |