summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoss Barnowski <rossbar@berkeley.edu>2020-05-11 10:18:37 -0700
committerGitHub <noreply@github.com>2020-05-11 12:18:37 -0500
commitba9f93d8ad9763ef9867bac5761a40fa4770aa82 (patch)
treef00f84508a9eb341aa39f3ce54a3d440a263d15e
parent9ff05dc5d5afd4e801f1acf27b33ca2ed74c48f8 (diff)
downloadnumpy-ba9f93d8ad9763ef9867bac5761a40fa4770aa82.tar.gz
ENH: Add equal_nan keyword argument to array_equal (gh-16128)
Match the corresponding kwarg and behavior from related functions like allclose. Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
-rw-r--r--doc/release/upcoming_changes/16128.new_feature.rst6
-rw-r--r--numpy/core/numeric.py34
-rw-r--r--numpy/core/tests/test_numeric.py30
3 files changed, 67 insertions, 3 deletions
diff --git a/doc/release/upcoming_changes/16128.new_feature.rst b/doc/release/upcoming_changes/16128.new_feature.rst
new file mode 100644
index 000000000..03b061c07
--- /dev/null
+++ b/doc/release/upcoming_changes/16128.new_feature.rst
@@ -0,0 +1,6 @@
+``equal_nan`` parameter for `numpy.array_equal`
+------------------------------------------------
+The keyword argument ``equal_nan`` was added to `numpy.array_equal`.
+``equal_nan`` is a boolean value that toggles whether or not ``nan`` values
+are considered equal in comparison (default is ``False``). This matches API
+used in related functions such as `numpy.isclose` and `numpy.allclose`.
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 40c7a6ae6..8bd4e241b 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -2310,12 +2310,12 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
return cond[()] # Flatten 0d arrays to scalars
-def _array_equal_dispatcher(a1, a2):
+def _array_equal_dispatcher(a1, a2, equal_nan=None):
return (a1, a2)
@array_function_dispatch(_array_equal_dispatcher)
-def array_equal(a1, a2):
+def array_equal(a1, a2, equal_nan=False):
"""
True if two arrays have the same shape and elements, False otherwise.
@@ -2323,6 +2323,12 @@ def array_equal(a1, a2):
----------
a1, a2 : array_like
Input arrays.
+ equal_nan : bool
+ Whether to compare NaN's as equal. If the dtype of a1 and a2 is
+ complex, values will be considered equal if either the real or the
+ imaginary component of a given value is ``nan``.
+
+ .. versionadded:: 1.19.0
Returns
-------
@@ -2346,7 +2352,21 @@ def array_equal(a1, a2):
False
>>> np.array_equal([1, 2], [1, 4])
False
+ >>> a = np.array([1, np.nan])
+ >>> np.array_equal(a, a)
+ False
+ >>> np.array_equal(a, a, equal_nan=True)
+ True
+ When ``equal_nan`` is True, complex values with nan components are
+ considered equal if either the real *or* the imaginary components are nan.
+
+ >>> a = np.array([1 + 1j])
+ >>> b = a.copy()
+ >>> a.real = np.nan
+ >>> b.imag = np.nan
+ >>> np.array_equal(a, b, equal_nan=True)
+ True
"""
try:
a1, a2 = asarray(a1), asarray(a2)
@@ -2354,7 +2374,15 @@ def array_equal(a1, a2):
return False
if a1.shape != a2.shape:
return False
- return bool(asarray(a1 == a2).all())
+ if not equal_nan:
+ return bool(asarray(a1 == a2).all())
+ # Handling NaN values if equal_nan is True
+ a1nan, a2nan = isnan(a1), isnan(a2)
+ # NaN's occur at different locations
+ if not (a1nan == a2nan).all():
+ return False
+ # Shapes of a1, a2 and masks are guaranteed to be consistent by this point
+ return bool(asarray(a1[~a1nan] == a2[~a1nan]).all())
def _array_equiv_dispatcher(a1, a2):
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index bcc6a0c4e..acd442e2f 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1446,6 +1446,36 @@ class TestArrayComparisons:
assert_(res)
assert_(type(res) is bool)
+ def test_array_equal_equal_nan(self):
+ # Test array_equal with equal_nan kwarg
+ a1 = np.array([1, 2, np.nan])
+ a2 = np.array([1, np.nan, 2])
+ a3 = np.array([1, 2, np.inf])
+
+ # equal_nan=False by default
+ assert_(not np.array_equal(a1, a1))
+ assert_(np.array_equal(a1, a1, equal_nan=True))
+ assert_(not np.array_equal(a1, a2, equal_nan=True))
+ # nan's not conflated with inf's
+ assert_(not np.array_equal(a1, a3, equal_nan=True))
+ # 0-D arrays
+ a = np.array(np.nan)
+ assert_(not np.array_equal(a, a))
+ assert_(np.array_equal(a, a, equal_nan=True))
+ # Non-float dtype - equal_nan should have no effect
+ a = np.array([1, 2, 3], dtype=int)
+ assert_(np.array_equal(a, a))
+ assert_(np.array_equal(a, a, equal_nan=True))
+ # Multi-dimensional array
+ a = np.array([[0, 1], [np.nan, 1]])
+ assert_(not np.array_equal(a, a))
+ assert_(np.array_equal(a, a, equal_nan=True))
+ # Complex values
+ a, b = [np.array([1 + 1j])]*2
+ a.real, b.imag = np.nan, np.nan
+ assert_(not np.array_equal(a, b, equal_nan=False))
+ assert_(np.array_equal(a, b, equal_nan=True))
+
def test_none_compares_elementwise(self):
a = np.array([None, 1, None], dtype=object)
assert_equal(a == None, [True, False, True])