diff options
Diffstat (limited to 'numpy/core')
| -rw-r--r-- | numpy/core/src/multiarray/arraytypes.c.src | 74 | ||||
| -rw-r--r-- | numpy/core/tests/test_multiarray.py | 21 |
2 files changed, 63 insertions, 32 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index 5a1e2f4d5..58c234745 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -3099,28 +3099,33 @@ static int static int OBJECT_argmax(PyObject **ip, npy_intp n, npy_intp *max_ind, - PyArrayObject *NPY_UNUSED(aip)) + PyArrayObject *NPY_UNUSED(aip)) { npy_intp i; - PyObject *mp = ip[0]; *max_ind = 0; - i = 1; - while (i < n && mp == NULL) { - mp = ip[i]; - i++; - } - for (; i < n; i++) { - ip++; -#if defined(NPY_PY3K) - if (*ip != NULL && PyObject_RichCompareBool(*ip, mp, Py_GT) == 1) { -#else - if (*ip != NULL && PyObject_Compare(*ip, mp) > 0) { -#endif - mp = *ip; - *max_ind = i; + /* Skip over all leading NULL entries */ + for (i = 0; i < n && ip[i] == NULL; ++i); + if (i < n) { + /* Found first non-NULL entry */ + PyObject *mp = ip[i]; + *max_ind = i; + for (i = i + 1; i < n; ++i) { + PyObject *val = ip[i]; + if (val != NULL) { + int greater_than = PyObject_RichCompareBool(val, mp, Py_GT); + + if (greater_than < 0) { + return 0; + } + if (greater_than) { + mp = val; + *max_ind = i; + } + } } } + return 0; } @@ -3158,28 +3163,33 @@ static int static int OBJECT_argmin(PyObject **ip, npy_intp n, npy_intp *min_ind, - PyArrayObject *NPY_UNUSED(aip)) + PyArrayObject *NPY_UNUSED(aip)) { npy_intp i; - PyObject *mp = ip[0]; *min_ind = 0; - i = 1; - while (i < n && mp == NULL) { - mp = ip[i]; - i++; - } - for (; i < n; i++) { - ip++; -#if defined(NPY_PY3K) - if (*ip != NULL && PyObject_RichCompareBool(mp, *ip, Py_GT) == 1) { -#else - if (*ip != NULL && PyObject_Compare(mp, *ip) > 0) { -#endif - mp = *ip; - *min_ind = i; + /* Skip over all leading NULL entries */ + for (i = 0; i < n && ip[i] == NULL; ++i); + if (i < n) { + /* Found first non-NULL entry */ + PyObject *mp = ip[i]; + *min_ind = i; + for (i = i + 1; i < n ; ++i) { + PyObject *val = ip[i]; + if (val != NULL) { + int less_than = PyObject_RichCompareBool(val, mp, Py_LT); + + if (less_than < 0) { + return 0; + } + if (less_than) { + mp = val; + *min_ind = i; + } + } } } + return 0; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 3cab3cf3e..d7e919269 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -8,6 +8,7 @@ import warnings import operator import io import itertools +import ctypes if sys.version_info[0] >= 3: import builtins else: @@ -2802,6 +2803,16 @@ class TestArgmax(TestCase): assert_equal(a.argmax(out=out1, axis=0), np.argmax(a, out=out2, axis=0)) assert_equal(out1, out2) + def test_object_argmax_with_NULLs(self): + # See gh-6032 + a = np.empty(4, dtype='O') + ctypes.memset(a.ctypes.data, 0, a.nbytes) + assert_equal(a.argmax(), 0) + a[3] = 10 + assert_equal(a.argmax(), 3) + a[1] = 30 + assert_equal(a.argmax(), 1) + class TestArgmin(TestCase): @@ -2940,6 +2951,16 @@ class TestArgmin(TestCase): assert_equal(a.argmin(out=out1, axis=0), np.argmin(a, out=out2, axis=0)) assert_equal(out1, out2) + def test_object_argmin_with_NULLs(self): + # See gh-6032 + a = np.empty(4, dtype='O') + ctypes.memset(a.ctypes.data, 0, a.nbytes) + assert_equal(a.argmin(), 0) + a[3] = 30 + assert_equal(a.argmin(), 3) + a[1] = 10 + assert_equal(a.argmin(), 1) + class TestMinMax(TestCase): |
