summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2020-03-31 22:10:37 +0100
committerEric Wieser <wieser.eric@gmail.com>2020-03-31 22:34:23 +0100
commit0c85dae2dcdff58d56a619dfa7cec0157e91d78d (patch)
tree617b8ffb6fb08a5c589f17e6949b0f588f7c8f78 /numpy
parentb570f73d39505926478fda1b1a496ac3e707d33d (diff)
downloadnumpy-0c85dae2dcdff58d56a619dfa7cec0157e91d78d.tar.gz
BUG: Setting a 0d array's strides to themselves should be legal
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/getset.c7
-rw-r--r--numpy/core/tests/test_multiarray.py5
2 files changed, 9 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/getset.c b/numpy/core/src/multiarray/getset.c
index 8c1b7f943..80a1cd4a1 100644
--- a/numpy/core/src/multiarray/getset.c
+++ b/numpy/core/src/multiarray/getset.c
@@ -13,6 +13,7 @@
#include "npy_import.h"
#include "common.h"
+#include "conversion_utils.h"
#include "ctors.h"
#include "scalartypes.h"
#include "descriptor.h"
@@ -110,7 +111,7 @@ array_strides_get(PyArrayObject *self)
static int
array_strides_set(PyArrayObject *self, PyObject *obj)
{
- PyArray_Dims newstrides = {NULL, 0};
+ PyArray_Dims newstrides = {NULL, -1};
PyArrayObject *new;
npy_intp numbytes = 0;
npy_intp offset = 0;
@@ -123,8 +124,8 @@ array_strides_set(PyArrayObject *self, PyObject *obj)
"Cannot delete array strides");
return -1;
}
- if (!PyArray_IntpConverter(obj, &newstrides) ||
- newstrides.ptr == NULL) {
+ if (!PyArray_OptionalIntpConverter(obj, &newstrides) ||
+ newstrides.len == -1) {
PyErr_SetString(PyExc_TypeError, "invalid strides");
return -1;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 829679dab..7d2e2df7c 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -363,6 +363,11 @@ class TestAttributes:
a.strides = 1
a[::2].strides = 2
+ # test 0d
+ arr_0d = np.array(0)
+ arr_0d.strides = ()
+ assert_raises(TypeError, set_strides, arr_0d, None)
+
def test_fill(self):
for t in "?bhilqpBHILQPfdgFDGO":
x = np.empty((3, 2, 1), t)