summaryrefslogtreecommitdiff
path: root/numpy/lib/shape_base.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-01-02 23:36:46 +0000
committerEric Wieser <wieser.eric@gmail.com>2017-02-11 21:09:06 +0000
commit5307aeda11a5d567a966a16232860fd852734606 (patch)
tree63439db4bb4c20f83293299ef76f043217f7c0cb /numpy/lib/shape_base.py
parent52988ea98120c10a927325401257eb3715a51b15 (diff)
downloadnumpy-5307aeda11a5d567a966a16232860fd852734606.tar.gz
MAINT: Use np.ndindex, which seems just as efficient
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r--numpy/lib/shape_base.py26
1 files changed, 8 insertions, 18 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index c9ef6084e..dfee35f96 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -8,6 +8,7 @@ from numpy.core.numeric import (
)
from numpy.core.fromnumeric import product, reshape, transpose
from numpy.core import vstack, atleast_3d
+from numpy.lib.index_tricks import ndindex
__all__ = [
@@ -104,14 +105,12 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
dims_in = list(range(nd))
inarr_view = transpose(arr, dims_in[:axis] + dims_in[axis+1:] + [axis])
- # number of iterations
- Ntot = product(inarr_view.shape[:nd-1])
-
- # current index
- ind = [0]*(nd - 1)
+ # indices
+ inds = ndindex(inarr_view.shape[:nd-1])
# invoke the function on the first item
- res = func1d(inarr_view[tuple(ind)], *args, **kwargs)
+ ind0 = next(inds)
+ res = func1d(inarr_view[ind0], *args, **kwargs)
res = asanyarray(res)
# insert as many axes as necessary to create the output
@@ -125,18 +124,9 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
outarr_view = transpose(outarr, dims_out[:axis] + dims_out[axis+res.ndim:] + dims_out[axis:axis+res.ndim])
# save the first result
- outarr_view[tuple(ind)] = res
- k = 1
- while k < Ntot:
- # increment the index
- ind[-1] += 1
- n = -1
- while (ind[n] >= outshape[n]) and (n > (1-nd)):
- ind[n-1] += 1
- ind[n] = 0
- n -= 1
- outarr_view[tuple(ind)] = asanyarray(func1d(inarr_view[tuple(ind)], *args, **kwargs))
- k += 1
+ outarr_view[ind0] = res
+ for ind in inds:
+ outarr_view[ind] = asanyarray(func1d(inarr_view[ind], *args, **kwargs))
return outarr