diff options
author | Pauli Virtanen <pav@iki.fi> | 2010-07-18 20:55:33 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2010-07-18 20:55:33 +0000 |
commit | 47702d223e7696419faf73b0d881eb1be9a7a8b0 (patch) | |
tree | 294ae30a74682c003b760bb33c573b4d5a04f449 | |
parent | eab3c1e98d296d6307b0f89c33c57319d3ac7975 (diff) | |
download | numpy-47702d223e7696419faf73b0d881eb1be9a7a8b0.tar.gz |
BUG: core/umath: make complex number comparisons False when *either* element is nan
This also fixes NaN propagation in np.maximum/minimum, for example in
the case for
maximum(1, complex(0, nan)) -> (0, nan)
which previously yielded (1, 0).
-rw-r--r-- | numpy/core/src/umath/loops.c.src | 12 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 27 |
3 files changed, 35 insertions, 8 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index cef45bdcb..e77da0691 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -1299,10 +1299,14 @@ NPY_NO_EXPORT void ***************************************************************************** */ -#define CGE(xr,xi,yr,yi) (xr > yr || (xr == yr && xi >= yi)) -#define CLE(xr,xi,yr,yi) (xr < yr || (xr == yr && xi <= yi)) -#define CGT(xr,xi,yr,yi) (xr > yr || (xr == yr && xi > yi)) -#define CLT(xr,xi,yr,yi) (xr < yr || (xr == yr && xi < yi)) +#define CGE(xr,xi,yr,yi) ((xr > yr && !npy_isnan(xi) && !npy_isnan(yi)) \ + || (xr == yr && xi >= yi)) +#define CLE(xr,xi,yr,yi) ((xr < yr && !npy_isnan(xi) && !npy_isnan(yi)) \ + || (xr == yr && xi <= yi)) +#define CGT(xr,xi,yr,yi) ((xr > yr && !npy_isnan(xi) && !npy_isnan(yi)) \ + || (xr == yr && xi > yi)) +#define CLT(xr,xi,yr,yi) ((xr < yr && !npy_isnan(xi) && !npy_isnan(yi)) \ + || (xr == yr && xi < yi)) #define CEQ(xr,xi,yr,yi) (xr == yr && xi == yi) #define CNE(xr,xi,yr,yi) (xr != yr || xi != yi) diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index 01cf5dad4..b676e633c 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -1338,5 +1338,9 @@ class TestRegression(TestCase): assert_(ret is out) assert_array_equal(ret, data.std(axis=1)) + def test_complex_nan_maximum(self): + cnan = complex(0, np.nan) + assert_equal(np.maximum(1, cnan), cnan) + if __name__ == "__main__": run_module_suite() diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index 6fcb2dcf0..2741287ec 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -378,7 +378,7 @@ class TestMaximum(TestCase): def test_complex_nans(self): nan = np.nan - for cnan in [nan, nan*1j, nan + nan*1j] : + for cnan in [complex(nan, 0), complex(0, nan), complex(nan, nan)] : arg1 = np.array([0, cnan, cnan], dtype=np.complex) arg2 = np.array([cnan, 0, cnan], dtype=np.complex) out = np.array([nan, nan, nan], dtype=np.complex) @@ -399,7 +399,7 @@ class TestMinimum(TestCase): def test_complex_nans(self): nan = np.nan - for cnan in [nan, nan*1j, nan + nan*1j] : + for cnan in [complex(nan, 0), complex(0, nan), complex(nan, nan)] : arg1 = np.array([0, cnan, cnan], dtype=np.complex) arg2 = np.array([cnan, 0, cnan], dtype=np.complex) out = np.array([nan, nan, nan], dtype=np.complex) @@ -420,7 +420,7 @@ class TestFmax(TestCase): def test_complex_nans(self): nan = np.nan - for cnan in [nan, nan*1j, nan + nan*1j] : + for cnan in [complex(nan, 0), complex(0, nan), complex(nan, nan)] : arg1 = np.array([0, cnan, cnan], dtype=np.complex) arg2 = np.array([cnan, 0, cnan], dtype=np.complex) out = np.array([0, 0, nan], dtype=np.complex) @@ -441,7 +441,7 @@ class TestFmin(TestCase): def test_complex_nans(self): nan = np.nan - for cnan in [nan, nan*1j, nan + nan*1j] : + for cnan in [complex(nan, 0), complex(0, nan), complex(nan, nan)] : arg1 = np.array([0, cnan, cnan], dtype=np.complex) arg2 = np.array([cnan, 0, cnan], dtype=np.complex) out = np.array([0, 0, nan], dtype=np.complex) @@ -1050,5 +1050,24 @@ def test_reduceat(): assert_array_almost_equal(h1, h2) +def test_complex_nan_comparisons(): + nans = [complex(np.nan, 0), complex(0, np.nan), complex(np.nan, np.nan)] + fins = [complex(1, 0), complex(-1, 0), complex(0, 1), complex(0, -1), + complex(1, 1), complex(-1, -1), complex(0, 0)] + + for x in nans + fins: + x = np.array([x]) + for y in nans + fins: + y = np.array([y]) + + if np.isfinite(x) and np.isfinite(y): + continue + + assert_equal(x < y, False, err_msg="%r < %r" % (x, y)) + assert_equal(x > y, False, err_msg="%r > %r" % (x, y)) + assert_equal(x <= y, False, err_msg="%r <= %r" % (x, y)) + assert_equal(x >= y, False, err_msg="%r >= %r" % (x, y)) + assert_equal(x == y, False, err_msg="%r == %r" % (x, y)) + if __name__ == "__main__": run_module_suite() |