summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2012-08-27 09:24:59 -0700
committerStefan van der Walt <stefan@sun.ac.za>2012-09-02 07:55:59 -0700
commit942bdb26bee477a8ac2c6b1c1c963c52c7c08403 (patch)
tree8a9d4ad6f9c47d4397dcb37255bef6a3b20b1a36 /numpy/lib/index_tricks.py
parente60c70d7ca5dbe45860c44673ddab02d47770155 (diff)
downloadnumpy-942bdb26bee477a8ac2c6b1c1c963c52c7c08403.tar.gz
Improve ndindex execution speed.
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py49
1 files changed, 8 insertions, 41 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 6f2aa1d02..112a79bd1 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -18,6 +18,7 @@ import function_base
import numpy.matrixlib as matrix
from function_base import diff
from numpy.lib._compiled_base import ravel_multi_index, unravel_index
+from numpy.lib.stride_tricks import as_strided
makemat = matrix.matrix
def ix_(*args):
@@ -531,37 +532,12 @@ class ndindex(object):
(2, 1, 0)
"""
+ def __init__(self, *shape):
+ x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape))
+ self._it = _nx.nditer(x, flags=['multi_index'])
- def __init__(self, *args):
- if len(args) == 1 and isinstance(args[0], tuple):
- args = args[0]
- self.nd = len(args)
- self.ind = [0]*self.nd
- self.index = 0
- self.maxvals = args
- tot = 1
- for k in range(self.nd):
- tot *= args[k]
- self.total = tot
-
- def _incrementone(self, axis):
- if (axis < 0): # base case
- return
- if (self.ind[axis] < self.maxvals[axis]-1):
- self.ind[axis] += 1
- else:
- self.ind[axis] = 0
- self._incrementone(axis-1)
-
- def ndincr(self):
- """
- Increment the multi-dimensional index by one.
-
- `ndincr` takes care of the "wrapping around" of the axes.
- It is called by `ndindex.next` and not normally used directly.
-
- """
- self._incrementone(self.nd-1)
+ def __iter__(self):
+ return self
def next(self):
"""
@@ -573,17 +549,8 @@ class ndindex(object):
Returns a tuple containing the indices of the current iteration.
"""
- if (self.index >= self.total):
- raise StopIteration
- val = tuple(self.ind)
- self.index += 1
- self.ndincr()
- return val
-
- def __iter__(self):
- return self
-
-
+ self._it.next()
+ return self._it.multi_index
# You can do all this with slice() plus a few special objects,