summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/arraytypes.c.src74
-rw-r--r--numpy/core/tests/test_multiarray.py21
2 files changed, 63 insertions, 32 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src
index 5a1e2f4d5..58c234745 100644
--- a/numpy/core/src/multiarray/arraytypes.c.src
+++ b/numpy/core/src/multiarray/arraytypes.c.src
@@ -3099,28 +3099,33 @@ static int
static int
OBJECT_argmax(PyObject **ip, npy_intp n, npy_intp *max_ind,
- PyArrayObject *NPY_UNUSED(aip))
+ PyArrayObject *NPY_UNUSED(aip))
{
npy_intp i;
- PyObject *mp = ip[0];
*max_ind = 0;
- i = 1;
- while (i < n && mp == NULL) {
- mp = ip[i];
- i++;
- }
- for (; i < n; i++) {
- ip++;
-#if defined(NPY_PY3K)
- if (*ip != NULL && PyObject_RichCompareBool(*ip, mp, Py_GT) == 1) {
-#else
- if (*ip != NULL && PyObject_Compare(*ip, mp) > 0) {
-#endif
- mp = *ip;
- *max_ind = i;
+ /* Skip over all leading NULL entries */
+ for (i = 0; i < n && ip[i] == NULL; ++i);
+ if (i < n) {
+ /* Found first non-NULL entry */
+ PyObject *mp = ip[i];
+ *max_ind = i;
+ for (i = i + 1; i < n; ++i) {
+ PyObject *val = ip[i];
+ if (val != NULL) {
+ int greater_than = PyObject_RichCompareBool(val, mp, Py_GT);
+
+ if (greater_than < 0) {
+ return 0;
+ }
+ if (greater_than) {
+ mp = val;
+ *max_ind = i;
+ }
+ }
}
}
+
return 0;
}
@@ -3158,28 +3163,33 @@ static int
static int
OBJECT_argmin(PyObject **ip, npy_intp n, npy_intp *min_ind,
- PyArrayObject *NPY_UNUSED(aip))
+ PyArrayObject *NPY_UNUSED(aip))
{
npy_intp i;
- PyObject *mp = ip[0];
*min_ind = 0;
- i = 1;
- while (i < n && mp == NULL) {
- mp = ip[i];
- i++;
- }
- for (; i < n; i++) {
- ip++;
-#if defined(NPY_PY3K)
- if (*ip != NULL && PyObject_RichCompareBool(mp, *ip, Py_GT) == 1) {
-#else
- if (*ip != NULL && PyObject_Compare(mp, *ip) > 0) {
-#endif
- mp = *ip;
- *min_ind = i;
+ /* Skip over all leading NULL entries */
+ for (i = 0; i < n && ip[i] == NULL; ++i);
+ if (i < n) {
+ /* Found first non-NULL entry */
+ PyObject *mp = ip[i];
+ *min_ind = i;
+ for (i = i + 1; i < n ; ++i) {
+ PyObject *val = ip[i];
+ if (val != NULL) {
+ int less_than = PyObject_RichCompareBool(val, mp, Py_LT);
+
+ if (less_than < 0) {
+ return 0;
+ }
+ if (less_than) {
+ mp = val;
+ *min_ind = i;
+ }
+ }
}
}
+
return 0;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 3cab3cf3e..d7e919269 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -8,6 +8,7 @@ import warnings
import operator
import io
import itertools
+import ctypes
if sys.version_info[0] >= 3:
import builtins
else:
@@ -2802,6 +2803,16 @@ class TestArgmax(TestCase):
assert_equal(a.argmax(out=out1, axis=0), np.argmax(a, out=out2, axis=0))
assert_equal(out1, out2)
+ def test_object_argmax_with_NULLs(self):
+ # See gh-6032
+ a = np.empty(4, dtype='O')
+ ctypes.memset(a.ctypes.data, 0, a.nbytes)
+ assert_equal(a.argmax(), 0)
+ a[3] = 10
+ assert_equal(a.argmax(), 3)
+ a[1] = 30
+ assert_equal(a.argmax(), 1)
+
class TestArgmin(TestCase):
@@ -2940,6 +2951,16 @@ class TestArgmin(TestCase):
assert_equal(a.argmin(out=out1, axis=0), np.argmin(a, out=out2, axis=0))
assert_equal(out1, out2)
+ def test_object_argmin_with_NULLs(self):
+ # See gh-6032
+ a = np.empty(4, dtype='O')
+ ctypes.memset(a.ctypes.data, 0, a.nbytes)
+ assert_equal(a.argmin(), 0)
+ a[3] = 30
+ assert_equal(a.argmin(), 3)
+ a[1] = 10
+ assert_equal(a.argmin(), 1)
+
class TestMinMax(TestCase):