diff options
author | njsmith <njs@pobox.com> | 2013-01-15 13:19:31 -0800 |
---|---|---|
committer | njsmith <njs@pobox.com> | 2013-01-15 13:19:31 -0800 |
commit | ee6c66a2a71f46f02a823a8e107592ca3b76858e (patch) | |
tree | ad16ad0e5ea3073639b0b292941f0a8b211445c8 | |
parent | eb109d2bbea6e61c289907b9962d57c768eee28f (diff) | |
parent | d1a055b0cd5ad82106901a274d58563df4cd5070 (diff) | |
download | numpy-ee6c66a2a71f46f02a823a8e107592ca3b76858e.tar.gz |
Merge pull request #2897 from ContinuumIO/ndindex_fix
Fix for #2895 ndindex failing
-rw-r--r-- | numpy/core/src/multiarray/ctors.c | 5 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 7 | ||||
-rw-r--r-- | numpy/lib/index_tricks.py | 26 | ||||
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 10 |
4 files changed, 41 insertions, 7 deletions
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c index 980d73a32..bdf2e6e2b 100644 --- a/numpy/core/src/multiarray/ctors.c +++ b/numpy/core/src/multiarray/ctors.c @@ -2075,11 +2075,6 @@ PyArray_FromInterface(PyObject *origin) /* Case for data access through pointer */ if (attr && PyTuple_Check(attr)) { PyObject *dataptr; - if (n == 0) { - PyErr_SetString(PyExc_ValueError, - "__array_interface__ shape must be at least size 1"); - goto fail; - } if (PyTuple_GET_SIZE(attr) != 2) { PyErr_SetString(PyExc_TypeError, "__array_interface__ data must be a 2-tuple with " diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 3d5bae220..319147970 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2873,6 +2873,13 @@ def test_array_interface(): f.iface['shape'] = (2,) assert_raises(ValueError, np.array, f) + # test scalar with no shape + class ArrayLike(object): + array = np.array(1) + __array_interface__ = array.__array_interface__ + assert_equal(np.array(ArrayLike()), 1) + + def test_flat_element_deletion(): it = np.ones(3).flat try: diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index b07fde27d..852c5f6fd 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -502,7 +502,6 @@ class ndenumerate(object): def __iter__(self): return self - class ndindex(object): """ An N-dimensional iterator object to index arrays. @@ -532,13 +531,36 @@ class ndindex(object): (2, 1, 0) """ + # This is a hack to handle 0-d arrays correctly. + # Fixing nditer would be more work but should be done eventually, + # and then this entire __new__ method can be removed. + def __new__(cls, *shape): + if len(shape) == 0 or (len(shape) == 1 and len(shape[0]) == 0): + class zero_dim_iter(object): + def __init__(self): + self._N = 1 + def __iter__(self): + return self + def ndincr(self): + self.next() + def next(self): + if self._N > 0: + self._N -= 1 + return () + raise StopIteration + return zero_dim_iter() + else: + return super(ndindex, cls).__new__(cls) + def __init__(self, *shape): + if len(shape) == 1 and isinstance(shape[0], tuple): + shape = shape[0] x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape)) self._it = _nx.nditer(x, flags=['multi_index'], order='C') def __iter__(self): return self - + def ndincr(self): """ Increment the multi-dimensional index by one. diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index 0ede40d5a..43160ffb7 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -242,6 +242,16 @@ def test_ndindex(): expected = [ix for ix, e in np.ndenumerate(np.zeros((1, 2, 3)))] assert_array_equal(x, expected) + x = list(np.ndindex((1, 2, 3))) + assert_array_equal(x, expected) + + # Make sure size argument is optional + x = list(np.ndindex()) + assert_equal(x, [()]) + + x = list(np.ndindex(())) + assert_equal(x, [()]) + if __name__ == "__main__": run_module_suite() |