summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2021-07-12 09:56:19 +0300
committerGitHub <noreply@github.com>2021-07-12 09:56:19 +0300
commitc750d24ef3a9a897fb19d3f0b4910f07f078e302 (patch)
treeb12051442b5659b1e8c2d83451aaf7ca8454f4e6
parent5fafca2184a56e11b109aac75660a3eef75de425 (diff)
parent8f854122e90241e979db6d30bcd7544b9d385e3b (diff)
downloadnumpy-c750d24ef3a9a897fb19d3f0b4910f07f078e302.tar.gz
Merge pull request #19440 from czgdp1807/refac
MAINT: factored out _PyArray_ArgMinMaxCommon
-rw-r--r--numpy/core/src/multiarray/calculation.c176
-rw-r--r--numpy/core/tests/test_multiarray.py251
2 files changed, 140 insertions, 287 deletions
diff --git a/numpy/core/src/multiarray/calculation.c b/numpy/core/src/multiarray/calculation.c
index e89018889..21e52c32b 100644
--- a/numpy/core/src/multiarray/calculation.c
+++ b/numpy/core/src/multiarray/calculation.c
@@ -35,12 +35,13 @@ power_of_ten(int n)
}
NPY_NO_EXPORT PyObject *
-_PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
- int axis, PyArrayObject *out, int keepdims)
+_PyArray_ArgMinMaxCommon(PyArrayObject *op,
+ int axis, PyArrayObject *out, int keepdims,
+ npy_bool is_argmax)
{
PyArrayObject *ap = NULL, *rp = NULL;
- PyArray_ArgFunc* arg_func;
- char *ip;
+ PyArray_ArgFunc* arg_func = NULL;
+ char *ip, *func_name;
npy_intp *rptr;
npy_intp i, n, m;
int elsize;
@@ -115,7 +116,14 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
}
}
- arg_func = PyArray_DESCR(ap)->f->argmax;
+ if (is_argmax) {
+ func_name = "argmax";
+ arg_func = PyArray_DESCR(ap)->f->argmax;
+ }
+ else {
+ func_name = "argmin";
+ arg_func = PyArray_DESCR(ap)->f->argmin;
+ }
if (arg_func == NULL) {
PyErr_SetString(PyExc_TypeError,
"data type not ordered");
@@ -124,8 +132,9 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
elsize = PyArray_DESCR(ap)->elsize;
m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1];
if (m == 0) {
- PyErr_SetString(PyExc_ValueError,
- "attempt to get argmax of an empty sequence");
+ PyErr_Format(PyExc_ValueError,
+ "attempt to get %s of an empty sequence",
+ func_name);
goto fail;
}
@@ -142,8 +151,9 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
if ((PyArray_NDIM(out) != out_ndim) ||
!PyArray_CompareLists(PyArray_DIMS(out), out_shape,
out_ndim)) {
- PyErr_SetString(PyExc_ValueError,
- "output array does not match result of np.argmax.");
+ PyErr_Format(PyExc_ValueError,
+ "output array does not match result of np.%s.",
+ func_name);
goto fail;
}
rp = (PyArrayObject *)PyArray_FromArray(out,
@@ -179,155 +189,27 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
return NULL;
}
+NPY_NO_EXPORT PyObject*
+_PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
+ int axis, PyArrayObject *out, int keepdims)
+{
+ return _PyArray_ArgMinMaxCommon(op, axis, out, keepdims, 1);
+}
+
/*NUMPY_API
* ArgMax
*/
NPY_NO_EXPORT PyObject *
PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
{
- return _PyArray_ArgMaxWithKeepdims(op, axis, out, 0);
+ return _PyArray_ArgMinMaxCommon(op, axis, out, 0, 1);
}
NPY_NO_EXPORT PyObject *
_PyArray_ArgMinWithKeepdims(PyArrayObject *op,
int axis, PyArrayObject *out, int keepdims)
{
- PyArrayObject *ap = NULL, *rp = NULL;
- PyArray_ArgFunc* arg_func;
- char *ip;
- npy_intp *rptr;
- npy_intp i, n, m;
- int elsize;
- // Keep a copy because axis changes via call to PyArray_CheckAxis
- int axis_copy = axis;
- npy_intp _shape_buf[NPY_MAXDIMS];
- npy_intp *out_shape;
- // Keep the number of dimensions and shape of
- // original array. Helps when `keepdims` is True.
- npy_intp* original_op_shape = PyArray_DIMS(op);
- int out_ndim = PyArray_NDIM(op);
- NPY_BEGIN_THREADS_DEF;
-
- if ((ap = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
- return NULL;
- }
- /*
- * We need to permute the array so that axis is placed at the end.
- * And all other dimensions are shifted left.
- */
- if (axis != PyArray_NDIM(ap)-1) {
- PyArray_Dims newaxes;
- npy_intp dims[NPY_MAXDIMS];
- int i;
-
- newaxes.ptr = dims;
- newaxes.len = PyArray_NDIM(ap);
- for (i = 0; i < axis; i++) {
- dims[i] = i;
- }
- for (i = axis; i < PyArray_NDIM(ap) - 1; i++) {
- dims[i] = i + 1;
- }
- dims[PyArray_NDIM(ap) - 1] = axis;
- op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes);
- Py_DECREF(ap);
- if (op == NULL) {
- return NULL;
- }
- }
- else {
- op = ap;
- }
-
- /* Will get native-byte order contiguous copy. */
- ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)op,
- PyArray_DESCR(op)->type_num, 1, 0);
- Py_DECREF(op);
- if (ap == NULL) {
- return NULL;
- }
-
- // Decides the shape of the output array.
- if (!keepdims) {
- out_ndim = PyArray_NDIM(ap) - 1;
- out_shape = PyArray_DIMS(ap);
- } else {
- out_shape = _shape_buf;
- if (axis_copy == NPY_MAXDIMS) {
- for (int i = 0; i < out_ndim; i++) {
- out_shape[i] = 1;
- }
- } else {
- /*
- * While `ap` may be transposed, we can ignore this for `out` because the
- * transpose only reorders the size 1 `axis` (not changing memory layout).
- */
- memcpy(out_shape, original_op_shape, out_ndim * sizeof(npy_intp));
- out_shape[axis] = 1;
- }
- }
-
- arg_func = PyArray_DESCR(ap)->f->argmin;
- if (arg_func == NULL) {
- PyErr_SetString(PyExc_TypeError,
- "data type not ordered");
- goto fail;
- }
- elsize = PyArray_DESCR(ap)->elsize;
- m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1];
- if (m == 0) {
- PyErr_SetString(PyExc_ValueError,
- "attempt to get argmin of an empty sequence");
- goto fail;
- }
-
- if (!out) {
- rp = (PyArrayObject *)PyArray_NewFromDescr(
- Py_TYPE(ap), PyArray_DescrFromType(NPY_INTP),
- out_ndim, out_shape, NULL, NULL,
- 0, (PyObject *)ap);
- if (rp == NULL) {
- goto fail;
- }
- }
- else {
- if ((PyArray_NDIM(out) != out_ndim) ||
- !PyArray_CompareLists(PyArray_DIMS(out), out_shape, out_ndim)) {
- PyErr_SetString(PyExc_ValueError,
- "output array does not match result of np.argmin.");
- goto fail;
- }
- rp = (PyArrayObject *)PyArray_FromArray(out,
- PyArray_DescrFromType(NPY_INTP),
- NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY);
- if (rp == NULL) {
- goto fail;
- }
- }
-
- NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap));
- n = PyArray_SIZE(ap)/m;
- rptr = (npy_intp *)PyArray_DATA(rp);
- for (ip = PyArray_DATA(ap), i = 0; i < n; i++, ip += elsize*m) {
- arg_func(ip, m, rptr, ap);
- rptr += 1;
- }
- NPY_END_THREADS_DESCR(PyArray_DESCR(ap));
-
- Py_DECREF(ap);
- /* Trigger the UPDATEIFCOPY/WRITEBACKIFCOPY if necessary */
- if (out != NULL && out != rp) {
- PyArray_ResolveWritebackIfCopy(rp);
- Py_DECREF(rp);
- rp = out;
- Py_INCREF(rp);
- }
- return (PyObject *)rp;
-
- fail:
- Py_DECREF(ap);
- Py_XDECREF(rp);
- return NULL;
+ return _PyArray_ArgMinMaxCommon(op, axis, out, keepdims, 0);
}
/*NUMPY_API
@@ -336,7 +218,7 @@ _PyArray_ArgMinWithKeepdims(PyArrayObject *op,
NPY_NO_EXPORT PyObject *
PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
{
- return _PyArray_ArgMinWithKeepdims(op, axis, out, 0);
+ return _PyArray_ArgMinMaxCommon(op, axis, out, 0, 0);
}
/*NUMPY_API
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 7c8fc8e3e..5f0a725d2 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4275,6 +4275,94 @@ class TestArgmaxArgminCommon:
method(arr.T, axis=axis,
out=wrong_outarray, keepdims=True)
+ @pytest.mark.parametrize('method', ['max', 'min'])
+ def test_all(self, method):
+ a = np.random.normal(0, 1, (4, 5, 6, 7, 8))
+ arg_method = getattr(a, 'arg' + method)
+ val_method = getattr(a, method)
+ for i in range(a.ndim):
+ a_maxmin = val_method(i)
+ aarg_maxmin = arg_method(i)
+ axes = list(range(a.ndim))
+ axes.remove(i)
+ assert_(np.all(a_maxmin == aarg_maxmin.choose(
+ *a.transpose(i, *axes))))
+
+ @pytest.mark.parametrize('method', ['argmax', 'argmin'])
+ def test_output_shape(self, method):
+ # see also gh-616
+ a = np.ones((10, 5))
+ arg_method = getattr(a, method)
+ # Check some simple shape mismatches
+ out = np.ones(11, dtype=np.int_)
+ assert_raises(ValueError, arg_method, -1, out)
+
+ out = np.ones((2, 5), dtype=np.int_)
+ assert_raises(ValueError, arg_method, -1, out)
+
+ # these could be relaxed possibly (used to allow even the previous)
+ out = np.ones((1, 10), dtype=np.int_)
+ assert_raises(ValueError, arg_method, -1, out)
+
+ out = np.ones(10, dtype=np.int_)
+ arg_method(-1, out=out)
+ assert_equal(out, arg_method(-1))
+
+ @pytest.mark.parametrize('ndim', [0, 1])
+ @pytest.mark.parametrize('method', ['argmax', 'argmin'])
+ def test_ret_is_out(self, ndim, method):
+ a = np.ones((4,) + (3,)*ndim)
+ arg_method = getattr(a, method)
+ out = np.empty((3,)*ndim, dtype=np.intp)
+ ret = arg_method(axis=0, out=out)
+ assert ret is out
+
+ @pytest.mark.parametrize('np_array, method, idx, val',
+ [(np.zeros, 'argmax', 5942, "as"),
+ (np.ones, 'argmin', 6001, "0")])
+ def test_unicode(self, np_array, method, idx, val):
+ d = np_array(6031, dtype='<U9')
+ arg_method = getattr(d, method)
+ d[idx] = val
+ assert_equal(arg_method(), idx)
+
+ @pytest.mark.parametrize('arr_method, np_method',
+ [('argmax', np.argmax),
+ ('argmin', np.argmin)])
+ def test_np_vs_ndarray(self, arr_method, np_method):
+ # make sure both ndarray.argmax/argmin and
+ # numpy.argmax/argmin support out/axis args
+ a = np.random.normal(size=(2, 3))
+ arg_method = getattr(a, arr_method)
+
+ # check positional args
+ out1 = np.zeros(2, dtype=int)
+ out2 = np.zeros(2, dtype=int)
+ assert_equal(arg_method(1, out1), np_method(a, 1, out2))
+ assert_equal(out1, out2)
+
+ # check keyword args
+ out1 = np.zeros(3, dtype=int)
+ out2 = np.zeros(3, dtype=int)
+ assert_equal(arg_method(out=out1, axis=0),
+ np_method(a, out=out2, axis=0))
+ assert_equal(out1, out2)
+
+ @pytest.mark.leaks_references(reason="replaces None with NULL.")
+ @pytest.mark.parametrize('method, vals',
+ [('argmax', (10, 30)),
+ ('argmin', (30, 10))])
+ def test_object_with_NULLs(self, method, vals):
+ # See gh-6032
+ a = np.empty(4, dtype='O')
+ arg_method = getattr(a, method)
+ ctypes.memset(a.ctypes.data, 0, a.nbytes)
+ assert_equal(arg_method(), 0)
+ a[3] = vals[0]
+ assert_equal(arg_method(), 3)
+ a[1] = vals[1]
+ assert_equal(arg_method(), 1)
+
class TestArgmax:
nan_arr = [
@@ -4340,81 +4428,30 @@ class TestArgmax:
([True, False, True, False, False], 0),
]
- def test_all(self):
- a = np.random.normal(0, 1, (4, 5, 6, 7, 8))
- for i in range(a.ndim):
- amax = a.max(i)
- aargmax = a.argmax(i)
- axes = list(range(a.ndim))
- axes.remove(i)
- assert_(np.all(amax == aargmax.choose(*a.transpose(i,*axes))))
-
- def test_combinations(self):
- for arr, pos in self.nan_arr:
- with suppress_warnings() as sup:
- sup.filter(RuntimeWarning,
- "invalid value encountered in reduce")
- max_val = np.max(arr)
-
- assert_equal(np.argmax(arr), pos, err_msg="%r" % arr)
- assert_equal(arr[np.argmax(arr)], max_val, err_msg="%r" % arr)
-
- def test_output_shape(self):
- # see also gh-616
- a = np.ones((10, 5))
- # Check some simple shape mismatches
- out = np.ones(11, dtype=np.int_)
- assert_raises(ValueError, a.argmax, -1, out)
-
- out = np.ones((2, 5), dtype=np.int_)
- assert_raises(ValueError, a.argmax, -1, out)
-
- # these could be relaxed possibly (used to allow even the previous)
- out = np.ones((1, 10), dtype=np.int_)
- assert_raises(ValueError, a.argmax, -1, out)
-
- out = np.ones(10, dtype=np.int_)
- a.argmax(-1, out=out)
- assert_equal(out, a.argmax(-1))
+ @pytest.mark.parametrize('data', nan_arr)
+ def test_combinations(self, data):
+ arr, pos = data
+ with suppress_warnings() as sup:
+ sup.filter(RuntimeWarning,
+ "invalid value encountered in reduce")
+ val = np.max(arr)
- @pytest.mark.parametrize('ndim', [0, 1])
- def test_ret_is_out(self, ndim):
- a = np.ones((4,) + (3,)*ndim)
- out = np.empty((3,)*ndim, dtype=np.intp)
- ret = a.argmax(axis=0, out=out)
- assert ret is out
+ assert_equal(np.argmax(arr), pos, err_msg="%r" % arr)
+ assert_equal(arr[np.argmax(arr)], val, err_msg="%r" % arr)
+
+ def test_maximum_signed_integers(self):
- def test_argmax_unicode(self):
- d = np.zeros(6031, dtype='<U9')
- d[5942] = "as"
- assert_equal(d.argmax(), 5942)
+ a = np.array([1, 2**7 - 1, -2**7], dtype=np.int8)
+ assert_equal(np.argmax(a), 1)
- def test_np_vs_ndarray(self):
- # make sure both ndarray.argmax and numpy.argmax support out/axis args
- a = np.random.normal(size=(2,3))
+ a = np.array([1, 2**15 - 1, -2**15], dtype=np.int16)
+ assert_equal(np.argmax(a), 1)
- # check positional args
- out1 = np.zeros(2, dtype=int)
- out2 = np.zeros(2, dtype=int)
- assert_equal(a.argmax(1, out1), np.argmax(a, 1, out2))
- assert_equal(out1, out2)
-
- # check keyword args
- out1 = np.zeros(3, dtype=int)
- out2 = np.zeros(3, dtype=int)
- assert_equal(a.argmax(out=out1, axis=0), np.argmax(a, out=out2, axis=0))
- assert_equal(out1, out2)
+ a = np.array([1, 2**31 - 1, -2**31], dtype=np.int32)
+ assert_equal(np.argmax(a), 1)
- @pytest.mark.leaks_references(reason="replaces None with NULL.")
- def test_object_argmax_with_NULLs(self):
- # See gh-6032
- a = np.empty(4, dtype='O')
- ctypes.memset(a.ctypes.data, 0, a.nbytes)
- assert_equal(a.argmax(), 0)
- a[3] = 10
- assert_equal(a.argmax(), 3)
- a[1] = 30
- assert_equal(a.argmax(), 1)
+ a = np.array([1, 2**63 - 1, -2**63], dtype=np.int64)
+ assert_equal(np.argmax(a), 1)
class TestArgmin:
@@ -4482,15 +4519,6 @@ class TestArgmin:
([False, True, False, True, True], 0),
]
- def test_all(self):
- a = np.random.normal(0, 1, (4, 5, 6, 7, 8))
- for i in range(a.ndim):
- amin = a.min(i)
- aargmin = a.argmin(i)
- axes = list(range(a.ndim))
- axes.remove(i)
- assert_(np.all(amin == aargmin.choose(*a.transpose(i,*axes))))
-
def test_combinations(self):
for arr, pos in self.nan_arr:
with suppress_warnings() as sup:
@@ -4503,75 +4531,18 @@ class TestArgmin:
def test_minimum_signed_integers(self):
- a = np.array([1, -2**7, -2**7 + 1], dtype=np.int8)
+ a = np.array([1, -2**7, -2**7 + 1, 2**7 - 1], dtype=np.int8)
assert_equal(np.argmin(a), 1)
- a = np.array([1, -2**15, -2**15 + 1], dtype=np.int16)
+ a = np.array([1, -2**15, -2**15 + 1, 2**15 - 1], dtype=np.int16)
assert_equal(np.argmin(a), 1)
- a = np.array([1, -2**31, -2**31 + 1], dtype=np.int32)
+ a = np.array([1, -2**31, -2**31 + 1, 2**31 - 1], dtype=np.int32)
assert_equal(np.argmin(a), 1)
- a = np.array([1, -2**63, -2**63 + 1], dtype=np.int64)
+ a = np.array([1, -2**63, -2**63 + 1, 2**63 - 1], dtype=np.int64)
assert_equal(np.argmin(a), 1)
- def test_output_shape(self):
- # see also gh-616
- a = np.ones((10, 5))
- # Check some simple shape mismatches
- out = np.ones(11, dtype=np.int_)
- assert_raises(ValueError, a.argmin, -1, out)
-
- out = np.ones((2, 5), dtype=np.int_)
- assert_raises(ValueError, a.argmin, -1, out)
-
- # these could be relaxed possibly (used to allow even the previous)
- out = np.ones((1, 10), dtype=np.int_)
- assert_raises(ValueError, a.argmin, -1, out)
-
- out = np.ones(10, dtype=np.int_)
- a.argmin(-1, out=out)
- assert_equal(out, a.argmin(-1))
-
- @pytest.mark.parametrize('ndim', [0, 1])
- def test_ret_is_out(self, ndim):
- a = np.ones((4,) + (3,)*ndim)
- out = np.empty((3,)*ndim, dtype=np.intp)
- ret = a.argmin(axis=0, out=out)
- assert ret is out
-
- def test_argmin_unicode(self):
- d = np.ones(6031, dtype='<U9')
- d[6001] = "0"
- assert_equal(d.argmin(), 6001)
-
- def test_np_vs_ndarray(self):
- # make sure both ndarray.argmin and numpy.argmin support out/axis args
- a = np.random.normal(size=(2, 3))
-
- # check positional args
- out1 = np.zeros(2, dtype=int)
- out2 = np.ones(2, dtype=int)
- assert_equal(a.argmin(1, out1), np.argmin(a, 1, out2))
- assert_equal(out1, out2)
-
- # check keyword args
- out1 = np.zeros(3, dtype=int)
- out2 = np.ones(3, dtype=int)
- assert_equal(a.argmin(out=out1, axis=0), np.argmin(a, out=out2, axis=0))
- assert_equal(out1, out2)
-
- @pytest.mark.leaks_references(reason="replaces None with NULL.")
- def test_object_argmin_with_NULLs(self):
- # See gh-6032
- a = np.empty(4, dtype='O')
- ctypes.memset(a.ctypes.data, 0, a.nbytes)
- assert_equal(a.argmin(), 0)
- a[3] = 30
- assert_equal(a.argmin(), 3)
- a[1] = 10
- assert_equal(a.argmin(), 1)
-
class TestMinMax: