diff options
author | Matti Picus <matti.picus@gmail.com> | 2021-07-12 09:56:19 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-12 09:56:19 +0300 |
commit | c750d24ef3a9a897fb19d3f0b4910f07f078e302 (patch) | |
tree | b12051442b5659b1e8c2d83451aaf7ca8454f4e6 | |
parent | 5fafca2184a56e11b109aac75660a3eef75de425 (diff) | |
parent | 8f854122e90241e979db6d30bcd7544b9d385e3b (diff) | |
download | numpy-c750d24ef3a9a897fb19d3f0b4910f07f078e302.tar.gz |
Merge pull request #19440 from czgdp1807/refac
MAINT: factored out _PyArray_ArgMinMaxCommon
-rw-r--r-- | numpy/core/src/multiarray/calculation.c | 176 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 251 |
2 files changed, 140 insertions, 287 deletions
diff --git a/numpy/core/src/multiarray/calculation.c b/numpy/core/src/multiarray/calculation.c index e89018889..21e52c32b 100644 --- a/numpy/core/src/multiarray/calculation.c +++ b/numpy/core/src/multiarray/calculation.c @@ -35,12 +35,13 @@ power_of_ten(int n) } NPY_NO_EXPORT PyObject * -_PyArray_ArgMaxWithKeepdims(PyArrayObject *op, - int axis, PyArrayObject *out, int keepdims) +_PyArray_ArgMinMaxCommon(PyArrayObject *op, + int axis, PyArrayObject *out, int keepdims, + npy_bool is_argmax) { PyArrayObject *ap = NULL, *rp = NULL; - PyArray_ArgFunc* arg_func; - char *ip; + PyArray_ArgFunc* arg_func = NULL; + char *ip, *func_name; npy_intp *rptr; npy_intp i, n, m; int elsize; @@ -115,7 +116,14 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op, } } - arg_func = PyArray_DESCR(ap)->f->argmax; + if (is_argmax) { + func_name = "argmax"; + arg_func = PyArray_DESCR(ap)->f->argmax; + } + else { + func_name = "argmin"; + arg_func = PyArray_DESCR(ap)->f->argmin; + } if (arg_func == NULL) { PyErr_SetString(PyExc_TypeError, "data type not ordered"); @@ -124,8 +132,9 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op, elsize = PyArray_DESCR(ap)->elsize; m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1]; if (m == 0) { - PyErr_SetString(PyExc_ValueError, - "attempt to get argmax of an empty sequence"); + PyErr_Format(PyExc_ValueError, + "attempt to get %s of an empty sequence", + func_name); goto fail; } @@ -142,8 +151,9 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op, if ((PyArray_NDIM(out) != out_ndim) || !PyArray_CompareLists(PyArray_DIMS(out), out_shape, out_ndim)) { - PyErr_SetString(PyExc_ValueError, - "output array does not match result of np.argmax."); + PyErr_Format(PyExc_ValueError, + "output array does not match result of np.%s.", + func_name); goto fail; } rp = (PyArrayObject *)PyArray_FromArray(out, @@ -179,155 +189,27 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op, return NULL; } +NPY_NO_EXPORT PyObject* +_PyArray_ArgMaxWithKeepdims(PyArrayObject *op, + int axis, PyArrayObject *out, int keepdims) +{ + return _PyArray_ArgMinMaxCommon(op, axis, out, keepdims, 1); +} + /*NUMPY_API * ArgMax */ NPY_NO_EXPORT PyObject * PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out) { - return _PyArray_ArgMaxWithKeepdims(op, axis, out, 0); + return _PyArray_ArgMinMaxCommon(op, axis, out, 0, 1); } NPY_NO_EXPORT PyObject * _PyArray_ArgMinWithKeepdims(PyArrayObject *op, int axis, PyArrayObject *out, int keepdims) { - PyArrayObject *ap = NULL, *rp = NULL; - PyArray_ArgFunc* arg_func; - char *ip; - npy_intp *rptr; - npy_intp i, n, m; - int elsize; - // Keep a copy because axis changes via call to PyArray_CheckAxis - int axis_copy = axis; - npy_intp _shape_buf[NPY_MAXDIMS]; - npy_intp *out_shape; - // Keep the number of dimensions and shape of - // original array. Helps when `keepdims` is True. - npy_intp* original_op_shape = PyArray_DIMS(op); - int out_ndim = PyArray_NDIM(op); - NPY_BEGIN_THREADS_DEF; - - if ((ap = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) { - return NULL; - } - /* - * We need to permute the array so that axis is placed at the end. - * And all other dimensions are shifted left. - */ - if (axis != PyArray_NDIM(ap)-1) { - PyArray_Dims newaxes; - npy_intp dims[NPY_MAXDIMS]; - int i; - - newaxes.ptr = dims; - newaxes.len = PyArray_NDIM(ap); - for (i = 0; i < axis; i++) { - dims[i] = i; - } - for (i = axis; i < PyArray_NDIM(ap) - 1; i++) { - dims[i] = i + 1; - } - dims[PyArray_NDIM(ap) - 1] = axis; - op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes); - Py_DECREF(ap); - if (op == NULL) { - return NULL; - } - } - else { - op = ap; - } - - /* Will get native-byte order contiguous copy. */ - ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)op, - PyArray_DESCR(op)->type_num, 1, 0); - Py_DECREF(op); - if (ap == NULL) { - return NULL; - } - - // Decides the shape of the output array. - if (!keepdims) { - out_ndim = PyArray_NDIM(ap) - 1; - out_shape = PyArray_DIMS(ap); - } else { - out_shape = _shape_buf; - if (axis_copy == NPY_MAXDIMS) { - for (int i = 0; i < out_ndim; i++) { - out_shape[i] = 1; - } - } else { - /* - * While `ap` may be transposed, we can ignore this for `out` because the - * transpose only reorders the size 1 `axis` (not changing memory layout). - */ - memcpy(out_shape, original_op_shape, out_ndim * sizeof(npy_intp)); - out_shape[axis] = 1; - } - } - - arg_func = PyArray_DESCR(ap)->f->argmin; - if (arg_func == NULL) { - PyErr_SetString(PyExc_TypeError, - "data type not ordered"); - goto fail; - } - elsize = PyArray_DESCR(ap)->elsize; - m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1]; - if (m == 0) { - PyErr_SetString(PyExc_ValueError, - "attempt to get argmin of an empty sequence"); - goto fail; - } - - if (!out) { - rp = (PyArrayObject *)PyArray_NewFromDescr( - Py_TYPE(ap), PyArray_DescrFromType(NPY_INTP), - out_ndim, out_shape, NULL, NULL, - 0, (PyObject *)ap); - if (rp == NULL) { - goto fail; - } - } - else { - if ((PyArray_NDIM(out) != out_ndim) || - !PyArray_CompareLists(PyArray_DIMS(out), out_shape, out_ndim)) { - PyErr_SetString(PyExc_ValueError, - "output array does not match result of np.argmin."); - goto fail; - } - rp = (PyArrayObject *)PyArray_FromArray(out, - PyArray_DescrFromType(NPY_INTP), - NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY); - if (rp == NULL) { - goto fail; - } - } - - NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap)); - n = PyArray_SIZE(ap)/m; - rptr = (npy_intp *)PyArray_DATA(rp); - for (ip = PyArray_DATA(ap), i = 0; i < n; i++, ip += elsize*m) { - arg_func(ip, m, rptr, ap); - rptr += 1; - } - NPY_END_THREADS_DESCR(PyArray_DESCR(ap)); - - Py_DECREF(ap); - /* Trigger the UPDATEIFCOPY/WRITEBACKIFCOPY if necessary */ - if (out != NULL && out != rp) { - PyArray_ResolveWritebackIfCopy(rp); - Py_DECREF(rp); - rp = out; - Py_INCREF(rp); - } - return (PyObject *)rp; - - fail: - Py_DECREF(ap); - Py_XDECREF(rp); - return NULL; + return _PyArray_ArgMinMaxCommon(op, axis, out, keepdims, 0); } /*NUMPY_API @@ -336,7 +218,7 @@ _PyArray_ArgMinWithKeepdims(PyArrayObject *op, NPY_NO_EXPORT PyObject * PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out) { - return _PyArray_ArgMinWithKeepdims(op, axis, out, 0); + return _PyArray_ArgMinMaxCommon(op, axis, out, 0, 0); } /*NUMPY_API diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 7c8fc8e3e..5f0a725d2 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -4275,6 +4275,94 @@ class TestArgmaxArgminCommon: method(arr.T, axis=axis, out=wrong_outarray, keepdims=True) + @pytest.mark.parametrize('method', ['max', 'min']) + def test_all(self, method): + a = np.random.normal(0, 1, (4, 5, 6, 7, 8)) + arg_method = getattr(a, 'arg' + method) + val_method = getattr(a, method) + for i in range(a.ndim): + a_maxmin = val_method(i) + aarg_maxmin = arg_method(i) + axes = list(range(a.ndim)) + axes.remove(i) + assert_(np.all(a_maxmin == aarg_maxmin.choose( + *a.transpose(i, *axes)))) + + @pytest.mark.parametrize('method', ['argmax', 'argmin']) + def test_output_shape(self, method): + # see also gh-616 + a = np.ones((10, 5)) + arg_method = getattr(a, method) + # Check some simple shape mismatches + out = np.ones(11, dtype=np.int_) + assert_raises(ValueError, arg_method, -1, out) + + out = np.ones((2, 5), dtype=np.int_) + assert_raises(ValueError, arg_method, -1, out) + + # these could be relaxed possibly (used to allow even the previous) + out = np.ones((1, 10), dtype=np.int_) + assert_raises(ValueError, arg_method, -1, out) + + out = np.ones(10, dtype=np.int_) + arg_method(-1, out=out) + assert_equal(out, arg_method(-1)) + + @pytest.mark.parametrize('ndim', [0, 1]) + @pytest.mark.parametrize('method', ['argmax', 'argmin']) + def test_ret_is_out(self, ndim, method): + a = np.ones((4,) + (3,)*ndim) + arg_method = getattr(a, method) + out = np.empty((3,)*ndim, dtype=np.intp) + ret = arg_method(axis=0, out=out) + assert ret is out + + @pytest.mark.parametrize('np_array, method, idx, val', + [(np.zeros, 'argmax', 5942, "as"), + (np.ones, 'argmin', 6001, "0")]) + def test_unicode(self, np_array, method, idx, val): + d = np_array(6031, dtype='<U9') + arg_method = getattr(d, method) + d[idx] = val + assert_equal(arg_method(), idx) + + @pytest.mark.parametrize('arr_method, np_method', + [('argmax', np.argmax), + ('argmin', np.argmin)]) + def test_np_vs_ndarray(self, arr_method, np_method): + # make sure both ndarray.argmax/argmin and + # numpy.argmax/argmin support out/axis args + a = np.random.normal(size=(2, 3)) + arg_method = getattr(a, arr_method) + + # check positional args + out1 = np.zeros(2, dtype=int) + out2 = np.zeros(2, dtype=int) + assert_equal(arg_method(1, out1), np_method(a, 1, out2)) + assert_equal(out1, out2) + + # check keyword args + out1 = np.zeros(3, dtype=int) + out2 = np.zeros(3, dtype=int) + assert_equal(arg_method(out=out1, axis=0), + np_method(a, out=out2, axis=0)) + assert_equal(out1, out2) + + @pytest.mark.leaks_references(reason="replaces None with NULL.") + @pytest.mark.parametrize('method, vals', + [('argmax', (10, 30)), + ('argmin', (30, 10))]) + def test_object_with_NULLs(self, method, vals): + # See gh-6032 + a = np.empty(4, dtype='O') + arg_method = getattr(a, method) + ctypes.memset(a.ctypes.data, 0, a.nbytes) + assert_equal(arg_method(), 0) + a[3] = vals[0] + assert_equal(arg_method(), 3) + a[1] = vals[1] + assert_equal(arg_method(), 1) + class TestArgmax: nan_arr = [ @@ -4340,81 +4428,30 @@ class TestArgmax: ([True, False, True, False, False], 0), ] - def test_all(self): - a = np.random.normal(0, 1, (4, 5, 6, 7, 8)) - for i in range(a.ndim): - amax = a.max(i) - aargmax = a.argmax(i) - axes = list(range(a.ndim)) - axes.remove(i) - assert_(np.all(amax == aargmax.choose(*a.transpose(i,*axes)))) - - def test_combinations(self): - for arr, pos in self.nan_arr: - with suppress_warnings() as sup: - sup.filter(RuntimeWarning, - "invalid value encountered in reduce") - max_val = np.max(arr) - - assert_equal(np.argmax(arr), pos, err_msg="%r" % arr) - assert_equal(arr[np.argmax(arr)], max_val, err_msg="%r" % arr) - - def test_output_shape(self): - # see also gh-616 - a = np.ones((10, 5)) - # Check some simple shape mismatches - out = np.ones(11, dtype=np.int_) - assert_raises(ValueError, a.argmax, -1, out) - - out = np.ones((2, 5), dtype=np.int_) - assert_raises(ValueError, a.argmax, -1, out) - - # these could be relaxed possibly (used to allow even the previous) - out = np.ones((1, 10), dtype=np.int_) - assert_raises(ValueError, a.argmax, -1, out) - - out = np.ones(10, dtype=np.int_) - a.argmax(-1, out=out) - assert_equal(out, a.argmax(-1)) + @pytest.mark.parametrize('data', nan_arr) + def test_combinations(self, data): + arr, pos = data + with suppress_warnings() as sup: + sup.filter(RuntimeWarning, + "invalid value encountered in reduce") + val = np.max(arr) - @pytest.mark.parametrize('ndim', [0, 1]) - def test_ret_is_out(self, ndim): - a = np.ones((4,) + (3,)*ndim) - out = np.empty((3,)*ndim, dtype=np.intp) - ret = a.argmax(axis=0, out=out) - assert ret is out + assert_equal(np.argmax(arr), pos, err_msg="%r" % arr) + assert_equal(arr[np.argmax(arr)], val, err_msg="%r" % arr) + + def test_maximum_signed_integers(self): - def test_argmax_unicode(self): - d = np.zeros(6031, dtype='<U9') - d[5942] = "as" - assert_equal(d.argmax(), 5942) + a = np.array([1, 2**7 - 1, -2**7], dtype=np.int8) + assert_equal(np.argmax(a), 1) - def test_np_vs_ndarray(self): - # make sure both ndarray.argmax and numpy.argmax support out/axis args - a = np.random.normal(size=(2,3)) + a = np.array([1, 2**15 - 1, -2**15], dtype=np.int16) + assert_equal(np.argmax(a), 1) - # check positional args - out1 = np.zeros(2, dtype=int) - out2 = np.zeros(2, dtype=int) - assert_equal(a.argmax(1, out1), np.argmax(a, 1, out2)) - assert_equal(out1, out2) - - # check keyword args - out1 = np.zeros(3, dtype=int) - out2 = np.zeros(3, dtype=int) - assert_equal(a.argmax(out=out1, axis=0), np.argmax(a, out=out2, axis=0)) - assert_equal(out1, out2) + a = np.array([1, 2**31 - 1, -2**31], dtype=np.int32) + assert_equal(np.argmax(a), 1) - @pytest.mark.leaks_references(reason="replaces None with NULL.") - def test_object_argmax_with_NULLs(self): - # See gh-6032 - a = np.empty(4, dtype='O') - ctypes.memset(a.ctypes.data, 0, a.nbytes) - assert_equal(a.argmax(), 0) - a[3] = 10 - assert_equal(a.argmax(), 3) - a[1] = 30 - assert_equal(a.argmax(), 1) + a = np.array([1, 2**63 - 1, -2**63], dtype=np.int64) + assert_equal(np.argmax(a), 1) class TestArgmin: @@ -4482,15 +4519,6 @@ class TestArgmin: ([False, True, False, True, True], 0), ] - def test_all(self): - a = np.random.normal(0, 1, (4, 5, 6, 7, 8)) - for i in range(a.ndim): - amin = a.min(i) - aargmin = a.argmin(i) - axes = list(range(a.ndim)) - axes.remove(i) - assert_(np.all(amin == aargmin.choose(*a.transpose(i,*axes)))) - def test_combinations(self): for arr, pos in self.nan_arr: with suppress_warnings() as sup: @@ -4503,75 +4531,18 @@ class TestArgmin: def test_minimum_signed_integers(self): - a = np.array([1, -2**7, -2**7 + 1], dtype=np.int8) + a = np.array([1, -2**7, -2**7 + 1, 2**7 - 1], dtype=np.int8) assert_equal(np.argmin(a), 1) - a = np.array([1, -2**15, -2**15 + 1], dtype=np.int16) + a = np.array([1, -2**15, -2**15 + 1, 2**15 - 1], dtype=np.int16) assert_equal(np.argmin(a), 1) - a = np.array([1, -2**31, -2**31 + 1], dtype=np.int32) + a = np.array([1, -2**31, -2**31 + 1, 2**31 - 1], dtype=np.int32) assert_equal(np.argmin(a), 1) - a = np.array([1, -2**63, -2**63 + 1], dtype=np.int64) + a = np.array([1, -2**63, -2**63 + 1, 2**63 - 1], dtype=np.int64) assert_equal(np.argmin(a), 1) - def test_output_shape(self): - # see also gh-616 - a = np.ones((10, 5)) - # Check some simple shape mismatches - out = np.ones(11, dtype=np.int_) - assert_raises(ValueError, a.argmin, -1, out) - - out = np.ones((2, 5), dtype=np.int_) - assert_raises(ValueError, a.argmin, -1, out) - - # these could be relaxed possibly (used to allow even the previous) - out = np.ones((1, 10), dtype=np.int_) - assert_raises(ValueError, a.argmin, -1, out) - - out = np.ones(10, dtype=np.int_) - a.argmin(-1, out=out) - assert_equal(out, a.argmin(-1)) - - @pytest.mark.parametrize('ndim', [0, 1]) - def test_ret_is_out(self, ndim): - a = np.ones((4,) + (3,)*ndim) - out = np.empty((3,)*ndim, dtype=np.intp) - ret = a.argmin(axis=0, out=out) - assert ret is out - - def test_argmin_unicode(self): - d = np.ones(6031, dtype='<U9') - d[6001] = "0" - assert_equal(d.argmin(), 6001) - - def test_np_vs_ndarray(self): - # make sure both ndarray.argmin and numpy.argmin support out/axis args - a = np.random.normal(size=(2, 3)) - - # check positional args - out1 = np.zeros(2, dtype=int) - out2 = np.ones(2, dtype=int) - assert_equal(a.argmin(1, out1), np.argmin(a, 1, out2)) - assert_equal(out1, out2) - - # check keyword args - out1 = np.zeros(3, dtype=int) - out2 = np.ones(3, dtype=int) - assert_equal(a.argmin(out=out1, axis=0), np.argmin(a, out=out2, axis=0)) - assert_equal(out1, out2) - - @pytest.mark.leaks_references(reason="replaces None with NULL.") - def test_object_argmin_with_NULLs(self): - # See gh-6032 - a = np.empty(4, dtype='O') - ctypes.memset(a.ctypes.data, 0, a.nbytes) - assert_equal(a.argmin(), 0) - a[3] = 30 - assert_equal(a.argmin(), 3) - a[1] = 10 - assert_equal(a.argmin(), 1) - class TestMinMax: |