summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2010-07-18 20:55:33 +0000
committerPauli Virtanen <pav@iki.fi>2010-07-18 20:55:33 +0000
commit47702d223e7696419faf73b0d881eb1be9a7a8b0 (patch)
tree294ae30a74682c003b760bb33c573b4d5a04f449 /numpy
parenteab3c1e98d296d6307b0f89c33c57319d3ac7975 (diff)
downloadnumpy-47702d223e7696419faf73b0d881eb1be9a7a8b0.tar.gz
BUG: core/umath: make complex number comparisons False when *either* element is nan
This also fixes NaN propagation in np.maximum/minimum, for example in the case for maximum(1, complex(0, nan)) -> (0, nan) which previously yielded (1, 0).
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/umath/loops.c.src12
-rw-r--r--numpy/core/tests/test_regression.py4
-rw-r--r--numpy/core/tests/test_umath.py27
3 files changed, 35 insertions, 8 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index cef45bdcb..e77da0691 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -1299,10 +1299,14 @@ NPY_NO_EXPORT void
*****************************************************************************
*/
-#define CGE(xr,xi,yr,yi) (xr > yr || (xr == yr && xi >= yi))
-#define CLE(xr,xi,yr,yi) (xr < yr || (xr == yr && xi <= yi))
-#define CGT(xr,xi,yr,yi) (xr > yr || (xr == yr && xi > yi))
-#define CLT(xr,xi,yr,yi) (xr < yr || (xr == yr && xi < yi))
+#define CGE(xr,xi,yr,yi) ((xr > yr && !npy_isnan(xi) && !npy_isnan(yi)) \
+ || (xr == yr && xi >= yi))
+#define CLE(xr,xi,yr,yi) ((xr < yr && !npy_isnan(xi) && !npy_isnan(yi)) \
+ || (xr == yr && xi <= yi))
+#define CGT(xr,xi,yr,yi) ((xr > yr && !npy_isnan(xi) && !npy_isnan(yi)) \
+ || (xr == yr && xi > yi))
+#define CLT(xr,xi,yr,yi) ((xr < yr && !npy_isnan(xi) && !npy_isnan(yi)) \
+ || (xr == yr && xi < yi))
#define CEQ(xr,xi,yr,yi) (xr == yr && xi == yi)
#define CNE(xr,xi,yr,yi) (xr != yr || xi != yi)
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index 01cf5dad4..b676e633c 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -1338,5 +1338,9 @@ class TestRegression(TestCase):
assert_(ret is out)
assert_array_equal(ret, data.std(axis=1))
+ def test_complex_nan_maximum(self):
+ cnan = complex(0, np.nan)
+ assert_equal(np.maximum(1, cnan), cnan)
+
if __name__ == "__main__":
run_module_suite()
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 6fcb2dcf0..2741287ec 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -378,7 +378,7 @@ class TestMaximum(TestCase):
def test_complex_nans(self):
nan = np.nan
- for cnan in [nan, nan*1j, nan + nan*1j] :
+ for cnan in [complex(nan, 0), complex(0, nan), complex(nan, nan)] :
arg1 = np.array([0, cnan, cnan], dtype=np.complex)
arg2 = np.array([cnan, 0, cnan], dtype=np.complex)
out = np.array([nan, nan, nan], dtype=np.complex)
@@ -399,7 +399,7 @@ class TestMinimum(TestCase):
def test_complex_nans(self):
nan = np.nan
- for cnan in [nan, nan*1j, nan + nan*1j] :
+ for cnan in [complex(nan, 0), complex(0, nan), complex(nan, nan)] :
arg1 = np.array([0, cnan, cnan], dtype=np.complex)
arg2 = np.array([cnan, 0, cnan], dtype=np.complex)
out = np.array([nan, nan, nan], dtype=np.complex)
@@ -420,7 +420,7 @@ class TestFmax(TestCase):
def test_complex_nans(self):
nan = np.nan
- for cnan in [nan, nan*1j, nan + nan*1j] :
+ for cnan in [complex(nan, 0), complex(0, nan), complex(nan, nan)] :
arg1 = np.array([0, cnan, cnan], dtype=np.complex)
arg2 = np.array([cnan, 0, cnan], dtype=np.complex)
out = np.array([0, 0, nan], dtype=np.complex)
@@ -441,7 +441,7 @@ class TestFmin(TestCase):
def test_complex_nans(self):
nan = np.nan
- for cnan in [nan, nan*1j, nan + nan*1j] :
+ for cnan in [complex(nan, 0), complex(0, nan), complex(nan, nan)] :
arg1 = np.array([0, cnan, cnan], dtype=np.complex)
arg2 = np.array([cnan, 0, cnan], dtype=np.complex)
out = np.array([0, 0, nan], dtype=np.complex)
@@ -1050,5 +1050,24 @@ def test_reduceat():
assert_array_almost_equal(h1, h2)
+def test_complex_nan_comparisons():
+ nans = [complex(np.nan, 0), complex(0, np.nan), complex(np.nan, np.nan)]
+ fins = [complex(1, 0), complex(-1, 0), complex(0, 1), complex(0, -1),
+ complex(1, 1), complex(-1, -1), complex(0, 0)]
+
+ for x in nans + fins:
+ x = np.array([x])
+ for y in nans + fins:
+ y = np.array([y])
+
+ if np.isfinite(x) and np.isfinite(y):
+ continue
+
+ assert_equal(x < y, False, err_msg="%r < %r" % (x, y))
+ assert_equal(x > y, False, err_msg="%r > %r" % (x, y))
+ assert_equal(x <= y, False, err_msg="%r <= %r" % (x, y))
+ assert_equal(x >= y, False, err_msg="%r >= %r" % (x, y))
+ assert_equal(x == y, False, err_msg="%r == %r" % (x, y))
+
if __name__ == "__main__":
run_module_suite()