diff options
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 74 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 35 |
2 files changed, 99 insertions, 10 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index e2026ec1c..3679a34b8 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -146,8 +146,15 @@ array_take(PyArrayObject *self, PyObject *args, PyObject *kwds) PyArray_ClipmodeConverter, &mode)) return NULL; - return PyArray_Return((PyArrayObject *) - PyArray_TakeFrom(self, indices, dimension, out, mode)); + PyObject *ret = PyArray_TakeFrom(self, indices, dimension, out, mode); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } static PyObject * @@ -303,7 +310,15 @@ array_argmax(PyArrayObject *self, PyObject *args, PyObject *kwds) PyArray_OutputConverter, &out)) return NULL; - return PyArray_Return((PyArrayObject *)PyArray_ArgMax(self, axis, out)); + PyObject *ret = PyArray_ArgMax(self, axis, out); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } static PyObject * @@ -318,7 +333,15 @@ array_argmin(PyArrayObject *self, PyObject *args, PyObject *kwds) PyArray_OutputConverter, &out)) return NULL; - return PyArray_Return((PyArrayObject *)PyArray_ArgMin(self, axis, out)); + PyObject *ret = PyArray_ArgMin(self, axis, out); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } static PyObject * @@ -1218,7 +1241,15 @@ array_choose(PyArrayObject *self, PyObject *args, PyObject *kwds) return NULL; } - return PyArray_Return((PyArrayObject *)PyArray_Choose(self, choices, out, clipmode)); + PyObject *ret = PyArray_Choose(self, choices, out, clipmode); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } static PyObject * @@ -2319,8 +2350,16 @@ array_compress(PyArrayObject *self, PyObject *args, PyObject *kwds) PyArray_OutputConverter, &out)) { return NULL; } - return PyArray_Return( - (PyArrayObject *)PyArray_Compress(self, condition, axis, out)); + + PyObject *ret = PyArray_Compress(self, condition, axis, out); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } @@ -2355,7 +2394,15 @@ array_trace(PyArrayObject *self, PyObject *args, PyObject *kwds) rtype = _CHKTYPENUM(dtype); Py_XDECREF(dtype); - return PyArray_Return((PyArrayObject *)PyArray_Trace(self, offset, axis1, axis2, rtype, out)); + PyObject *ret = PyArray_Trace(self, offset, axis1, axis2, rtype, out); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } #undef _CHKTYPENUM @@ -2440,7 +2487,16 @@ array_round(PyArrayObject *self, PyObject *args, PyObject *kwds) PyArray_OutputConverter, &out)) { return NULL; } - return PyArray_Return((PyArrayObject *)PyArray_Round(self, decimals, out)); + + PyObject *ret = PyArray_Round(self, decimals, out); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index f36c27c6c..a698370b6 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1582,6 +1582,11 @@ class TestMethods: # gh-12031, caused SEGFAULT assert_raises(TypeError, oned.choose,np.void(0), [oned]) + out = np.array(0) + ret = np.choose(np.array(1), [10, 20, 30], out=out) + assert out is ret + assert_equal(out[()], 20) + # gh-6272 check overlap on out x = np.arange(5) y = np.choose([0,0,0], [x[:3], x[:3], x[:3]], out=x[1:4], mode='wrap') @@ -1658,7 +1663,7 @@ class TestMethods: out = np.zeros_like(arr) res = arr.round(*round_args, out=out) assert_equal(out, expected) - assert_equal(out, res) + assert out is res check_round(np.array([1.2, 1.5]), [1, 2]) check_round(np.array(1.5), 2) @@ -3023,6 +3028,10 @@ class TestMethods: assert_equal(b.trace(0, 1, 2), [3, 11]) assert_equal(b.trace(offset=1, axis1=0, axis2=2), [1, 3]) + out = np.array(1) + ret = a.trace(out=out) + assert ret is out + def test_trace_subclass(self): # The class would need to overwrite trace to ensure single-element # output also has the right subclass. @@ -4126,6 +4135,13 @@ class TestArgmax: a.argmax(-1, out=out) assert_equal(out, a.argmax(-1)) + @pytest.mark.parametrize('ndim', [0, 1]) + def test_ret_is_out(self, ndim): + a = np.ones((4,) + (3,)*ndim) + out = np.empty((3,)*ndim, dtype=np.intp) + ret = a.argmax(axis=0, out=out) + assert ret is out + def test_argmax_unicode(self): d = np.zeros(6031, dtype='<U9') d[5942] = "as" @@ -4275,6 +4291,13 @@ class TestArgmin: a.argmin(-1, out=out) assert_equal(out, a.argmin(-1)) + @pytest.mark.parametrize('ndim', [0, 1]) + def test_ret_is_out(self, ndim): + a = np.ones((4,) + (3,)*ndim) + out = np.empty((3,)*ndim, dtype=np.intp) + ret = a.argmin(axis=0, out=out) + assert ret is out + def test_argmin_unicode(self): d = np.ones(6031, dtype='<U9') d[6001] = "0" @@ -4552,6 +4575,16 @@ class TestTake: y = np.take(x, [1, 2, 3], out=x[2:5], mode='wrap') assert_equal(y, np.array([1, 2, 3])) + @pytest.mark.parametrize('shape', [(1, 2), (1,), ()]) + def test_ret_is_out(self, shape): + # 0d arrays should not be an exception to this rule + x = np.arange(5) + inds = np.zeros(shape, dtype=np.intp) + out = np.zeros(shape, dtype=x.dtype) + ret = np.take(x, inds, out=out) + assert ret is out + + class TestLexsort: @pytest.mark.parametrize('dtype',[ np.uint8, np.uint16, np.uint32, np.uint64, |