summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_internal.py5
-rw-r--r--numpy/core/tests/test_dtype.py5
-rw-r--r--numpy/lib/index_tricks.py45
-rw-r--r--numpy/lib/tests/test_index_tricks.py8
4 files changed, 24 insertions, 39 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index fbe580dee..686b773f2 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -133,7 +133,7 @@ def _reconstruct(subtype, shape, dtype):
format_re = re.compile(asbytes(
r'(?P<order1>[<>|=]?)'
- r'(?P<repeats> *[(]?[ ,0-9]*[)]? *)'
+ r'(?P<repeats> *[(]?[ ,0-9L]*[)]? *)'
r'(?P<order2>[<>|=]?)'
r'(?P<dtype>[A-Za-z0-9.]*(?:\[[a-zA-Z0-9,.]+\])?)'))
sep_re = re.compile(asbytes(r'\s*,\s*'))
@@ -285,7 +285,7 @@ def _newnames(datatype, order):
# Given an array with fields and a sequence of field names
# construct a new array with just those fields copied over
def _index_fields(ary, fields):
- from multiarray import empty, dtype
+ from multiarray import empty, dtype, array
dt = ary.dtype
names = [name for name in fields if name in dt.names]
@@ -298,7 +298,6 @@ def _index_fields(ary, fields):
# Return a copy for now until behavior is fully deprecated
# in favor of returning view
copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']}
- from numpy import array
return array(view, dtype=copy_dtype, copy=True)
# Given a string containing a PEP 3118 format specifier,
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 246ebba6b..5645a0824 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -413,6 +413,11 @@ class TestString(TestCase):
assert_equal(repr(dt),
"dtype([('a', '<M8[D]'), ('b', '<m8[us]')])")
+ @dec.skipif(sys.version_info[0] > 2)
+ def test_dtype_str_with_long_in_shape(self):
+ # Pull request #376
+ dt = np.dtype('(1L,)i4')
+
class TestDtypeAttributeDeletion(object):
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 6f2aa1d02..b07fde27d 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -18,6 +18,7 @@ import function_base
import numpy.matrixlib as matrix
from function_base import diff
from numpy.lib._compiled_base import ravel_multi_index, unravel_index
+from numpy.lib.stride_tricks import as_strided
makemat = matrix.matrix
def ix_(*args):
@@ -531,37 +532,20 @@ class ndindex(object):
(2, 1, 0)
"""
+ def __init__(self, *shape):
+ x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape))
+ self._it = _nx.nditer(x, flags=['multi_index'], order='C')
- def __init__(self, *args):
- if len(args) == 1 and isinstance(args[0], tuple):
- args = args[0]
- self.nd = len(args)
- self.ind = [0]*self.nd
- self.index = 0
- self.maxvals = args
- tot = 1
- for k in range(self.nd):
- tot *= args[k]
- self.total = tot
-
- def _incrementone(self, axis):
- if (axis < 0): # base case
- return
- if (self.ind[axis] < self.maxvals[axis]-1):
- self.ind[axis] += 1
- else:
- self.ind[axis] = 0
- self._incrementone(axis-1)
+ def __iter__(self):
+ return self
def ndincr(self):
"""
Increment the multi-dimensional index by one.
- `ndincr` takes care of the "wrapping around" of the axes.
- It is called by `ndindex.next` and not normally used directly.
-
+ This method is for backward compatibility only: do not use.
"""
- self._incrementone(self.nd-1)
+ self.next()
def next(self):
"""
@@ -573,17 +557,8 @@ class ndindex(object):
Returns a tuple containing the indices of the current iteration.
"""
- if (self.index >= self.total):
- raise StopIteration
- val = tuple(self.ind)
- self.index += 1
- self.ndincr()
- return val
-
- def __iter__(self):
- return self
-
-
+ self._it.next()
+ return self._it.multi_index
# You can do all this with slice() plus a few special objects,
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index beda2d146..0ede40d5a 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -2,7 +2,7 @@ from numpy.testing import *
import numpy as np
from numpy import ( array, ones, r_, mgrid, unravel_index, zeros, where,
ndenumerate, fill_diagonal, diag_indices,
- diag_indices_from, s_, index_exp )
+ diag_indices_from, s_, index_exp, ndindex )
class TestRavelUnravelIndex(TestCase):
def test_basic(self):
@@ -237,5 +237,11 @@ def test_diag_indices_from():
assert_array_equal(c, np.arange(4))
+def test_ndindex():
+ x = list(np.ndindex(1, 2, 3))
+ expected = [ix for ix, e in np.ndenumerate(np.zeros((1, 2, 3)))]
+ assert_array_equal(x, expected)
+
+
if __name__ == "__main__":
run_module_suite()