diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2020-03-31 22:10:37 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2020-03-31 22:34:23 +0100 |
commit | 0c85dae2dcdff58d56a619dfa7cec0157e91d78d (patch) | |
tree | 617b8ffb6fb08a5c589f17e6949b0f588f7c8f78 /numpy | |
parent | b570f73d39505926478fda1b1a496ac3e707d33d (diff) | |
download | numpy-0c85dae2dcdff58d56a619dfa7cec0157e91d78d.tar.gz |
BUG: Setting a 0d array's strides to themselves should be legal
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/getset.c | 7 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 5 |
2 files changed, 9 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/getset.c b/numpy/core/src/multiarray/getset.c index 8c1b7f943..80a1cd4a1 100644 --- a/numpy/core/src/multiarray/getset.c +++ b/numpy/core/src/multiarray/getset.c @@ -13,6 +13,7 @@ #include "npy_import.h" #include "common.h" +#include "conversion_utils.h" #include "ctors.h" #include "scalartypes.h" #include "descriptor.h" @@ -110,7 +111,7 @@ array_strides_get(PyArrayObject *self) static int array_strides_set(PyArrayObject *self, PyObject *obj) { - PyArray_Dims newstrides = {NULL, 0}; + PyArray_Dims newstrides = {NULL, -1}; PyArrayObject *new; npy_intp numbytes = 0; npy_intp offset = 0; @@ -123,8 +124,8 @@ array_strides_set(PyArrayObject *self, PyObject *obj) "Cannot delete array strides"); return -1; } - if (!PyArray_IntpConverter(obj, &newstrides) || - newstrides.ptr == NULL) { + if (!PyArray_OptionalIntpConverter(obj, &newstrides) || + newstrides.len == -1) { PyErr_SetString(PyExc_TypeError, "invalid strides"); return -1; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 829679dab..7d2e2df7c 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -363,6 +363,11 @@ class TestAttributes: a.strides = 1 a[::2].strides = 2 + # test 0d + arr_0d = np.array(0) + arr_0d.strides = () + assert_raises(TypeError, set_strides, arr_0d, None) + def test_fill(self): for t in "?bhilqpBHILQPfdgFDGO": x = np.empty((3, 2, 1), t) |