diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/_internal.py | 5 | ||||
-rw-r--r-- | numpy/core/tests/test_dtype.py | 5 | ||||
-rw-r--r-- | numpy/lib/index_tricks.py | 45 | ||||
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 8 |
4 files changed, 24 insertions, 39 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index fbe580dee..686b773f2 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -133,7 +133,7 @@ def _reconstruct(subtype, shape, dtype): format_re = re.compile(asbytes( r'(?P<order1>[<>|=]?)' - r'(?P<repeats> *[(]?[ ,0-9]*[)]? *)' + r'(?P<repeats> *[(]?[ ,0-9L]*[)]? *)' r'(?P<order2>[<>|=]?)' r'(?P<dtype>[A-Za-z0-9.]*(?:\[[a-zA-Z0-9,.]+\])?)')) sep_re = re.compile(asbytes(r'\s*,\s*')) @@ -285,7 +285,7 @@ def _newnames(datatype, order): # Given an array with fields and a sequence of field names # construct a new array with just those fields copied over def _index_fields(ary, fields): - from multiarray import empty, dtype + from multiarray import empty, dtype, array dt = ary.dtype names = [name for name in fields if name in dt.names] @@ -298,7 +298,6 @@ def _index_fields(ary, fields): # Return a copy for now until behavior is fully deprecated # in favor of returning view copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']} - from numpy import array return array(view, dtype=copy_dtype, copy=True) # Given a string containing a PEP 3118 format specifier, diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index 246ebba6b..5645a0824 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -413,6 +413,11 @@ class TestString(TestCase): assert_equal(repr(dt), "dtype([('a', '<M8[D]'), ('b', '<m8[us]')])") + @dec.skipif(sys.version_info[0] > 2) + def test_dtype_str_with_long_in_shape(self): + # Pull request #376 + dt = np.dtype('(1L,)i4') + class TestDtypeAttributeDeletion(object): diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 6f2aa1d02..b07fde27d 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -18,6 +18,7 @@ import function_base import numpy.matrixlib as matrix from function_base import diff from numpy.lib._compiled_base import ravel_multi_index, unravel_index +from numpy.lib.stride_tricks import as_strided makemat = matrix.matrix def ix_(*args): @@ -531,37 +532,20 @@ class ndindex(object): (2, 1, 0) """ + def __init__(self, *shape): + x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape)) + self._it = _nx.nditer(x, flags=['multi_index'], order='C') - def __init__(self, *args): - if len(args) == 1 and isinstance(args[0], tuple): - args = args[0] - self.nd = len(args) - self.ind = [0]*self.nd - self.index = 0 - self.maxvals = args - tot = 1 - for k in range(self.nd): - tot *= args[k] - self.total = tot - - def _incrementone(self, axis): - if (axis < 0): # base case - return - if (self.ind[axis] < self.maxvals[axis]-1): - self.ind[axis] += 1 - else: - self.ind[axis] = 0 - self._incrementone(axis-1) + def __iter__(self): + return self def ndincr(self): """ Increment the multi-dimensional index by one. - `ndincr` takes care of the "wrapping around" of the axes. - It is called by `ndindex.next` and not normally used directly. - + This method is for backward compatibility only: do not use. """ - self._incrementone(self.nd-1) + self.next() def next(self): """ @@ -573,17 +557,8 @@ class ndindex(object): Returns a tuple containing the indices of the current iteration. """ - if (self.index >= self.total): - raise StopIteration - val = tuple(self.ind) - self.index += 1 - self.ndincr() - return val - - def __iter__(self): - return self - - + self._it.next() + return self._it.multi_index # You can do all this with slice() plus a few special objects, diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index beda2d146..0ede40d5a 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -2,7 +2,7 @@ from numpy.testing import * import numpy as np from numpy import ( array, ones, r_, mgrid, unravel_index, zeros, where, ndenumerate, fill_diagonal, diag_indices, - diag_indices_from, s_, index_exp ) + diag_indices_from, s_, index_exp, ndindex ) class TestRavelUnravelIndex(TestCase): def test_basic(self): @@ -237,5 +237,11 @@ def test_diag_indices_from(): assert_array_equal(c, np.arange(4)) +def test_ndindex(): + x = list(np.ndindex(1, 2, 3)) + expected = [ix for ix, e in np.ndenumerate(np.zeros((1, 2, 3)))] + assert_array_equal(x, expected) + + if __name__ == "__main__": run_module_suite() |