diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 52 |
1 files changed, 51 insertions, 1 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 604bd6059..899fd5f47 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -1,6 +1,6 @@ ## Automatically adapted for numpy Sep 19, 2005 by convertcode.py -__all__ = ['mgrid','ogrid','r_', 'c_', 'index_exp', 'ix_','ndenumerate'] +__all__ = ['mgrid','ogrid','r_', 'c_', 'index_exp', 'ix_','ndenumerate','ndindex'] import sys import types @@ -265,6 +265,56 @@ class ndenumerate(object): def __iter__(self): return self +class ndindex(object): + """Pass in a sequence of integers corresponding + to the number of dimensions in the counter. This iterator + will then return an N-dimensional counter. + + Example: + >>> for index in ndindex(4,3,2): + print index + (0,0,0) + (0,0,1) + (0,1,0) + ... + (3,1,1) + (3,2,0) + (3,2,1) + """ + + def __init__(self, *args): + self.nd = len(args) + self.ind = [0]*self.nd + self.index = 0 + self.maxvals = args + tot = 1 + for k in range(self.nd): + tot *= args[k] + self.total = tot + + def _incrementone(self, axis): + if (axis < 0): # base case + return + if (self.ind[axis] < self.maxvals[axis]-1): + self.ind[axis]+=1 + else: + self.ind[axis] = 0 + self._incrementone(axis-1) + + def ndincr(self): + self._incrementone(self.nd-1) + + def next(self): + if (self.index >= self.total): + raise StopIteration + val = tuple(self.ind) + self.index+=1 + self.ndincr() + return val + + def __iter__(self): + return self + # You can do all this with slice() plus a few special objects, # but there's a lot to remember. This version is simpler because |