summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-03-31 18:09:08 -0500
committerGitHub <noreply@github.com>2020-03-31 18:09:08 -0500
commit12f9da5615e6c1e387a9101a8923867343dfaeca (patch)
treec4ce75560eb3ded062c9ee89ade645f4378da4ca /numpy/core
parent18bfeaf2716001afa8785638a75f90259dbc8aa1 (diff)
parent0c85dae2dcdff58d56a619dfa7cec0157e91d78d (diff)
downloadnumpy-12f9da5615e6c1e387a9101a8923867343dfaeca.tar.gz
Merge pull request #15884 from eric-wieser/fix-set-empty-strides
BUG: Setting a 0d array's strides to themselves should be legal
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/getset.c7
-rw-r--r--numpy/core/tests/test_multiarray.py5
2 files changed, 9 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/getset.c b/numpy/core/src/multiarray/getset.c
index 8c1b7f943..80a1cd4a1 100644
--- a/numpy/core/src/multiarray/getset.c
+++ b/numpy/core/src/multiarray/getset.c
@@ -13,6 +13,7 @@
#include "npy_import.h"
#include "common.h"
+#include "conversion_utils.h"
#include "ctors.h"
#include "scalartypes.h"
#include "descriptor.h"
@@ -110,7 +111,7 @@ array_strides_get(PyArrayObject *self)
static int
array_strides_set(PyArrayObject *self, PyObject *obj)
{
- PyArray_Dims newstrides = {NULL, 0};
+ PyArray_Dims newstrides = {NULL, -1};
PyArrayObject *new;
npy_intp numbytes = 0;
npy_intp offset = 0;
@@ -123,8 +124,8 @@ array_strides_set(PyArrayObject *self, PyObject *obj)
"Cannot delete array strides");
return -1;
}
- if (!PyArray_IntpConverter(obj, &newstrides) ||
- newstrides.ptr == NULL) {
+ if (!PyArray_OptionalIntpConverter(obj, &newstrides) ||
+ newstrides.len == -1) {
PyErr_SetString(PyExc_TypeError, "invalid strides");
return -1;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 0b99008e0..5515ff446 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -356,6 +356,11 @@ class TestAttributes:
a.strides = 1
a[::2].strides = 2
+ # test 0d
+ arr_0d = np.array(0)
+ arr_0d.strides = ()
+ assert_raises(TypeError, set_strides, arr_0d, None)
+
def test_fill(self):
for t in "?bhilqpBHILQPfdgFDGO":
x = np.empty((3, 2, 1), t)