diff options
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 26 |
1 files changed, 8 insertions, 18 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index c9ef6084e..dfee35f96 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -8,6 +8,7 @@ from numpy.core.numeric import ( ) from numpy.core.fromnumeric import product, reshape, transpose from numpy.core import vstack, atleast_3d +from numpy.lib.index_tricks import ndindex __all__ = [ @@ -104,14 +105,12 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): dims_in = list(range(nd)) inarr_view = transpose(arr, dims_in[:axis] + dims_in[axis+1:] + [axis]) - # number of iterations - Ntot = product(inarr_view.shape[:nd-1]) - - # current index - ind = [0]*(nd - 1) + # indices + inds = ndindex(inarr_view.shape[:nd-1]) # invoke the function on the first item - res = func1d(inarr_view[tuple(ind)], *args, **kwargs) + ind0 = next(inds) + res = func1d(inarr_view[ind0], *args, **kwargs) res = asanyarray(res) # insert as many axes as necessary to create the output @@ -125,18 +124,9 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): outarr_view = transpose(outarr, dims_out[:axis] + dims_out[axis+res.ndim:] + dims_out[axis:axis+res.ndim]) # save the first result - outarr_view[tuple(ind)] = res - k = 1 - while k < Ntot: - # increment the index - ind[-1] += 1 - n = -1 - while (ind[n] >= outshape[n]) and (n > (1-nd)): - ind[n-1] += 1 - ind[n] = 0 - n -= 1 - outarr_view[tuple(ind)] = asanyarray(func1d(inarr_view[tuple(ind)], *args, **kwargs)) - k += 1 + outarr_view[ind0] = res + for ind in inds: + outarr_view[ind] = asanyarray(func1d(inarr_view[ind], *args, **kwargs)) return outarr |