summaryrefslogtreecommitdiff
path: root/scipy/base/shape_base.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2005-09-20 10:58:41 +0000
committerTravis Oliphant <oliphant@enthought.com>2005-09-20 10:58:41 +0000
commita3874db675f2fa224085c50534c8c793ca7d577f (patch)
treeee6a375ab97b1f3bef446fdecec059359106a6a5 /scipy/base/shape_base.py
parent1a859e9b3f7a861fd9723745c92a820264d19a3b (diff)
downloadnumpy-a3874db675f2fa224085c50534c8c793ca7d577f.tar.gz
More fixes.
Diffstat (limited to 'scipy/base/shape_base.py')
-rw-r--r--scipy/base/shape_base.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/scipy/base/shape_base.py b/scipy/base/shape_base.py
index b0f76ccef..e4269dc55 100644
--- a/scipy/base/shape_base.py
+++ b/scipy/base/shape_base.py
@@ -71,17 +71,30 @@ def apply_along_axis(func1d,axis,arr,*args):
def apply_over_axes(func, a, axes):
- """Apply a function over multiple axes, keeping the same shape
+ """Apply a function repeatedly over multiple axes, keeping the same shape
for the resulting array.
+
+ func is called as res = func(a, axis). The result is assumed
+ to be either the same shape as a or have one less dimension.
+ This call is repeated for each axis in the axes sequence.
"""
val = asarray(a)
- N = len(val.shape)
- if not type(axes) in SequenceType:
+ N = a.ndim
+ if array(axes).ndim == 0:
axes = (axes,)
for axis in axes:
if axis < 0: axis = N + axis
args = (val, axis)
- val = expand_dims(func(*args),axis)
+ res = func(*args)
+ if res.ndim == val.ndim:
+ val = res
+ else:
+ res = expand_dims(res,axis)
+ if res.ndim == val.ndim:
+ val = res
+ else:
+ raise ValueError, "function is not returning"\
+ " an array of correct shape"
return val
def expand_dims(a, axis):