diff options
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 131 |
1 files changed, 72 insertions, 59 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 6e92aa041..da0b6a5b2 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -6,8 +6,10 @@ import numpy.core.numeric as _nx from numpy.core.numeric import ( asarray, zeros, outer, concatenate, isscalar, array, asanyarray ) -from numpy.core.fromnumeric import product, reshape +from numpy.core.fromnumeric import product, reshape, transpose from numpy.core import vstack, atleast_3d +from numpy.lib.index_tricks import ndindex +from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells __all__ = [ @@ -45,9 +47,10 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): ------- apply_along_axis : ndarray The output array. The shape of `outarr` is identical to the shape of - `arr`, except along the `axis` dimension, where the length of `outarr` - is equal to the size of the return value of `func1d`. If `func1d` - returns a scalar `outarr` will have one fewer dimensions than `arr`. + `arr`, except along the `axis` dimension. This axis is removed, and + replaced with new dimensions equal to the shape of the return value + of `func1d`. So if `func1d` returns a scalar `outarr` will have one + fewer dimensions than `arr`. See Also -------- @@ -64,7 +67,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): >>> np.apply_along_axis(my_func, 1, b) array([ 2., 5., 8.]) - For a function that doesn't return a scalar, the number of dimensions in + For a function that returns a 1D array, the number of dimensions in `outarr` is the same as `arr`. >>> b = np.array([[8,1,7], [4,3,9], [5,2,6]]) @@ -73,66 +76,76 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): [3, 4, 9], [2, 5, 6]]) + For a function that returns a higher dimensional array, those dimensions + are inserted in place of the `axis` dimension. + + >>> b = np.array([[1,2,3], [4,5,6], [7,8,9]]) + >>> np.apply_along_axis(np.diag, -1, b) + array([[[1, 0, 0], + [0, 2, 0], + [0, 0, 3]], + + [[4, 0, 0], + [0, 5, 0], + [0, 0, 6]], + + [[7, 0, 0], + [0, 8, 0], + [0, 0, 9]]]) """ + # handle negative axes arr = asanyarray(arr) nd = arr.ndim + if not (-nd <= axis < nd): + raise IndexError('axis {0} out of bounds [-{1}, {1})'.format(axis, nd)) if axis < 0: axis += nd - if (axis >= nd): - raise ValueError("axis must be less than arr.ndim; axis=%d, rank=%d." - % (axis, nd)) - ind = [0]*(nd-1) - i = zeros(nd, 'O') - indlist = list(range(nd)) - indlist.remove(axis) - i[axis] = slice(None, None) - outshape = asarray(arr.shape).take(indlist) - i.put(indlist, ind) - res = func1d(arr[tuple(i.tolist())], *args, **kwargs) - # if res is a number, then we have a smaller output array - if isscalar(res): - outarr = zeros(outshape, asarray(res).dtype) - outarr[tuple(ind)] = res - Ntot = product(outshape) - k = 1 - while k < Ntot: - # increment the index - ind[-1] += 1 - n = -1 - while (ind[n] >= outshape[n]) and (n > (1-nd)): - ind[n-1] += 1 - ind[n] = 0 - n -= 1 - i.put(indlist, ind) - res = func1d(arr[tuple(i.tolist())], *args, **kwargs) - outarr[tuple(ind)] = res - k += 1 - return outarr + + # arr, with the iteration axis at the end + in_dims = list(range(nd)) + inarr_view = transpose(arr, in_dims[:axis] + in_dims[axis+1:] + [axis]) + + # compute indices for the iteration axes + inds = ndindex(inarr_view.shape[:-1]) + + # invoke the function on the first item + ind0 = next(inds) + res = asanyarray(func1d(inarr_view[ind0], *args, **kwargs)) + + # build a buffer for storing evaluations of func1d. + # remove the requested axis, and add the new ones on the end. + # laid out so that each write is contiguous. + # for a tuple index inds, buff[inds] = func1d(inarr_view[inds]) + buff = zeros(inarr_view.shape[:-1] + res.shape, res.dtype) + + # permutation of axes such that out = buff.transpose(buff_permute) + buff_dims = list(range(buff.ndim)) + buff_permute = ( + buff_dims[0 : axis] + + buff_dims[buff.ndim-res.ndim : buff.ndim] + + buff_dims[axis : buff.ndim-res.ndim] + ) + + # matrices have a nasty __array_prepare__ and __array_wrap__ + if not isinstance(res, matrix): + buff = res.__array_prepare__(buff) + + # save the first result, then compute and save all remaining results + buff[ind0] = res + for ind in inds: + buff[ind] = asanyarray(func1d(inarr_view[ind], *args, **kwargs)) + + if not isinstance(res, matrix): + # wrap the array, to preserve subclasses + buff = res.__array_wrap__(buff) + + # finally, rotate the inserted axes back to where they belong + return transpose(buff, buff_permute) + else: - res = asanyarray(res) - Ntot = product(outshape) - holdshape = outshape - outshape = list(arr.shape) - outshape[axis] = res.size - outarr = zeros(outshape, res.dtype) - outarr = res.__array_wrap__(outarr) - outarr[tuple(i.tolist())] = res - k = 1 - while k < Ntot: - # increment the index - ind[-1] += 1 - n = -1 - while (ind[n] >= holdshape[n]) and (n > (1-nd)): - ind[n-1] += 1 - ind[n] = 0 - n -= 1 - i.put(indlist, ind) - res = asanyarray(func1d(arr[tuple(i.tolist())], *args, **kwargs)) - outarr[tuple(i.tolist())] = res - k += 1 - if res.shape == (): - outarr = outarr.squeeze(axis) - return outarr + # matrices have to be transposed first, because they collapse dimensions! + out_arr = transpose(buff, buff_permute) + return res.__array_wrap__(out_arr) def apply_over_axes(func, a, axes): |