summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-01-27 01:00:48 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-01-27 01:00:48 +0000
commitb7cc1f63d4bd51ea8df16305d8213a10d7493f3a (patch)
tree91d06fce1cf908a68b6141538c76c3b2eab7107b /numpy/lib/index_tricks.py
parent606c857c8e5c9ee844a58dd5919d14a6eaff191e (diff)
downloadnumpy-b7cc1f63d4bd51ea8df16305d8213a10d7493f3a.tar.gz
Add ndindex iterator to complement ndenumerate iterator.
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py52
1 files changed, 51 insertions, 1 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 604bd6059..899fd5f47 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -1,6 +1,6 @@
## Automatically adapted for numpy Sep 19, 2005 by convertcode.py
-__all__ = ['mgrid','ogrid','r_', 'c_', 'index_exp', 'ix_','ndenumerate']
+__all__ = ['mgrid','ogrid','r_', 'c_', 'index_exp', 'ix_','ndenumerate','ndindex']
import sys
import types
@@ -265,6 +265,56 @@ class ndenumerate(object):
def __iter__(self):
return self
+class ndindex(object):
+ """Pass in a sequence of integers corresponding
+ to the number of dimensions in the counter. This iterator
+ will then return an N-dimensional counter.
+
+ Example:
+ >>> for index in ndindex(4,3,2):
+ print index
+ (0,0,0)
+ (0,0,1)
+ (0,1,0)
+ ...
+ (3,1,1)
+ (3,2,0)
+ (3,2,1)
+ """
+
+ def __init__(self, *args):
+ self.nd = len(args)
+ self.ind = [0]*self.nd
+ self.index = 0
+ self.maxvals = args
+ tot = 1
+ for k in range(self.nd):
+ tot *= args[k]
+ self.total = tot
+
+ def _incrementone(self, axis):
+ if (axis < 0): # base case
+ return
+ if (self.ind[axis] < self.maxvals[axis]-1):
+ self.ind[axis]+=1
+ else:
+ self.ind[axis] = 0
+ self._incrementone(axis-1)
+
+ def ndincr(self):
+ self._incrementone(self.nd-1)
+
+ def next(self):
+ if (self.index >= self.total):
+ raise StopIteration
+ val = tuple(self.ind)
+ self.index+=1
+ self.ndincr()
+ return val
+
+ def __iter__(self):
+ return self
+
# You can do all this with slice() plus a few special objects,
# but there's a lot to remember. This version is simpler because