summaryrefslogtreecommitdiff
path: root/numpy/lib/shape_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r--numpy/lib/shape_base.py57
1 files changed, 34 insertions, 23 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 120153ad5..9c16982c4 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -103,39 +103,50 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
% (axis, nd))
# arr, with the iteration axis at the end
- dims_in = list(range(nd))
- inarr_view = transpose(arr, dims_in[:axis] + dims_in[axis+1:] + [axis])
+ in_dims = list(range(nd))
+ inarr_view = transpose(arr, in_dims[:axis] + in_dims[axis+1:] + [axis])
- # indices
- inds = ndindex(inarr_view.shape[:nd-1])
+ # compute indices for the iteration axes
+ inds = ndindex(inarr_view.shape[:-1])
# invoke the function on the first item
ind0 = next(inds)
- res = func1d(inarr_view[ind0], *args, **kwargs)
- res = asanyarray(res)
+ res = asanyarray(func1d(inarr_view[ind0], *args, **kwargs))
+
+ # build a buffer for storing evaluations of func1d.
+ # remove the requested axis, and add the new ones on the end.
+ # laid out so that each write is contiguous.
+ # for a tuple index inds, buff[inds] = func1d(inarr_view[inds])
+ buff = zeros(inarr_view.shape[:-1] + res.shape, res.dtype)
+
+ # permutation of axes such that out = buff.transpose(buff_permute)
+ buff_dims = list(range(buff.ndim))
+ buff_permute = (
+ buff_dims[0 : axis] +
+ buff_dims[buff.ndim-res.ndim : buff.ndim] +
+ buff_dims[axis : buff.ndim-res.ndim]
+ )
- # insert as many axes as necessary to create the output
- outshape = arr.shape[:axis] + res.shape + arr.shape[axis+1:]
- outarr = zeros(outshape, res.dtype)
- # matrices call reshape in __array_prepare__, and this is apparently ok!
+ # matrices have a nasty __array_prepare__ and __array_wrap__
if not isinstance(res, matrix):
- outarr = res.__array_prepare__(outarr)
- if outshape != outarr.shape:
- raise ValueError('__array_prepare__ should not change the shape of the resultant array')
-
- # outarr, with inserted dimensions at the end
- # this means that outarr_view[i] = func1d(inarr_view[i])
- dims_out = list(range(outarr.ndim))
- outarr_view = transpose(outarr, dims_out[:axis] + dims_out[axis+res.ndim:] + dims_out[axis:axis+res.ndim])
+ buff = res.__array_prepare__(buff)
- # save the first result
- outarr_view[ind0] = res
+ # save the first result, then compute and save all remaining results
+ buff[ind0] = res
for ind in inds:
- outarr_view[ind] = asanyarray(func1d(inarr_view[ind], *args, **kwargs))
+ buff[ind] = asanyarray(func1d(inarr_view[ind], *args, **kwargs))
- outarr = res.__array_wrap__(outarr)
+ if not isinstance(res, matrix):
+ # wrap the array, to preserve subclasses
+ buff = res.__array_wrap__(buff)
- return outarr
+ # finally, rotate the inserted axes back to where they belong
+ return transpose(buff, buff_permute)
+
+ else:
+ # matrices have to be transposed first, because they collapse dimensions!
+ out_arr = transpose(buff, buff_permute)
+ return res.__array_wrap__(out_arr)
def apply_over_axes(func, a, axes):