diff options
author | Joe Kington <joferkington@gmail.com> | 2012-03-03 22:45:15 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2012-03-04 15:06:29 -0700 |
commit | 56e3e526e026125ba16d30339c4411042b950b06 (patch) | |
tree | eb708a61adb8d26361bcee4f84cee030617d1c3d | |
parent | 63394270fc28dc5615fe018728af894a2ffb2858 (diff) | |
download | numpy-56e3e526e026125ba16d30339c4411042b950b06.tar.gz |
ENH: Allow "isclose()" to work with subclasses of ndarray (such as masked arrays).
-rw-r--r-- | numpy/core/numeric.py | 17 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 6 |
2 files changed, 16 insertions, 7 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 60111e1f2..e4693431c 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -2086,19 +2086,24 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): if isscalar(a) and isscalar(b): result = bool(result) return result - x = array(a, copy=False, ndmin=1) - y = array(b, copy=False, ndmin=1) + x = array(a, copy=False, subok=True, ndmin=1) + y = array(b, copy=False, subok=True, ndmin=1) xfin = isfinite(x) yfin = isfinite(y) if all(xfin) and all(yfin): return within_tol(x, y, atol, rtol) else: + finite = xfin & yfin + # Because we're using boolean indexing, x & y must be the same shape. + # Ideally, we'd just do x, y = broadcast_arrays(x, y). It's in + # lib.stride_tricks, though, so we can't import it here. + cond = zeros_like(finite, subok=True) + x = x * ones_like(cond) + y = y * ones_like(cond) # Avoid subtraction with infinite/nan values... - cond = zeros(broadcast(x, y).shape, dtype=bool) - mask = xfin & yfin - cond[mask] = within_tol(x[mask], y[mask], atol, rtol) + cond[finite] = within_tol(x[finite], y[finite], atol, rtol) # Check for equality of infinite values... - cond[~mask] = (x[~mask] == y[~mask]) + cond[~finite] = (x[~finite] == y[~finite]) if equal_nan: # Make NaN == NaN cond[isnan(x) & isnan(y)] = True diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 90af9b0d7..7c1ba2dba 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -1230,11 +1230,15 @@ class TestIsclose(object): tests = [([inf, 0], [inf, atol*2]), ([atol, 1, 1e6*(1 + 2*rtol) + atol], [0, nan, 1e6]), - (arange(3), [0, 1, 2.1]) + (arange(3), [0, 1, 2.1]), + (nan, [nan, nan, nan]), + ([0], [atol, inf, -inf, nan]), ] results = [[True, False], [True, False, False], [True, True, False], + [False, False, False], + [True, False, False, False], ] for (x, y), result in zip(tests, results): |