From 942bdb26bee477a8ac2c6b1c1c963c52c7c08403 Mon Sep 17 00:00:00 2001 From: Stefan van der Walt Date: Mon, 27 Aug 2012 09:24:59 -0700 Subject: Improve ndindex execution speed. --- numpy/lib/index_tricks.py | 49 ++++++++--------------------------------------- 1 file changed, 8 insertions(+), 41 deletions(-) (limited to 'numpy/lib/index_tricks.py') diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 6f2aa1d02..112a79bd1 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -18,6 +18,7 @@ import function_base import numpy.matrixlib as matrix from function_base import diff from numpy.lib._compiled_base import ravel_multi_index, unravel_index +from numpy.lib.stride_tricks import as_strided makemat = matrix.matrix def ix_(*args): @@ -531,37 +532,12 @@ class ndindex(object): (2, 1, 0) """ + def __init__(self, *shape): + x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape)) + self._it = _nx.nditer(x, flags=['multi_index']) - def __init__(self, *args): - if len(args) == 1 and isinstance(args[0], tuple): - args = args[0] - self.nd = len(args) - self.ind = [0]*self.nd - self.index = 0 - self.maxvals = args - tot = 1 - for k in range(self.nd): - tot *= args[k] - self.total = tot - - def _incrementone(self, axis): - if (axis < 0): # base case - return - if (self.ind[axis] < self.maxvals[axis]-1): - self.ind[axis] += 1 - else: - self.ind[axis] = 0 - self._incrementone(axis-1) - - def ndincr(self): - """ - Increment the multi-dimensional index by one. - - `ndincr` takes care of the "wrapping around" of the axes. - It is called by `ndindex.next` and not normally used directly. - - """ - self._incrementone(self.nd-1) + def __iter__(self): + return self def next(self): """ @@ -573,17 +549,8 @@ class ndindex(object): Returns a tuple containing the indices of the current iteration. """ - if (self.index >= self.total): - raise StopIteration - val = tuple(self.ind) - self.index += 1 - self.ndincr() - return val - - def __iter__(self): - return self - - + self._it.next() + return self._it.multi_index # You can do all this with slice() plus a few special objects, -- cgit v1.2.1