diff options
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 82 |
1 files changed, 80 insertions, 2 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 66f534734..00424d55d 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -1,5 +1,6 @@ from __future__ import division, absolute_import, print_function +import functools import warnings import numpy.core.numeric as _nx @@ -8,6 +9,7 @@ from numpy.core.numeric import ( ) from numpy.core.fromnumeric import product, reshape, transpose from numpy.core.multiarray import normalize_axis_index +from numpy.core import overrides from numpy.core import vstack, atleast_3d from numpy.lib.index_tricks import ndindex from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells @@ -21,6 +23,10 @@ __all__ = [ ] +array_function_dispatch = functools.partial( + overrides.array_function_dispatch, module='numpy') + + def _make_along_axis_idx(arr_shape, indices, axis): # compute dimensions to iterate over if not _nx.issubdtype(indices.dtype, _nx.integer): @@ -44,6 +50,11 @@ def _make_along_axis_idx(arr_shape, indices, axis): return tuple(fancy_index) +def _take_along_axis_dispatcher(arr, indices, axis): + return (arr, indices) + + +@array_function_dispatch(_take_along_axis_dispatcher) def take_along_axis(arr, indices, axis): """ Take values from the input array by matching 1d index and data slices. @@ -160,6 +171,11 @@ def take_along_axis(arr, indices, axis): return arr[_make_along_axis_idx(arr_shape, indices, axis)] +def _put_along_axis_dispatcher(arr, indices, values, axis): + return (arr, indices, values) + + +@array_function_dispatch(_put_along_axis_dispatcher) def put_along_axis(arr, indices, values, axis): """ Put values into the destination array by matching 1d index and data slices. @@ -245,6 +261,11 @@ def put_along_axis(arr, indices, values, axis): arr[_make_along_axis_idx(arr_shape, indices, axis)] = values +def _apply_along_axis_dispatcher(func1d, axis, arr, *args, **kwargs): + return (arr,) + + +@array_function_dispatch(_apply_along_axis_dispatcher) def apply_along_axis(func1d, axis, arr, *args, **kwargs): """ Apply a function to 1-D slices along the given axis. @@ -392,6 +413,11 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): return res.__array_wrap__(out_arr) +def _apply_over_axes_dispatcher(func, a, axes): + return (a,) + + +@array_function_dispatch(_apply_over_axes_dispatcher) def apply_over_axes(func, a, axes): """ Apply a function repeatedly over multiple axes. @@ -474,9 +500,15 @@ def apply_over_axes(func, a, axes): val = res else: raise ValueError("function is not returning " - "an array of the correct shape") + "an array of the correct shape") return val + +def _expand_dims_dispatcher(a, axis): + return (a,) + + +@array_function_dispatch(_expand_dims_dispatcher) def expand_dims(a, axis): """ Expand the shape of an array. @@ -554,8 +586,15 @@ def expand_dims(a, axis): # axis = normalize_axis_index(axis, a.ndim + 1) return a.reshape(shape[:axis] + (1,) + shape[axis:]) + row_stack = vstack + +def _column_stack_dispatcher(tup): + return tup + + +@array_function_dispatch(_column_stack_dispatcher) def column_stack(tup): """ Stack 1-D arrays as columns into a 2-D array. @@ -597,6 +636,12 @@ def column_stack(tup): arrays.append(arr) return _nx.concatenate(arrays, 1) + +def _dstack_dispatcher(tup): + return tup + + +@array_function_dispatch(_dstack_dispatcher) def dstack(tup): """ Stack arrays in sequence depth wise (along third axis). @@ -649,6 +694,7 @@ def dstack(tup): """ return _nx.concatenate([atleast_3d(_m) for _m in tup], 2) + def _replace_zero_by_x_arrays(sub_arys): for i in range(len(sub_arys)): if _nx.ndim(sub_arys[i]) == 0: @@ -657,6 +703,12 @@ def _replace_zero_by_x_arrays(sub_arys): sub_arys[i] = _nx.empty(0, dtype=sub_arys[i].dtype) return sub_arys + +def _array_split_dispatcher(ary, indices_or_sections, axis=None): + return (ary, indices_or_sections) + + +@array_function_dispatch(_array_split_dispatcher) def array_split(ary, indices_or_sections, axis=0): """ Split an array into multiple sub-arrays. @@ -712,7 +764,12 @@ def array_split(ary, indices_or_sections, axis=0): return sub_arys -def split(ary,indices_or_sections,axis=0): +def _split_dispatcher(ary, indices_or_sections, axis=None): + return (ary, indices_or_sections) + + +@array_function_dispatch(_split_dispatcher) +def split(ary, indices_or_sections, axis=0): """ Split an array into multiple sub-arrays. @@ -789,6 +846,12 @@ def split(ary,indices_or_sections,axis=0): res = array_split(ary, indices_or_sections, axis) return res + +def _hvdsplit_dispatcher(ary, indices_or_sections): + return (ary, indices_or_sections) + + +@array_function_dispatch(_hvdsplit_dispatcher) def hsplit(ary, indices_or_sections): """ Split an array into multiple sub-arrays horizontally (column-wise). @@ -851,6 +914,8 @@ def hsplit(ary, indices_or_sections): else: return split(ary, indices_or_sections, 0) + +@array_function_dispatch(_hvdsplit_dispatcher) def vsplit(ary, indices_or_sections): """ Split an array into multiple sub-arrays vertically (row-wise). @@ -902,6 +967,8 @@ def vsplit(ary, indices_or_sections): raise ValueError('vsplit only works on arrays of 2 or more dimensions') return split(ary, indices_or_sections, 0) + +@array_function_dispatch(_hvdsplit_dispatcher) def dsplit(ary, indices_or_sections): """ Split array into multiple sub-arrays along the 3rd axis (depth). @@ -971,6 +1038,12 @@ def get_array_wrap(*args): return wrappers[-1][-1] return None + +def _kron_dispatcher(a, b): + return (a, b) + + +@array_function_dispatch(_kron_dispatcher) def kron(a, b): """ Kronecker product of two arrays. @@ -1070,6 +1143,11 @@ def kron(a, b): return result +def _tile_dispatcher(A, reps): + return (A, reps) + + +@array_function_dispatch(_tile_dispatcher) def tile(A, reps): """ Construct an array by repeating A the number of times given by reps. |