summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-05-12 02:51:10 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-05-12 02:51:10 +0000
commit7cd1911082f7c2a431cd8bee944ac0abc61c8f9b (patch)
treecbbc3288f08edea649373b5af64b0c73c707582a /numpy/lib/index_tricks.py
parent916b29a1040d0d83ac6beef1138112873581b61e (diff)
downloadnumpy-7cd1911082f7c2a431cd8bee944ac0abc61c8f9b.tar.gz
Fix crash for zero-size arrays.
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py53
1 files changed, 49 insertions, 4 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 14655f92a..444c67e03 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -1,6 +1,11 @@
## Automatically adapted for numpy Sep 19, 2005 by convertcode.py
-__all__ = ['mgrid','ogrid','r_', 'c_', 'index_exp', 'ix_','ndenumerate','ndindex']
+__all__ = ['fromflat',
+ 'mgrid',
+ 'ogrid',
+ 'r_', 'c_', 's_',
+ 'index_exp', 'ix_',
+ 'ndenumerate','ndindex']
import sys
import types
@@ -11,6 +16,42 @@ import function_base
import numpy.core.defmatrix as matrix
makemat = matrix.matrix
+
+# fromflat contributed by Stefan van der Walt
+def fromflat(x,dims):
+ """Convert a flat index into an index tuple for a matrix of given shape.
+
+ e.g. for a 2x2 matrix, fromflat(2,(2,2)) translates to (1,0).
+
+ Example usage:
+ p = x.argmax()
+ idx = fromflat(p)
+ x[idx] == x.max()
+
+ Note: x.flat[p] == x.max()
+
+ Thus, it may be easier to use flattened indexing than to re-map
+ the index to a tuple.
+ """
+ if x > _nx.prod(dims)-1 or x < 0:
+ raise ValueError("Invalid index, must be 0 <= x <= number of elements.")
+
+ idx = _nx.empty_like(dims)
+
+ # Take dimensions
+ # [a,b,c,d]
+ # Reverse and drop first element
+ # [d,c,b]
+ # Prepend [1]
+ # [1,d,c,b]
+ # Calculate cumulative product
+ # [1,d,dc,dcb]
+ # Reverse
+ # [dcb,dc,d,1]
+ dim_prod = _nx.cumprod([1] + list(dims)[:0:-1])[::-1]
+ # Indeces become [x/dcb % a, x/dc % b, x/d % c, x/1 % d]
+ return tuple(x/dim_prod % dims)
+
def ix_(*args):
""" Construct an open mesh from multiple sequences.
@@ -252,7 +293,6 @@ class c_class(concatenator):
c_ = c_class()
-
class ndenumerate(object):
"""
A simple nd index iterator over an array.
@@ -326,6 +366,8 @@ class ndindex(object):
return self
+
+
# You can do all this with slice() plus a few special objects,
# but there's a lot to remember. This version is simpler because
# it uses the standard array indexing syntax.
@@ -351,9 +393,11 @@ class _index_expression_class(object):
used in the construction of complex index expressions.
"""
maxint = sys.maxint
+ def __init__(self, maketuple):
+ self.maketuple = maketuple
def __getitem__(self, item):
- if type(item) != type(()):
+ if self.maketuple and type(item) != type(()):
return (item,)
else:
return item
@@ -366,6 +410,7 @@ class _index_expression_class(object):
stop = None
return self[start:stop:None]
-index_exp = _index_expression_class()
+index_exp = _index_expression_class(1)
+s_ = _index_expression_class(0)
# End contribution from Konrad.