diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-01-02 23:36:46 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-02-11 21:09:06 +0000 |
commit | 5307aeda11a5d567a966a16232860fd852734606 (patch) | |
tree | 63439db4bb4c20f83293299ef76f043217f7c0cb /numpy/lib/shape_base.py | |
parent | 52988ea98120c10a927325401257eb3715a51b15 (diff) | |
download | numpy-5307aeda11a5d567a966a16232860fd852734606.tar.gz |
MAINT: Use np.ndindex, which seems just as efficient
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 26 |
1 files changed, 8 insertions, 18 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index c9ef6084e..dfee35f96 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -8,6 +8,7 @@ from numpy.core.numeric import ( ) from numpy.core.fromnumeric import product, reshape, transpose from numpy.core import vstack, atleast_3d +from numpy.lib.index_tricks import ndindex __all__ = [ @@ -104,14 +105,12 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): dims_in = list(range(nd)) inarr_view = transpose(arr, dims_in[:axis] + dims_in[axis+1:] + [axis]) - # number of iterations - Ntot = product(inarr_view.shape[:nd-1]) - - # current index - ind = [0]*(nd - 1) + # indices + inds = ndindex(inarr_view.shape[:nd-1]) # invoke the function on the first item - res = func1d(inarr_view[tuple(ind)], *args, **kwargs) + ind0 = next(inds) + res = func1d(inarr_view[ind0], *args, **kwargs) res = asanyarray(res) # insert as many axes as necessary to create the output @@ -125,18 +124,9 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): outarr_view = transpose(outarr, dims_out[:axis] + dims_out[axis+res.ndim:] + dims_out[axis:axis+res.ndim]) # save the first result - outarr_view[tuple(ind)] = res - k = 1 - while k < Ntot: - # increment the index - ind[-1] += 1 - n = -1 - while (ind[n] >= outshape[n]) and (n > (1-nd)): - ind[n-1] += 1 - ind[n] = 0 - n -= 1 - outarr_view[tuple(ind)] = asanyarray(func1d(inarr_view[tuple(ind)], *args, **kwargs)) - k += 1 + outarr_view[ind0] = res + for ind in inds: + outarr_view[ind] = asanyarray(func1d(inarr_view[ind], *args, **kwargs)) return outarr |