summaryrefslogtreecommitdiff
path: root/numpy/lib/shape_base.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@gmail.com>2017-02-11 23:49:25 -0800
committerGitHub <noreply@github.com>2017-02-11 23:49:25 -0800
commit51a824051c82b36c1d0f17f81c6a9ef24375cf42 (patch)
treedbe1b2046032b5c7712f4d9fc031a3274d74025a /numpy/lib/shape_base.py
parent5de1a8215edc5982458473b2ef6da091beb4b19d (diff)
parentd4bce016f19dd8c8e2761818741d1c4024a34f21 (diff)
downloadnumpy-51a824051c82b36c1d0f17f81c6a9ef24375cf42.tar.gz
Merge pull request #8441 from eric-wieser/apply_along_axis-nd
BUG: Fix crash on 0d return value in apply_along_axis
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r--numpy/lib/shape_base.py131
1 files changed, 72 insertions, 59 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 6e92aa041..da0b6a5b2 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -6,8 +6,10 @@ import numpy.core.numeric as _nx
from numpy.core.numeric import (
asarray, zeros, outer, concatenate, isscalar, array, asanyarray
)
-from numpy.core.fromnumeric import product, reshape
+from numpy.core.fromnumeric import product, reshape, transpose
from numpy.core import vstack, atleast_3d
+from numpy.lib.index_tricks import ndindex
+from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells
__all__ = [
@@ -45,9 +47,10 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
-------
apply_along_axis : ndarray
The output array. The shape of `outarr` is identical to the shape of
- `arr`, except along the `axis` dimension, where the length of `outarr`
- is equal to the size of the return value of `func1d`. If `func1d`
- returns a scalar `outarr` will have one fewer dimensions than `arr`.
+ `arr`, except along the `axis` dimension. This axis is removed, and
+ replaced with new dimensions equal to the shape of the return value
+ of `func1d`. So if `func1d` returns a scalar `outarr` will have one
+ fewer dimensions than `arr`.
See Also
--------
@@ -64,7 +67,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
>>> np.apply_along_axis(my_func, 1, b)
array([ 2., 5., 8.])
- For a function that doesn't return a scalar, the number of dimensions in
+ For a function that returns a 1D array, the number of dimensions in
`outarr` is the same as `arr`.
>>> b = np.array([[8,1,7], [4,3,9], [5,2,6]])
@@ -73,66 +76,76 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
[3, 4, 9],
[2, 5, 6]])
+ For a function that returns a higher dimensional array, those dimensions
+ are inserted in place of the `axis` dimension.
+
+ >>> b = np.array([[1,2,3], [4,5,6], [7,8,9]])
+ >>> np.apply_along_axis(np.diag, -1, b)
+ array([[[1, 0, 0],
+ [0, 2, 0],
+ [0, 0, 3]],
+
+ [[4, 0, 0],
+ [0, 5, 0],
+ [0, 0, 6]],
+
+ [[7, 0, 0],
+ [0, 8, 0],
+ [0, 0, 9]]])
"""
+ # handle negative axes
arr = asanyarray(arr)
nd = arr.ndim
+ if not (-nd <= axis < nd):
+ raise IndexError('axis {0} out of bounds [-{1}, {1})'.format(axis, nd))
if axis < 0:
axis += nd
- if (axis >= nd):
- raise ValueError("axis must be less than arr.ndim; axis=%d, rank=%d."
- % (axis, nd))
- ind = [0]*(nd-1)
- i = zeros(nd, 'O')
- indlist = list(range(nd))
- indlist.remove(axis)
- i[axis] = slice(None, None)
- outshape = asarray(arr.shape).take(indlist)
- i.put(indlist, ind)
- res = func1d(arr[tuple(i.tolist())], *args, **kwargs)
- # if res is a number, then we have a smaller output array
- if isscalar(res):
- outarr = zeros(outshape, asarray(res).dtype)
- outarr[tuple(ind)] = res
- Ntot = product(outshape)
- k = 1
- while k < Ntot:
- # increment the index
- ind[-1] += 1
- n = -1
- while (ind[n] >= outshape[n]) and (n > (1-nd)):
- ind[n-1] += 1
- ind[n] = 0
- n -= 1
- i.put(indlist, ind)
- res = func1d(arr[tuple(i.tolist())], *args, **kwargs)
- outarr[tuple(ind)] = res
- k += 1
- return outarr
+
+ # arr, with the iteration axis at the end
+ in_dims = list(range(nd))
+ inarr_view = transpose(arr, in_dims[:axis] + in_dims[axis+1:] + [axis])
+
+ # compute indices for the iteration axes
+ inds = ndindex(inarr_view.shape[:-1])
+
+ # invoke the function on the first item
+ ind0 = next(inds)
+ res = asanyarray(func1d(inarr_view[ind0], *args, **kwargs))
+
+ # build a buffer for storing evaluations of func1d.
+ # remove the requested axis, and add the new ones on the end.
+ # laid out so that each write is contiguous.
+ # for a tuple index inds, buff[inds] = func1d(inarr_view[inds])
+ buff = zeros(inarr_view.shape[:-1] + res.shape, res.dtype)
+
+ # permutation of axes such that out = buff.transpose(buff_permute)
+ buff_dims = list(range(buff.ndim))
+ buff_permute = (
+ buff_dims[0 : axis] +
+ buff_dims[buff.ndim-res.ndim : buff.ndim] +
+ buff_dims[axis : buff.ndim-res.ndim]
+ )
+
+ # matrices have a nasty __array_prepare__ and __array_wrap__
+ if not isinstance(res, matrix):
+ buff = res.__array_prepare__(buff)
+
+ # save the first result, then compute and save all remaining results
+ buff[ind0] = res
+ for ind in inds:
+ buff[ind] = asanyarray(func1d(inarr_view[ind], *args, **kwargs))
+
+ if not isinstance(res, matrix):
+ # wrap the array, to preserve subclasses
+ buff = res.__array_wrap__(buff)
+
+ # finally, rotate the inserted axes back to where they belong
+ return transpose(buff, buff_permute)
+
else:
- res = asanyarray(res)
- Ntot = product(outshape)
- holdshape = outshape
- outshape = list(arr.shape)
- outshape[axis] = res.size
- outarr = zeros(outshape, res.dtype)
- outarr = res.__array_wrap__(outarr)
- outarr[tuple(i.tolist())] = res
- k = 1
- while k < Ntot:
- # increment the index
- ind[-1] += 1
- n = -1
- while (ind[n] >= holdshape[n]) and (n > (1-nd)):
- ind[n-1] += 1
- ind[n] = 0
- n -= 1
- i.put(indlist, ind)
- res = asanyarray(func1d(arr[tuple(i.tolist())], *args, **kwargs))
- outarr[tuple(i.tolist())] = res
- k += 1
- if res.shape == ():
- outarr = outarr.squeeze(axis)
- return outarr
+ # matrices have to be transposed first, because they collapse dimensions!
+ out_arr = transpose(buff, buff_permute)
+ return res.__array_wrap__(out_arr)
def apply_over_axes(func, a, axes):