diff options
| author | Eric Wieser <wieser.eric@gmail.com> | 2020-05-17 16:07:31 +0100 |
|---|---|---|
| committer | Eric Wieser <wieser.eric@gmail.com> | 2020-05-18 22:39:53 +0100 |
| commit | e73fdb402c13da019eeb3de4d67002d81a335801 (patch) | |
| tree | 022c1306c13082440f4fe1a03923dd17c7366098 /numpy | |
| parent | bd8be5417632c019dbc1d36400052805f95a372c (diff) | |
| download | numpy-e73fdb402c13da019eeb3de4d67002d81a335801.tar.gz | |
BUG: Ensure out argument is returned by identity for 0d arrays
This makes the following functions consistent with the behavior of ufuncs:
* `np.argmax`
* `np.argmin`
* `np.choose`
* `np.trace`
* `np.round`
* `np.take`
Previously they would unpack these into scalars.
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/src/multiarray/methods.c | 74 | ||||
| -rw-r--r-- | numpy/core/tests/test_multiarray.py | 35 |
2 files changed, 99 insertions, 10 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index e2026ec1c..3679a34b8 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -146,8 +146,15 @@ array_take(PyArrayObject *self, PyObject *args, PyObject *kwds) PyArray_ClipmodeConverter, &mode)) return NULL; - return PyArray_Return((PyArrayObject *) - PyArray_TakeFrom(self, indices, dimension, out, mode)); + PyObject *ret = PyArray_TakeFrom(self, indices, dimension, out, mode); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } static PyObject * @@ -303,7 +310,15 @@ array_argmax(PyArrayObject *self, PyObject *args, PyObject *kwds) PyArray_OutputConverter, &out)) return NULL; - return PyArray_Return((PyArrayObject *)PyArray_ArgMax(self, axis, out)); + PyObject *ret = PyArray_ArgMax(self, axis, out); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } static PyObject * @@ -318,7 +333,15 @@ array_argmin(PyArrayObject *self, PyObject *args, PyObject *kwds) PyArray_OutputConverter, &out)) return NULL; - return PyArray_Return((PyArrayObject *)PyArray_ArgMin(self, axis, out)); + PyObject *ret = PyArray_ArgMin(self, axis, out); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } static PyObject * @@ -1218,7 +1241,15 @@ array_choose(PyArrayObject *self, PyObject *args, PyObject *kwds) return NULL; } - return PyArray_Return((PyArrayObject *)PyArray_Choose(self, choices, out, clipmode)); + PyObject *ret = PyArray_Choose(self, choices, out, clipmode); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } static PyObject * @@ -2319,8 +2350,16 @@ array_compress(PyArrayObject *self, PyObject *args, PyObject *kwds) PyArray_OutputConverter, &out)) { return NULL; } - return PyArray_Return( - (PyArrayObject *)PyArray_Compress(self, condition, axis, out)); + + PyObject *ret = PyArray_Compress(self, condition, axis, out); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } @@ -2355,7 +2394,15 @@ array_trace(PyArrayObject *self, PyObject *args, PyObject *kwds) rtype = _CHKTYPENUM(dtype); Py_XDECREF(dtype); - return PyArray_Return((PyArrayObject *)PyArray_Trace(self, offset, axis1, axis2, rtype, out)); + PyObject *ret = PyArray_Trace(self, offset, axis1, axis2, rtype, out); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } #undef _CHKTYPENUM @@ -2440,7 +2487,16 @@ array_round(PyArrayObject *self, PyObject *args, PyObject *kwds) PyArray_OutputConverter, &out)) { return NULL; } - return PyArray_Return((PyArrayObject *)PyArray_Round(self, decimals, out)); + + PyObject *ret = PyArray_Round(self, decimals, out); + + /* this matches the unpacking behavior of ufuncs */ + if (out == NULL) { + return PyArray_Return((PyArrayObject *)ret); + } + else { + return ret; + } } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index f36c27c6c..a698370b6 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1582,6 +1582,11 @@ class TestMethods: # gh-12031, caused SEGFAULT assert_raises(TypeError, oned.choose,np.void(0), [oned]) + out = np.array(0) + ret = np.choose(np.array(1), [10, 20, 30], out=out) + assert out is ret + assert_equal(out[()], 20) + # gh-6272 check overlap on out x = np.arange(5) y = np.choose([0,0,0], [x[:3], x[:3], x[:3]], out=x[1:4], mode='wrap') @@ -1658,7 +1663,7 @@ class TestMethods: out = np.zeros_like(arr) res = arr.round(*round_args, out=out) assert_equal(out, expected) - assert_equal(out, res) + assert out is res check_round(np.array([1.2, 1.5]), [1, 2]) check_round(np.array(1.5), 2) @@ -3023,6 +3028,10 @@ class TestMethods: assert_equal(b.trace(0, 1, 2), [3, 11]) assert_equal(b.trace(offset=1, axis1=0, axis2=2), [1, 3]) + out = np.array(1) + ret = a.trace(out=out) + assert ret is out + def test_trace_subclass(self): # The class would need to overwrite trace to ensure single-element # output also has the right subclass. @@ -4126,6 +4135,13 @@ class TestArgmax: a.argmax(-1, out=out) assert_equal(out, a.argmax(-1)) + @pytest.mark.parametrize('ndim', [0, 1]) + def test_ret_is_out(self, ndim): + a = np.ones((4,) + (3,)*ndim) + out = np.empty((3,)*ndim, dtype=np.intp) + ret = a.argmax(axis=0, out=out) + assert ret is out + def test_argmax_unicode(self): d = np.zeros(6031, dtype='<U9') d[5942] = "as" @@ -4275,6 +4291,13 @@ class TestArgmin: a.argmin(-1, out=out) assert_equal(out, a.argmin(-1)) + @pytest.mark.parametrize('ndim', [0, 1]) + def test_ret_is_out(self, ndim): + a = np.ones((4,) + (3,)*ndim) + out = np.empty((3,)*ndim, dtype=np.intp) + ret = a.argmin(axis=0, out=out) + assert ret is out + def test_argmin_unicode(self): d = np.ones(6031, dtype='<U9') d[6001] = "0" @@ -4552,6 +4575,16 @@ class TestTake: y = np.take(x, [1, 2, 3], out=x[2:5], mode='wrap') assert_equal(y, np.array([1, 2, 3])) + @pytest.mark.parametrize('shape', [(1, 2), (1,), ()]) + def test_ret_is_out(self, shape): + # 0d arrays should not be an exception to this rule + x = np.arange(5) + inds = np.zeros(shape, dtype=np.intp) + out = np.zeros(shape, dtype=x.dtype) + ret = np.take(x, inds, out=out) + assert ret is out + + class TestLexsort: @pytest.mark.parametrize('dtype',[ np.uint8, np.uint16, np.uint32, np.uint64, |
