diff options
author | Stephan Hoyer <shoyer@google.com> | 2018-10-08 12:52:55 -0700 |
---|---|---|
committer | Stephan Hoyer <shoyer@google.com> | 2018-10-08 12:52:55 -0700 |
commit | 4141e24fc201f8cf76180ca69eaa2d89eafaee58 (patch) | |
tree | 3ee2de98a23d2972c9894332cddd9ee0969dc0ad /numpy/lib/function_base.py | |
parent | 2f4bc6f1d561230e9e295ce0d49fef5c4deb7ea0 (diff) | |
download | numpy-4141e24fc201f8cf76180ca69eaa2d89eafaee58.tar.gz |
ENH: __array_function__ for np.lib, part 1
np.lib.arraypad through np.lib.nanfunctions
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 159 |
1 files changed, 159 insertions, 0 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index e2a8f4bc2..c52ecdbd8 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -26,6 +26,7 @@ from numpy.core.fromnumeric import ( ravel, nonzero, partition, mean, any, sum ) from numpy.core.numerictypes import typecodes +from numpy.core.overrides import array_function_dispatch from numpy.core.function_base import add_newdoc from numpy.lib.twodim_base import diag from .utils import deprecate @@ -58,6 +59,11 @@ __all__ = [ ] +def _rot90_dispatcher(m, k=None, axes=None): + return (m,) + + +@array_function_dispatch(_rot90_dispatcher) def rot90(m, k=1, axes=(0,1)): """ Rotate an array by 90 degrees in the plane specified by axes. @@ -144,6 +150,11 @@ def rot90(m, k=1, axes=(0,1)): return flip(transpose(m, axes_list), axes[1]) +def _flip_dispatcher(m, axis=None): + return (m,) + + +@array_function_dispatch(_flip_dispatcher) def flip(m, axis=None): """ Reverse the order of elements in an array along the given axis. @@ -268,6 +279,11 @@ def iterable(y): return True +def _average_dispatcher(a, axis=None, weights=None, returned=None): + return (a, weights) + + +@array_function_dispatch(_average_dispatcher) def average(a, axis=None, weights=None, returned=False): """ Compute the weighted average along the specified axis. @@ -474,6 +490,15 @@ def asarray_chkfinite(a, dtype=None, order=None): return a +def _piecewise_dispatcher(x, condlist, funclist, *args, **kw): + yield x + # support the undocumented behavior of allowing scalars + if np.iterable(condlist): + for c in condlist: + yield c + + +@array_function_dispatch(_piecewise_dispatcher) def piecewise(x, condlist, funclist, *args, **kw): """ Evaluate a piecewise-defined function. @@ -595,6 +620,14 @@ def piecewise(x, condlist, funclist, *args, **kw): return y +def _select_dispatcher(condlist, choicelist, default=None): + for c in condlist: + yield c + for c in choicelist: + yield c + + +@array_function_dispatch(_select_dispatcher) def select(condlist, choicelist, default=0): """ Return an array drawn from elements in choicelist, depending on conditions. @@ -698,6 +731,11 @@ def select(condlist, choicelist, default=0): return result +def _copy_dispatcher(a, order=None): + return (a,) + + +@array_function_dispatch(_copy_dispatcher) def copy(a, order='K'): """ Return an array copy of the given object. @@ -747,6 +785,13 @@ def copy(a, order='K'): # Basic operations +def _gradient_dispatcher(f, *varargs, **kwargs): + yield f + for v in varargs: + yield v + + +@array_function_dispatch(_gradient_dispatcher) def gradient(f, *varargs, **kwargs): """ Return the gradient of an N-dimensional array. @@ -1088,6 +1133,11 @@ def gradient(f, *varargs, **kwargs): return outvals +def _diff_dispatcher(a, n=None, axis=None, prepend=None, append=None): + return (a, prepend, append) + + +@array_function_dispatch(_diff_dispatcher) def diff(a, n=1, axis=-1, prepend=np._NoValue, append=np._NoValue): """ Calculate the n-th discrete difference along the given axis. @@ -1216,6 +1266,11 @@ def diff(a, n=1, axis=-1, prepend=np._NoValue, append=np._NoValue): return a +def _interp_dispatcher(x, xp, fp, left=None, right=None, period=None): + return (x, xp, fp) + + +@array_function_dispatch(_interp_dispatcher) def interp(x, xp, fp, left=None, right=None, period=None): """ One-dimensional linear interpolation. @@ -1348,6 +1403,11 @@ def interp(x, xp, fp, left=None, right=None, period=None): return interp_func(x, xp, fp, left, right) +def _angle_dispatcher(z, deg=None): + return (z,) + + +@array_function_dispatch(_angle_dispatcher) def angle(z, deg=False): """ Return the angle of the complex argument. @@ -1395,6 +1455,11 @@ def angle(z, deg=False): return a +def _unwrap_dispatcher(p, discont=None, axis=None): + return (p,) + + +@array_function_dispatch(_unwrap_dispatcher) def unwrap(p, discont=pi, axis=-1): """ Unwrap by changing deltas between values to 2*pi complement. @@ -1451,6 +1516,11 @@ def unwrap(p, discont=pi, axis=-1): return up +def _sort_complex(a): + return (a,) + + +@array_function_dispatch(_sort_complex) def sort_complex(a): """ Sort a complex array using the real part first, then the imaginary part. @@ -1487,6 +1557,11 @@ def sort_complex(a): return b +def _trim_zeros(filt, trim=None): + return (filt,) + + +@array_function_dispatch(_trim_zeros) def trim_zeros(filt, trim='fb'): """ Trim the leading and/or trailing zeros from a 1-D array or sequence. @@ -1556,6 +1631,11 @@ def unique(x): return asarray(items) +def _extract_dispatcher(condition, arr): + return (condition, arr) + + +@array_function_dispatch(_extract_dispatcher) def extract(condition, arr): """ Return the elements of an array that satisfy some condition. @@ -1607,6 +1687,11 @@ def extract(condition, arr): return _nx.take(ravel(arr), nonzero(ravel(condition))[0]) +def _place_dispatcher(arr, mask, vals): + return (arr, mask, vals) + + +@array_function_dispatch(_place_dispatcher) def place(arr, mask, vals): """ Change elements of an array based on conditional and input values. @@ -2161,6 +2246,12 @@ class vectorize(object): return outputs[0] if nout == 1 else outputs +def _cov_dispatcher(m, y=None, rowvar=None, bias=None, ddof=None, + fweights=None, aweights=None): + return (m, y, fweights, aweights) + + +@array_function_dispatch(_cov_dispatcher) def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None): """ @@ -2370,6 +2461,11 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, return c.squeeze() +def _corrcoef_dispatcher(x, y=None, rowvar=None, bias=None, ddof=None): + return (x, y) + + +@array_function_dispatch(_corrcoef_dispatcher) def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue): """ Return Pearson product-moment correlation coefficients. @@ -2938,6 +3034,11 @@ def _i0_2(x): return exp(x) * _chbevl(32.0/x - 2.0, _i0B) / sqrt(x) +def _i0_dispatcher(x): + return (x,) + + +@array_function_dispatch(_i0_dispatcher) def i0(x): """ Modified Bessel function of the first kind, order 0. @@ -3132,6 +3233,11 @@ def kaiser(M, beta): return i0(beta * sqrt(1-((n-alpha)/alpha)**2.0))/i0(float(beta)) +def _sinc_dispatcher(x): + return (x,) + + +@array_function_dispatch(_sinc_dispatcher) def sinc(x): """ Return the sinc function. @@ -3211,6 +3317,11 @@ def sinc(x): return sin(y)/y +def _msort_dispatcher(a): + return (a,) + + +@array_function_dispatch(_msort_dispatcher) def msort(a): """ Return a copy of an array sorted along the first axis. @@ -3294,6 +3405,12 @@ def _ureduce(a, func, **kwargs): return r, keepdim +def _median_dispatcher( + a, axis=None, out=None, overwrite_input=None, keepdims=None): + return (a, out) + + +@array_function_dispatch(_median_dispatcher) def median(a, axis=None, out=None, overwrite_input=False, keepdims=False): """ Compute the median along the specified axis. @@ -3438,6 +3555,12 @@ def _median(a, axis=None, out=None, overwrite_input=False): return mean(part[indexer], axis=axis, out=out) +def _percentile_dispatcher(a, q, axis=None, out=None, overwrite_input=None, + interpolation=None, keepdims=None): + return (a, q, out) + + +@array_function_dispatch(_percentile_dispatcher) def percentile(a, q, axis=None, out=None, overwrite_input=False, interpolation='linear', keepdims=False): """ @@ -3583,6 +3706,12 @@ def percentile(a, q, axis=None, out=None, a, q, axis, out, overwrite_input, interpolation, keepdims) +def _quantile_dispatcher(a, q, axis=None, out=None, overwrite_input=None, + interpolation=None, keepdims=None): + return (a, q, out) + + +@array_function_dispatch(_quantile_dispatcher) def quantile(a, q, axis=None, out=None, overwrite_input=False, interpolation='linear', keepdims=False): """ @@ -3845,6 +3974,11 @@ def _quantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False, return r +def _trapz_dispatcher(y, x=None, dx=None, axis=None): + return (y, x) + + +@array_function_dispatch(_trapz_dispatcher) def trapz(y, x=None, dx=1.0, axis=-1): """ Integrate along the given axis using the composite trapezoidal rule. @@ -3935,7 +4069,12 @@ def trapz(y, x=None, dx=1.0, axis=-1): return ret +def _meshgrid_dispatcher(*xi, **kwargs): + return xi + + # Based on scitools meshgrid +@array_function_dispatch(_meshgrid_dispatcher) def meshgrid(*xi, **kwargs): """ Return coordinate matrices from coordinate vectors. @@ -4073,6 +4212,11 @@ def meshgrid(*xi, **kwargs): return output +def _delete_dispatcher(arr, obj, axis=None): + return (arr, obj) + + +@array_function_dispatch(_delete_dispatcher) def delete(arr, obj, axis=None): """ Return a new array with sub-arrays along an axis deleted. For a one @@ -4278,6 +4422,11 @@ def delete(arr, obj, axis=None): return new +def _insert_dispatcher(arr, obj, values, axis=None): + return (arr, obj, values) + + +@array_function_dispatch(_insert_dispatcher) def insert(arr, obj, values, axis=None): """ Insert values along the given axis before the given indices. @@ -4484,6 +4633,11 @@ def insert(arr, obj, values, axis=None): return new +def _append_dispatcher(arr, values, axis=None): + return (arr, values) + + +@array_function_dispatch(_append_dispatcher) def append(arr, values, axis=None): """ Append values to the end of an array. @@ -4539,6 +4693,11 @@ def append(arr, values, axis=None): return concatenate((arr, values), axis=axis) +def _digitize_dispatcher(x, bins, right=None): + return (x, bins) + + +@array_function_dispatch(_digitize_dispatcher) def digitize(x, bins, right=False): """ Return the indices of the bins to which each value in input array belongs. |