summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py203
1 files changed, 194 insertions, 9 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 2992e92bb..c52ecdbd8 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -9,24 +9,24 @@ except ImportError:
import re
import sys
import warnings
-import operator
import numpy as np
import numpy.core.numeric as _nx
-from numpy.core import linspace, atleast_1d, atleast_2d, transpose
+from numpy.core import atleast_1d, transpose
from numpy.core.numeric import (
ones, zeros, arange, concatenate, array, asarray, asanyarray, empty,
empty_like, ndarray, around, floor, ceil, take, dot, where, intp,
- integer, isscalar, absolute, AxisError
+ integer, isscalar, absolute
)
from numpy.core.umath import (
- pi, multiply, add, arctan2, frompyfunc, cos, less_equal, sqrt, sin,
- mod, exp, log10, not_equal, subtract
+ pi, add, arctan2, frompyfunc, cos, less_equal, sqrt, sin,
+ mod, exp, not_equal, subtract
)
from numpy.core.fromnumeric import (
- ravel, nonzero, sort, partition, mean, any, sum
+ ravel, nonzero, partition, mean, any, sum
)
-from numpy.core.numerictypes import typecodes, number
+from numpy.core.numerictypes import typecodes
+from numpy.core.overrides import array_function_dispatch
from numpy.core.function_base import add_newdoc
from numpy.lib.twodim_base import diag
from .utils import deprecate
@@ -36,7 +36,6 @@ from numpy.core.multiarray import (
)
from numpy.core.umath import _add_newdoc_ufunc as add_newdoc_ufunc
from numpy.compat import long
-from numpy.compat.py3k import basestring
if sys.version_info[0] < 3:
# Force range to be a generator, for np.delete's usage.
@@ -60,6 +59,11 @@ __all__ = [
]
+def _rot90_dispatcher(m, k=None, axes=None):
+ return (m,)
+
+
+@array_function_dispatch(_rot90_dispatcher)
def rot90(m, k=1, axes=(0,1)):
"""
Rotate an array by 90 degrees in the plane specified by axes.
@@ -146,6 +150,11 @@ def rot90(m, k=1, axes=(0,1)):
return flip(transpose(m, axes_list), axes[1])
+def _flip_dispatcher(m, axis=None):
+ return (m,)
+
+
+@array_function_dispatch(_flip_dispatcher)
def flip(m, axis=None):
"""
Reverse the order of elements in an array along the given axis.
@@ -270,6 +279,11 @@ def iterable(y):
return True
+def _average_dispatcher(a, axis=None, weights=None, returned=None):
+ return (a, weights)
+
+
+@array_function_dispatch(_average_dispatcher)
def average(a, axis=None, weights=None, returned=False):
"""
Compute the weighted average along the specified axis.
@@ -476,6 +490,15 @@ def asarray_chkfinite(a, dtype=None, order=None):
return a
+def _piecewise_dispatcher(x, condlist, funclist, *args, **kw):
+ yield x
+ # support the undocumented behavior of allowing scalars
+ if np.iterable(condlist):
+ for c in condlist:
+ yield c
+
+
+@array_function_dispatch(_piecewise_dispatcher)
def piecewise(x, condlist, funclist, *args, **kw):
"""
Evaluate a piecewise-defined function.
@@ -597,6 +620,14 @@ def piecewise(x, condlist, funclist, *args, **kw):
return y
+def _select_dispatcher(condlist, choicelist, default=None):
+ for c in condlist:
+ yield c
+ for c in choicelist:
+ yield c
+
+
+@array_function_dispatch(_select_dispatcher)
def select(condlist, choicelist, default=0):
"""
Return an array drawn from elements in choicelist, depending on conditions.
@@ -700,6 +731,11 @@ def select(condlist, choicelist, default=0):
return result
+def _copy_dispatcher(a, order=None):
+ return (a,)
+
+
+@array_function_dispatch(_copy_dispatcher)
def copy(a, order='K'):
"""
Return an array copy of the given object.
@@ -749,6 +785,13 @@ def copy(a, order='K'):
# Basic operations
+def _gradient_dispatcher(f, *varargs, **kwargs):
+ yield f
+ for v in varargs:
+ yield v
+
+
+@array_function_dispatch(_gradient_dispatcher)
def gradient(f, *varargs, **kwargs):
"""
Return the gradient of an N-dimensional array.
@@ -1090,7 +1133,12 @@ def gradient(f, *varargs, **kwargs):
return outvals
-def diff(a, n=1, axis=-1):
+def _diff_dispatcher(a, n=None, axis=None, prepend=None, append=None):
+ return (a, prepend, append)
+
+
+@array_function_dispatch(_diff_dispatcher)
+def diff(a, n=1, axis=-1, prepend=np._NoValue, append=np._NoValue):
"""
Calculate the n-th discrete difference along the given axis.
@@ -1108,6 +1156,12 @@ def diff(a, n=1, axis=-1):
axis : int, optional
The axis along which the difference is taken, default is the
last axis.
+ prepend, append : array_like, optional
+ Values to prepend or append to "a" along axis prior to
+ performing the difference. Scalar values are expanded to
+ arrays with length 1 in the direction of axis and the shape
+ of the input array in along all other axes. Otherwise the
+ dimension and shape must match "a" except along axis.
Returns
-------
@@ -1176,6 +1230,28 @@ def diff(a, n=1, axis=-1):
nd = a.ndim
axis = normalize_axis_index(axis, nd)
+ combined = []
+ if prepend is not np._NoValue:
+ prepend = np.asanyarray(prepend)
+ if prepend.ndim == 0:
+ shape = list(a.shape)
+ shape[axis] = 1
+ prepend = np.broadcast_to(prepend, tuple(shape))
+ combined.append(prepend)
+
+ combined.append(a)
+
+ if append is not np._NoValue:
+ append = np.asanyarray(append)
+ if append.ndim == 0:
+ shape = list(a.shape)
+ shape[axis] = 1
+ append = np.broadcast_to(append, tuple(shape))
+ combined.append(append)
+
+ if len(combined) > 1:
+ a = np.concatenate(combined, axis)
+
slice1 = [slice(None)] * nd
slice2 = [slice(None)] * nd
slice1[axis] = slice(1, None)
@@ -1190,6 +1266,11 @@ def diff(a, n=1, axis=-1):
return a
+def _interp_dispatcher(x, xp, fp, left=None, right=None, period=None):
+ return (x, xp, fp)
+
+
+@array_function_dispatch(_interp_dispatcher)
def interp(x, xp, fp, left=None, right=None, period=None):
"""
One-dimensional linear interpolation.
@@ -1322,6 +1403,11 @@ def interp(x, xp, fp, left=None, right=None, period=None):
return interp_func(x, xp, fp, left, right)
+def _angle_dispatcher(z, deg=None):
+ return (z,)
+
+
+@array_function_dispatch(_angle_dispatcher)
def angle(z, deg=False):
"""
Return the angle of the complex argument.
@@ -1369,6 +1455,11 @@ def angle(z, deg=False):
return a
+def _unwrap_dispatcher(p, discont=None, axis=None):
+ return (p,)
+
+
+@array_function_dispatch(_unwrap_dispatcher)
def unwrap(p, discont=pi, axis=-1):
"""
Unwrap by changing deltas between values to 2*pi complement.
@@ -1425,6 +1516,11 @@ def unwrap(p, discont=pi, axis=-1):
return up
+def _sort_complex(a):
+ return (a,)
+
+
+@array_function_dispatch(_sort_complex)
def sort_complex(a):
"""
Sort a complex array using the real part first, then the imaginary part.
@@ -1461,6 +1557,11 @@ def sort_complex(a):
return b
+def _trim_zeros(filt, trim=None):
+ return (filt,)
+
+
+@array_function_dispatch(_trim_zeros)
def trim_zeros(filt, trim='fb'):
"""
Trim the leading and/or trailing zeros from a 1-D array or sequence.
@@ -1530,6 +1631,11 @@ def unique(x):
return asarray(items)
+def _extract_dispatcher(condition, arr):
+ return (condition, arr)
+
+
+@array_function_dispatch(_extract_dispatcher)
def extract(condition, arr):
"""
Return the elements of an array that satisfy some condition.
@@ -1581,6 +1687,11 @@ def extract(condition, arr):
return _nx.take(ravel(arr), nonzero(ravel(condition))[0])
+def _place_dispatcher(arr, mask, vals):
+ return (arr, mask, vals)
+
+
+@array_function_dispatch(_place_dispatcher)
def place(arr, mask, vals):
"""
Change elements of an array based on conditional and input values.
@@ -2135,6 +2246,12 @@ class vectorize(object):
return outputs[0] if nout == 1 else outputs
+def _cov_dispatcher(m, y=None, rowvar=None, bias=None, ddof=None,
+ fweights=None, aweights=None):
+ return (m, y, fweights, aweights)
+
+
+@array_function_dispatch(_cov_dispatcher)
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
aweights=None):
"""
@@ -2344,6 +2461,11 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
return c.squeeze()
+def _corrcoef_dispatcher(x, y=None, rowvar=None, bias=None, ddof=None):
+ return (x, y)
+
+
+@array_function_dispatch(_corrcoef_dispatcher)
def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue):
"""
Return Pearson product-moment correlation coefficients.
@@ -2912,6 +3034,11 @@ def _i0_2(x):
return exp(x) * _chbevl(32.0/x - 2.0, _i0B) / sqrt(x)
+def _i0_dispatcher(x):
+ return (x,)
+
+
+@array_function_dispatch(_i0_dispatcher)
def i0(x):
"""
Modified Bessel function of the first kind, order 0.
@@ -3106,6 +3233,11 @@ def kaiser(M, beta):
return i0(beta * sqrt(1-((n-alpha)/alpha)**2.0))/i0(float(beta))
+def _sinc_dispatcher(x):
+ return (x,)
+
+
+@array_function_dispatch(_sinc_dispatcher)
def sinc(x):
"""
Return the sinc function.
@@ -3185,6 +3317,11 @@ def sinc(x):
return sin(y)/y
+def _msort_dispatcher(a):
+ return (a,)
+
+
+@array_function_dispatch(_msort_dispatcher)
def msort(a):
"""
Return a copy of an array sorted along the first axis.
@@ -3268,6 +3405,12 @@ def _ureduce(a, func, **kwargs):
return r, keepdim
+def _median_dispatcher(
+ a, axis=None, out=None, overwrite_input=None, keepdims=None):
+ return (a, out)
+
+
+@array_function_dispatch(_median_dispatcher)
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
"""
Compute the median along the specified axis.
@@ -3412,6 +3555,12 @@ def _median(a, axis=None, out=None, overwrite_input=False):
return mean(part[indexer], axis=axis, out=out)
+def _percentile_dispatcher(a, q, axis=None, out=None, overwrite_input=None,
+ interpolation=None, keepdims=None):
+ return (a, q, out)
+
+
+@array_function_dispatch(_percentile_dispatcher)
def percentile(a, q, axis=None, out=None,
overwrite_input=False, interpolation='linear', keepdims=False):
"""
@@ -3557,6 +3706,12 @@ def percentile(a, q, axis=None, out=None,
a, q, axis, out, overwrite_input, interpolation, keepdims)
+def _quantile_dispatcher(a, q, axis=None, out=None, overwrite_input=None,
+ interpolation=None, keepdims=None):
+ return (a, q, out)
+
+
+@array_function_dispatch(_quantile_dispatcher)
def quantile(a, q, axis=None, out=None,
overwrite_input=False, interpolation='linear', keepdims=False):
"""
@@ -3819,6 +3974,11 @@ def _quantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False,
return r
+def _trapz_dispatcher(y, x=None, dx=None, axis=None):
+ return (y, x)
+
+
+@array_function_dispatch(_trapz_dispatcher)
def trapz(y, x=None, dx=1.0, axis=-1):
"""
Integrate along the given axis using the composite trapezoidal rule.
@@ -3909,7 +4069,12 @@ def trapz(y, x=None, dx=1.0, axis=-1):
return ret
+def _meshgrid_dispatcher(*xi, **kwargs):
+ return xi
+
+
# Based on scitools meshgrid
+@array_function_dispatch(_meshgrid_dispatcher)
def meshgrid(*xi, **kwargs):
"""
Return coordinate matrices from coordinate vectors.
@@ -4047,6 +4212,11 @@ def meshgrid(*xi, **kwargs):
return output
+def _delete_dispatcher(arr, obj, axis=None):
+ return (arr, obj)
+
+
+@array_function_dispatch(_delete_dispatcher)
def delete(arr, obj, axis=None):
"""
Return a new array with sub-arrays along an axis deleted. For a one
@@ -4252,6 +4422,11 @@ def delete(arr, obj, axis=None):
return new
+def _insert_dispatcher(arr, obj, values, axis=None):
+ return (arr, obj, values)
+
+
+@array_function_dispatch(_insert_dispatcher)
def insert(arr, obj, values, axis=None):
"""
Insert values along the given axis before the given indices.
@@ -4458,6 +4633,11 @@ def insert(arr, obj, values, axis=None):
return new
+def _append_dispatcher(arr, values, axis=None):
+ return (arr, values)
+
+
+@array_function_dispatch(_append_dispatcher)
def append(arr, values, axis=None):
"""
Append values to the end of an array.
@@ -4513,6 +4693,11 @@ def append(arr, values, axis=None):
return concatenate((arr, values), axis=axis)
+def _digitize_dispatcher(x, bins, right=None):
+ return (x, bins)
+
+
+@array_function_dispatch(_digitize_dispatcher)
def digitize(x, bins, right=False):
"""
Return the indices of the bins to which each value in input array belongs.