summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/methods.c74
-rw-r--r--numpy/core/tests/test_multiarray.py35
2 files changed, 99 insertions, 10 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index e2026ec1c..3679a34b8 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -146,8 +146,15 @@ array_take(PyArrayObject *self, PyObject *args, PyObject *kwds)
PyArray_ClipmodeConverter, &mode))
return NULL;
- return PyArray_Return((PyArrayObject *)
- PyArray_TakeFrom(self, indices, dimension, out, mode));
+ PyObject *ret = PyArray_TakeFrom(self, indices, dimension, out, mode);
+
+ /* this matches the unpacking behavior of ufuncs */
+ if (out == NULL) {
+ return PyArray_Return((PyArrayObject *)ret);
+ }
+ else {
+ return ret;
+ }
}
static PyObject *
@@ -303,7 +310,15 @@ array_argmax(PyArrayObject *self, PyObject *args, PyObject *kwds)
PyArray_OutputConverter, &out))
return NULL;
- return PyArray_Return((PyArrayObject *)PyArray_ArgMax(self, axis, out));
+ PyObject *ret = PyArray_ArgMax(self, axis, out);
+
+ /* this matches the unpacking behavior of ufuncs */
+ if (out == NULL) {
+ return PyArray_Return((PyArrayObject *)ret);
+ }
+ else {
+ return ret;
+ }
}
static PyObject *
@@ -318,7 +333,15 @@ array_argmin(PyArrayObject *self, PyObject *args, PyObject *kwds)
PyArray_OutputConverter, &out))
return NULL;
- return PyArray_Return((PyArrayObject *)PyArray_ArgMin(self, axis, out));
+ PyObject *ret = PyArray_ArgMin(self, axis, out);
+
+ /* this matches the unpacking behavior of ufuncs */
+ if (out == NULL) {
+ return PyArray_Return((PyArrayObject *)ret);
+ }
+ else {
+ return ret;
+ }
}
static PyObject *
@@ -1218,7 +1241,15 @@ array_choose(PyArrayObject *self, PyObject *args, PyObject *kwds)
return NULL;
}
- return PyArray_Return((PyArrayObject *)PyArray_Choose(self, choices, out, clipmode));
+ PyObject *ret = PyArray_Choose(self, choices, out, clipmode);
+
+ /* this matches the unpacking behavior of ufuncs */
+ if (out == NULL) {
+ return PyArray_Return((PyArrayObject *)ret);
+ }
+ else {
+ return ret;
+ }
}
static PyObject *
@@ -2319,8 +2350,16 @@ array_compress(PyArrayObject *self, PyObject *args, PyObject *kwds)
PyArray_OutputConverter, &out)) {
return NULL;
}
- return PyArray_Return(
- (PyArrayObject *)PyArray_Compress(self, condition, axis, out));
+
+ PyObject *ret = PyArray_Compress(self, condition, axis, out);
+
+ /* this matches the unpacking behavior of ufuncs */
+ if (out == NULL) {
+ return PyArray_Return((PyArrayObject *)ret);
+ }
+ else {
+ return ret;
+ }
}
@@ -2355,7 +2394,15 @@ array_trace(PyArrayObject *self, PyObject *args, PyObject *kwds)
rtype = _CHKTYPENUM(dtype);
Py_XDECREF(dtype);
- return PyArray_Return((PyArrayObject *)PyArray_Trace(self, offset, axis1, axis2, rtype, out));
+ PyObject *ret = PyArray_Trace(self, offset, axis1, axis2, rtype, out);
+
+ /* this matches the unpacking behavior of ufuncs */
+ if (out == NULL) {
+ return PyArray_Return((PyArrayObject *)ret);
+ }
+ else {
+ return ret;
+ }
}
#undef _CHKTYPENUM
@@ -2440,7 +2487,16 @@ array_round(PyArrayObject *self, PyObject *args, PyObject *kwds)
PyArray_OutputConverter, &out)) {
return NULL;
}
- return PyArray_Return((PyArrayObject *)PyArray_Round(self, decimals, out));
+
+ PyObject *ret = PyArray_Round(self, decimals, out);
+
+ /* this matches the unpacking behavior of ufuncs */
+ if (out == NULL) {
+ return PyArray_Return((PyArrayObject *)ret);
+ }
+ else {
+ return ret;
+ }
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index f36c27c6c..a698370b6 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1582,6 +1582,11 @@ class TestMethods:
# gh-12031, caused SEGFAULT
assert_raises(TypeError, oned.choose,np.void(0), [oned])
+ out = np.array(0)
+ ret = np.choose(np.array(1), [10, 20, 30], out=out)
+ assert out is ret
+ assert_equal(out[()], 20)
+
# gh-6272 check overlap on out
x = np.arange(5)
y = np.choose([0,0,0], [x[:3], x[:3], x[:3]], out=x[1:4], mode='wrap')
@@ -1658,7 +1663,7 @@ class TestMethods:
out = np.zeros_like(arr)
res = arr.round(*round_args, out=out)
assert_equal(out, expected)
- assert_equal(out, res)
+ assert out is res
check_round(np.array([1.2, 1.5]), [1, 2])
check_round(np.array(1.5), 2)
@@ -3023,6 +3028,10 @@ class TestMethods:
assert_equal(b.trace(0, 1, 2), [3, 11])
assert_equal(b.trace(offset=1, axis1=0, axis2=2), [1, 3])
+ out = np.array(1)
+ ret = a.trace(out=out)
+ assert ret is out
+
def test_trace_subclass(self):
# The class would need to overwrite trace to ensure single-element
# output also has the right subclass.
@@ -4126,6 +4135,13 @@ class TestArgmax:
a.argmax(-1, out=out)
assert_equal(out, a.argmax(-1))
+ @pytest.mark.parametrize('ndim', [0, 1])
+ def test_ret_is_out(self, ndim):
+ a = np.ones((4,) + (3,)*ndim)
+ out = np.empty((3,)*ndim, dtype=np.intp)
+ ret = a.argmax(axis=0, out=out)
+ assert ret is out
+
def test_argmax_unicode(self):
d = np.zeros(6031, dtype='<U9')
d[5942] = "as"
@@ -4275,6 +4291,13 @@ class TestArgmin:
a.argmin(-1, out=out)
assert_equal(out, a.argmin(-1))
+ @pytest.mark.parametrize('ndim', [0, 1])
+ def test_ret_is_out(self, ndim):
+ a = np.ones((4,) + (3,)*ndim)
+ out = np.empty((3,)*ndim, dtype=np.intp)
+ ret = a.argmin(axis=0, out=out)
+ assert ret is out
+
def test_argmin_unicode(self):
d = np.ones(6031, dtype='<U9')
d[6001] = "0"
@@ -4552,6 +4575,16 @@ class TestTake:
y = np.take(x, [1, 2, 3], out=x[2:5], mode='wrap')
assert_equal(y, np.array([1, 2, 3]))
+ @pytest.mark.parametrize('shape', [(1, 2), (1,), ()])
+ def test_ret_is_out(self, shape):
+ # 0d arrays should not be an exception to this rule
+ x = np.arange(5)
+ inds = np.zeros(shape, dtype=np.intp)
+ out = np.zeros(shape, dtype=x.dtype)
+ ret = np.take(x, inds, out=out)
+ assert ret is out
+
+
class TestLexsort:
@pytest.mark.parametrize('dtype',[
np.uint8, np.uint16, np.uint32, np.uint64,