summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_add_newdocs.py2
-rw-r--r--numpy/core/src/multiarray/getset.c7
-rw-r--r--numpy/core/src/multiarray/nditer_pywrap.c11
-rw-r--r--numpy/core/tests/test_multiarray.py5
-rw-r--r--numpy/core/tests/test_nditer.py10
5 files changed, 23 insertions, 12 deletions
diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py
index 68902d25a..b963f32b8 100644
--- a/numpy/core/_add_newdocs.py
+++ b/numpy/core/_add_newdocs.py
@@ -152,7 +152,7 @@ add_newdoc('numpy.core', 'flatiter', ('copy',
add_newdoc('numpy.core', 'nditer',
"""
- nditer(op, flags=None, op_flags=None, op_dtypes=None, order='K', casting='safe', op_axes=None, itershape=(), buffersize=0)
+ nditer(op, flags=None, op_flags=None, op_dtypes=None, order='K', casting='safe', op_axes=None, itershape=None, buffersize=0)
Efficient multi-dimensional iterator object to iterate over arrays.
To get started using this object, see the
diff --git a/numpy/core/src/multiarray/getset.c b/numpy/core/src/multiarray/getset.c
index 8c1b7f943..80a1cd4a1 100644
--- a/numpy/core/src/multiarray/getset.c
+++ b/numpy/core/src/multiarray/getset.c
@@ -13,6 +13,7 @@
#include "npy_import.h"
#include "common.h"
+#include "conversion_utils.h"
#include "ctors.h"
#include "scalartypes.h"
#include "descriptor.h"
@@ -110,7 +111,7 @@ array_strides_get(PyArrayObject *self)
static int
array_strides_set(PyArrayObject *self, PyObject *obj)
{
- PyArray_Dims newstrides = {NULL, 0};
+ PyArray_Dims newstrides = {NULL, -1};
PyArrayObject *new;
npy_intp numbytes = 0;
npy_intp offset = 0;
@@ -123,8 +124,8 @@ array_strides_set(PyArrayObject *self, PyObject *obj)
"Cannot delete array strides");
return -1;
}
- if (!PyArray_IntpConverter(obj, &newstrides) ||
- newstrides.ptr == NULL) {
+ if (!PyArray_OptionalIntpConverter(obj, &newstrides) ||
+ newstrides.len == -1) {
PyErr_SetString(PyExc_TypeError, "invalid strides");
return -1;
}
diff --git a/numpy/core/src/multiarray/nditer_pywrap.c b/numpy/core/src/multiarray/nditer_pywrap.c
index add40f460..505c5a841 100644
--- a/numpy/core/src/multiarray/nditer_pywrap.c
+++ b/numpy/core/src/multiarray/nditer_pywrap.c
@@ -17,6 +17,7 @@
#include "npy_pycompat.h"
#include "alloc.h"
#include "common.h"
+#include "conversion_utils.h"
#include "ctors.h"
/* Functions not part of the public NumPy C API */
@@ -748,7 +749,7 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds)
int oa_ndim = -1;
int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
int *op_axes[NPY_MAXARGS];
- PyArray_Dims itershape = {NULL, 0};
+ PyArray_Dims itershape = {NULL, -1};
int buffersize = 0;
if (self->iter != NULL) {
@@ -765,7 +766,7 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds)
npyiter_order_converter, &order,
PyArray_CastingConverter, &casting,
&op_axes_in,
- PyArray_IntpConverter, &itershape,
+ PyArray_OptionalIntpConverter, &itershape,
&buffersize)) {
npy_free_cache_dim_obj(itershape);
return -1;
@@ -800,7 +801,7 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds)
}
}
- if (itershape.len > 0) {
+ if (itershape.len != -1) {
if (oa_ndim == -1) {
oa_ndim = itershape.len;
memset(op_axes, 0, sizeof(op_axes[0]) * nop);
@@ -812,10 +813,6 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds)
goto fail;
}
}
- else if (itershape.ptr != NULL) {
- npy_free_cache_dim_obj(itershape);
- itershape.ptr = NULL;
- }
self->iter = NpyIter_AdvancedNew(nop, op, flags, order, casting, op_flags,
op_request_dtypes,
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 0b99008e0..5515ff446 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -356,6 +356,11 @@ class TestAttributes:
a.strides = 1
a[::2].strides = 2
+ # test 0d
+ arr_0d = np.array(0)
+ arr_0d.strides = ()
+ assert_raises(TypeError, set_strides, arr_0d, None)
+
def test_fill(self):
for t in "?bhilqpBHILQPfdgFDGO":
x = np.empty((3, 2, 1), t)
diff --git a/numpy/core/tests/test_nditer.py b/numpy/core/tests/test_nditer.py
index 24272bb0d..c106c528d 100644
--- a/numpy/core/tests/test_nditer.py
+++ b/numpy/core/tests/test_nditer.py
@@ -2688,7 +2688,15 @@ def test_0d_iter():
i = nditer(np.arange(5), ['multi_index'], [['readonly']], op_axes=[()])
assert_equal(i.ndim, 0)
assert_equal(len(i), 1)
- # note that itershape=(), still behaves like None due to the conversions
+
+ i = nditer(np.arange(5), ['multi_index'], [['readonly']],
+ op_axes=[()], itershape=())
+ assert_equal(i.ndim, 0)
+ assert_equal(len(i), 1)
+
+ # passing an itershape alone is not enough, the op_axes are also needed
+ with assert_raises(ValueError):
+ nditer(np.arange(5), ['multi_index'], [['readonly']], itershape=())
# Test a more complex buffered casting case (same as another test above)
sdt = [('a', 'f4'), ('b', 'i8'), ('c', 'c8', (2, 3)), ('d', 'O')]