diff options
author | Pauli Virtanen <pav@iki.fi> | 2010-07-18 20:55:56 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2010-07-18 20:55:56 +0000 |
commit | e659183918d5bbcb7ddaefe8a7dc6b0025d0a31b (patch) | |
tree | 798fd0db27cf2efb7402d9573b99aa336da5fcdb | |
parent | 47702d223e7696419faf73b0d881eb1be9a7a8b0 (diff) | |
download | numpy-e659183918d5bbcb7ddaefe8a7dc6b0025d0a31b.tar.gz |
BUG: core: fix argmax and argmin NaN handling to conform with max/min (#1429)
This makes `argmax` and `argmix` treat NaN as a maximal element.
Effectively, this causes propagation of NaNs, which is consistent with
the current behavior of amax & amin.
-rw-r--r-- | numpy/core/src/multiarray/arraytypes.c.src | 44 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 27 |
2 files changed, 70 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index 3c6ad10dd..a05289fb0 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -2735,6 +2735,8 @@ finish: * #type = Bool, byte, ubyte, short, ushort, int, uint, long, ulong, * longlong, ulonglong, float, double, longdouble, * float, double, longdouble, datetime, timedelta# + * #isfloat = 0*11, 1*6, 0*2# + * #iscomplex = 0*14, 1*3, 0*2# * #incr = ip++*14, ip+=2*3, ip++*2# */ static int @@ -2742,14 +2744,54 @@ static int { intp i; @type@ mp = *ip; +#if @iscomplex@ + @type@ mp_im = ip[1]; +#endif *max_ind = 0; + +#if @isfloat@ + if (npy_isnan(mp)) { + /* nan encountered; it's maximal */ + return 0; + } +#endif +#if @iscomplex@ + if (npy_isnan(mp_im)) { + /* nan encountered; it's maximal */ + return 0; + } +#endif + for (i = 1; i < n; i++) { @incr@; - if (*ip > mp) { + /* + * Propagate nans, similarly as max() and min() + */ +#if @iscomplex@ + /* Lexical order for complex numbers */ + if ((ip[0] > mp) || ((ip[0] == mp) && (ip[1] > mp_im)) + || npy_isnan(ip[0]) || npy_isnan(ip[1])) { + mp = ip[0]; + mp_im = ip[1]; + *max_ind = i; + if (npy_isnan(mp) || npy_isnan(mp_im)) { + /* nan encountered, it's maximal */ + break; + } + } +#else + if (!(*ip <= mp)) { /* negated, for correct nan handling */ mp = *ip; *max_ind = i; +#if @isfloat@ + if (npy_isnan(mp)) { + /* nan encountered, it's maximal */ + break; + } +#endif } +#endif } return 0; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index b3a4337ca..45cc1c858 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -671,6 +671,27 @@ class TestStringCompare(TestCase): class TestArgmax(TestCase): + + nan_arr = [ + ([0, 1, 2, 3, np.nan], 4), + ([0, 1, 2, np.nan, 3], 3), + ([np.nan, 0, 1, 2, 3], 0), + ([np.nan, 0, np.nan, 2, 3], 0), + ([0, 1, 2, 3, complex(0,np.nan)], 4), + ([0, 1, 2, 3, complex(np.nan,0)], 4), + ([0, 1, 2, complex(np.nan,0), 3], 3), + ([0, 1, 2, complex(0,np.nan), 3], 3), + ([complex(0,np.nan), 0, 1, 2, 3], 0), + ([complex(np.nan, np.nan), 0, 1, 2, 3], 0), + ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0), + ([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0), + ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0), + + ([complex(0, 0), complex(0, 2), complex(0, 1)], 1), + ([complex(1, 0), complex(0, 2), complex(0, 1)], 0), + ([complex(1, 0), complex(0, 2), complex(1, 1)], 2), + ] + def test_all(self): a = np.random.normal(0,1,(4,5,6,7,8)) for i in xrange(a.ndim): @@ -680,6 +701,12 @@ class TestArgmax(TestCase): axes.remove(i) assert all(amax == aargmax.choose(*a.transpose(i,*axes))) + def test_combinations(self): + for arr, pos in self.nan_arr: + assert_equal(np.argmax(arr), pos, err_msg="%r"%arr) + assert_equal(arr[np.argmax(arr)], np.max(arr), err_msg="%r"%arr) + + class TestMinMax(TestCase): def test_scalar(self): assert_raises(ValueError, np.amax, 1, 1) |