diff options
author | Travis Oliphant <oliphant@enthought.com> | 2005-09-20 10:58:41 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2005-09-20 10:58:41 +0000 |
commit | a3874db675f2fa224085c50534c8c793ca7d577f (patch) | |
tree | ee6a375ab97b1f3bef446fdecec059359106a6a5 /scipy/base/shape_base.py | |
parent | 1a859e9b3f7a861fd9723745c92a820264d19a3b (diff) | |
download | numpy-a3874db675f2fa224085c50534c8c793ca7d577f.tar.gz |
More fixes.
Diffstat (limited to 'scipy/base/shape_base.py')
-rw-r--r-- | scipy/base/shape_base.py | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/scipy/base/shape_base.py b/scipy/base/shape_base.py index b0f76ccef..e4269dc55 100644 --- a/scipy/base/shape_base.py +++ b/scipy/base/shape_base.py @@ -71,17 +71,30 @@ def apply_along_axis(func1d,axis,arr,*args): def apply_over_axes(func, a, axes): - """Apply a function over multiple axes, keeping the same shape + """Apply a function repeatedly over multiple axes, keeping the same shape for the resulting array. + + func is called as res = func(a, axis). The result is assumed + to be either the same shape as a or have one less dimension. + This call is repeated for each axis in the axes sequence. """ val = asarray(a) - N = len(val.shape) - if not type(axes) in SequenceType: + N = a.ndim + if array(axes).ndim == 0: axes = (axes,) for axis in axes: if axis < 0: axis = N + axis args = (val, axis) - val = expand_dims(func(*args),axis) + res = func(*args) + if res.ndim == val.ndim: + val = res + else: + res = expand_dims(res,axis) + if res.ndim == val.ndim: + val = res + else: + raise ValueError, "function is not returning"\ + " an array of correct shape" return val def expand_dims(a, axis): |