summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/arraypad.py2
-rw-r--r--numpy/lib/arraysetops.py8
-rw-r--r--numpy/lib/financial.py7
-rw-r--r--numpy/lib/function_base.py8
-rw-r--r--numpy/lib/index_tricks.py10
-rw-r--r--numpy/lib/nanfunctions.py7
-rw-r--r--numpy/lib/npyio.py29
-rw-r--r--numpy/lib/polynomial.py57
-rw-r--r--numpy/lib/recfunctions.py88
-rw-r--r--numpy/lib/scimath.py38
-rw-r--r--numpy/lib/shape_base.py82
-rw-r--r--numpy/lib/stride_tricks.py11
-rw-r--r--numpy/lib/tests/test_function_base.py26
-rw-r--r--numpy/lib/tests/test_index_tricks.py5
-rw-r--r--numpy/lib/twodim_base.py42
-rw-r--r--numpy/lib/type_check.py53
-rw-r--r--numpy/lib/ufunclike.py30
17 files changed, 487 insertions, 16 deletions
diff --git a/numpy/lib/arraypad.py b/numpy/lib/arraypad.py
index f76ad456f..d27a3918f 100644
--- a/numpy/lib/arraypad.py
+++ b/numpy/lib/arraypad.py
@@ -995,7 +995,7 @@ def _pad_dispatcher(array, pad_width, mode, **kwargs):
return (array,)
-@array_function_dispatch(_pad_dispatcher)
+@array_function_dispatch(_pad_dispatcher, module='numpy')
def pad(array, pad_width, mode, **kwargs):
"""
Pads an array.
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index ec62cd7a6..850e20123 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -27,8 +27,14 @@ To do: Optionally return indices analogously to unique for all functions.
"""
from __future__ import division, absolute_import, print_function
+import functools
+
import numpy as np
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides
+
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
__all__ = [
diff --git a/numpy/lib/financial.py b/numpy/lib/financial.py
index d1a0cd9c0..e1e297492 100644
--- a/numpy/lib/financial.py
+++ b/numpy/lib/financial.py
@@ -13,9 +13,14 @@ otherwise stated.
from __future__ import division, absolute_import, print_function
from decimal import Decimal
+import functools
import numpy as np
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides
+
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
__all__ = ['fv', 'pmt', 'nper', 'ipmt', 'ppmt', 'pv', 'rate',
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index c52ecdbd8..fae6541bc 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -6,6 +6,7 @@ try:
import collections.abc as collections_abc
except ImportError:
import collections as collections_abc
+import functools
import re
import sys
import warnings
@@ -26,7 +27,7 @@ from numpy.core.fromnumeric import (
ravel, nonzero, partition, mean, any, sum
)
from numpy.core.numerictypes import typecodes
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides
from numpy.core.function_base import add_newdoc
from numpy.lib.twodim_base import diag
from .utils import deprecate
@@ -44,6 +45,11 @@ if sys.version_info[0] < 3:
else:
import builtins
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
# needed in this module for compatibility
from numpy.lib.histograms import histogram, histogramdd
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 26243d231..ff2e00d3e 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -1,5 +1,6 @@
from __future__ import division, absolute_import, print_function
+import functools
import sys
import math
@@ -9,14 +10,17 @@ from numpy.core.numeric import (
)
from numpy.core.numerictypes import find_common_type, issubdtype
-from . import function_base
import numpy.matrixlib as matrixlib
from .function_base import diff
from numpy.core.multiarray import ravel_multi_index, unravel_index
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides, linspace
from numpy.lib.stride_tricks import as_strided
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
__all__ = [
'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_',
's_', 'index_exp', 'ix_', 'ndenumerate', 'ndindex', 'fill_diagonal',
@@ -341,7 +345,7 @@ class AxisConcatenator(object):
step = 1
if isinstance(step, complex):
size = int(abs(step))
- newobj = function_base.linspace(start, stop, num=size)
+ newobj = linspace(start, stop, num=size)
else:
newobj = _nx.arange(start, stop, step)
if ndmin > 1:
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py
index 279c4c5c4..d73d84467 100644
--- a/numpy/lib/nanfunctions.py
+++ b/numpy/lib/nanfunctions.py
@@ -22,10 +22,15 @@ Functions
"""
from __future__ import division, absolute_import, print_function
+import functools
import warnings
import numpy as np
from numpy.lib import function_base
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides
+
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
__all__ = [
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 62fc9c5b3..733795671 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -12,6 +12,7 @@ import numpy as np
from . import format
from ._datasource import DataSource
from numpy.core.multiarray import packbits, unpackbits
+from numpy.core.overrides import array_function_dispatch
from numpy.core._internal import recursive
from ._iotools import (
LineSplitter, NameValidator, StringConverter, ConverterError,
@@ -447,6 +448,11 @@ def load(file, mmap_mode=None, allow_pickle=True, fix_imports=True,
fid.close()
+def _save_dispatcher(file, arr, allow_pickle=None, fix_imports=None):
+ return (arr,)
+
+
+@array_function_dispatch(_save_dispatcher)
def save(file, arr, allow_pickle=True, fix_imports=True):
"""
Save an array to a binary file in NumPy ``.npy`` format.
@@ -525,6 +531,14 @@ def save(file, arr, allow_pickle=True, fix_imports=True):
fid.close()
+def _savez_dispatcher(file, *args, **kwds):
+ for a in args:
+ yield a
+ for v in kwds.values():
+ yield v
+
+
+@array_function_dispatch(_savez_dispatcher)
def savez(file, *args, **kwds):
"""
Save several arrays into a single file in uncompressed ``.npz`` format.
@@ -604,6 +618,14 @@ def savez(file, *args, **kwds):
_savez(file, args, kwds, False)
+def _savez_compressed_dispatcher(file, *args, **kwds):
+ for a in args:
+ yield a
+ for v in kwds.values():
+ yield v
+
+
+@array_function_dispatch(_savez_compressed_dispatcher)
def savez_compressed(file, *args, **kwds):
"""
Save several arrays into a single file in compressed ``.npz`` format.
@@ -1154,6 +1176,13 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None,
return X
+def _savetxt_dispatcher(fname, X, fmt=None, delimiter=None, newline=None,
+ header=None, footer=None, comments=None,
+ encoding=None):
+ return (X,)
+
+
+@array_function_dispatch(_savetxt_dispatcher)
def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='',
footer='', comments='# ', encoding=None):
"""
diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py
index 9f3b84732..c2702f0a7 100644
--- a/numpy/lib/polynomial.py
+++ b/numpy/lib/polynomial.py
@@ -8,17 +8,24 @@ __all__ = ['poly', 'roots', 'polyint', 'polyder', 'polyadd',
'polysub', 'polymul', 'polydiv', 'polyval', 'poly1d',
'polyfit', 'RankWarning']
+import functools
import re
import warnings
import numpy.core.numeric as NX
from numpy.core import (isscalar, abs, finfo, atleast_1d, hstack, dot, array,
ones)
+from numpy.core import overrides
from numpy.lib.twodim_base import diag, vander
from numpy.lib.function_base import trim_zeros
from numpy.lib.type_check import iscomplex, real, imag, mintypecode
from numpy.linalg import eigvals, lstsq, inv
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
class RankWarning(UserWarning):
"""
Issued by `polyfit` when the Vandermonde matrix is rank deficient.
@@ -29,6 +36,12 @@ class RankWarning(UserWarning):
"""
pass
+
+def _poly_dispatcher(seq_of_zeros):
+ return seq_of_zeros
+
+
+@array_function_dispatch(_poly_dispatcher)
def poly(seq_of_zeros):
"""
Find the coefficients of a polynomial with the given sequence of roots.
@@ -145,6 +158,12 @@ def poly(seq_of_zeros):
return a
+
+def _roots_dispatcher(p):
+ return p
+
+
+@array_function_dispatch(_roots_dispatcher)
def roots(p):
"""
Return the roots of a polynomial with coefficients given in p.
@@ -229,6 +248,12 @@ def roots(p):
roots = hstack((roots, NX.zeros(trailing_zeros, roots.dtype)))
return roots
+
+def _polyint_dispatcher(p, m=None, k=None):
+ return (p,)
+
+
+@array_function_dispatch(_polyint_dispatcher)
def polyint(p, m=1, k=None):
"""
Return an antiderivative (indefinite integral) of a polynomial.
@@ -322,6 +347,12 @@ def polyint(p, m=1, k=None):
return poly1d(val)
return val
+
+def _polyder_dispatcher(p, m=None):
+ return (p,)
+
+
+@array_function_dispatch(_polyder_dispatcher)
def polyder(p, m=1):
"""
Return the derivative of the specified order of a polynomial.
@@ -390,6 +421,12 @@ def polyder(p, m=1):
val = poly1d(val)
return val
+
+def _polyfit_dispatcher(x, y, deg, rcond=None, full=None, w=None, cov=None):
+ return (x, y, w)
+
+
+@array_function_dispatch(_polyfit_dispatcher)
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
"""
Least squares polynomial fit.
@@ -610,6 +647,11 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
return c
+def _polyval_dispatcher(p, x):
+ return (p, x)
+
+
+@array_function_dispatch(_polyval_dispatcher)
def polyval(p, x):
"""
Evaluate a polynomial at specific values.
@@ -679,6 +721,12 @@ def polyval(p, x):
y = y * x + p[i]
return y
+
+def _binary_op_dispatcher(a1, a2):
+ return (a1, a2)
+
+
+@array_function_dispatch(_binary_op_dispatcher)
def polyadd(a1, a2):
"""
Find the sum of two polynomials.
@@ -739,6 +787,8 @@ def polyadd(a1, a2):
val = poly1d(val)
return val
+
+@array_function_dispatch(_binary_op_dispatcher)
def polysub(a1, a2):
"""
Difference (subtraction) of two polynomials.
@@ -786,6 +836,7 @@ def polysub(a1, a2):
return val
+@array_function_dispatch(_binary_op_dispatcher)
def polymul(a1, a2):
"""
Find the product of two polynomials.
@@ -842,6 +893,12 @@ def polymul(a1, a2):
val = poly1d(val)
return val
+
+def _polydiv_dispatcher(u, v):
+ return (u, v)
+
+
+@array_function_dispatch(_polydiv_dispatcher)
def polydiv(u, v):
"""
Returns the quotient and remainder of polynomial division.
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py
index b6453d5a2..53a586f56 100644
--- a/numpy/lib/recfunctions.py
+++ b/numpy/lib/recfunctions.py
@@ -14,6 +14,7 @@ import numpy.ma as ma
from numpy import ndarray, recarray
from numpy.ma import MaskedArray
from numpy.ma.mrecords import MaskedRecords
+from numpy.core.overrides import array_function_dispatch
from numpy.lib._iotools import _is_string_like
from numpy.compat import basestring
@@ -31,6 +32,11 @@ __all__ = [
]
+def _recursive_fill_fields_dispatcher(input, output):
+ return (input, output)
+
+
+@array_function_dispatch(_recursive_fill_fields_dispatcher)
def recursive_fill_fields(input, output):
"""
Fills fields from output with fields from input,
@@ -189,6 +195,11 @@ def flatten_descr(ndtype):
return tuple(descr)
+def _zip_dtype_dispatcher(seqarrays, flatten=None):
+ return seqarrays
+
+
+@array_function_dispatch(_zip_dtype_dispatcher)
def zip_dtype(seqarrays, flatten=False):
newdtype = []
if flatten:
@@ -205,6 +216,7 @@ def zip_dtype(seqarrays, flatten=False):
return np.dtype(newdtype)
+@array_function_dispatch(_zip_dtype_dispatcher)
def zip_descr(seqarrays, flatten=False):
"""
Combine the dtype description of a series of arrays.
@@ -297,6 +309,11 @@ def _izip_fields(iterable):
yield element
+def _izip_records_dispatcher(seqarrays, fill_value=None, flatten=None):
+ return seqarrays
+
+
+@array_function_dispatch(_izip_records_dispatcher)
def izip_records(seqarrays, fill_value=None, flatten=True):
"""
Returns an iterator of concatenated items from a sequence of arrays.
@@ -357,6 +374,12 @@ def _fix_defaults(output, defaults=None):
return output
+def _merge_arrays_dispatcher(seqarrays, fill_value=None, flatten=None,
+ usemask=None, asrecarray=None):
+ return seqarrays
+
+
+@array_function_dispatch(_merge_arrays_dispatcher)
def merge_arrays(seqarrays, fill_value=-1, flatten=False,
usemask=False, asrecarray=False):
"""
@@ -494,6 +517,11 @@ def merge_arrays(seqarrays, fill_value=-1, flatten=False,
return output
+def _drop_fields_dispatcher(base, drop_names, usemask=None, asrecarray=None):
+ return (base,)
+
+
+@array_function_dispatch(_drop_fields_dispatcher)
def drop_fields(base, drop_names, usemask=True, asrecarray=False):
"""
Return a new array with fields in `drop_names` dropped.
@@ -583,6 +611,11 @@ def _keep_fields(base, keep_names, usemask=True, asrecarray=False):
return _fix_output(output, usemask=usemask, asrecarray=asrecarray)
+def _rec_drop_fields_dispatcher(base, drop_names):
+ return (base,)
+
+
+@array_function_dispatch(_rec_drop_fields_dispatcher)
def rec_drop_fields(base, drop_names):
"""
Returns a new numpy.recarray with fields in `drop_names` dropped.
@@ -590,6 +623,11 @@ def rec_drop_fields(base, drop_names):
return drop_fields(base, drop_names, usemask=False, asrecarray=True)
+def _rename_fields_dispatcher(base, namemapper):
+ return (base,)
+
+
+@array_function_dispatch(_rename_fields_dispatcher)
def rename_fields(base, namemapper):
"""
Rename the fields from a flexible-datatype ndarray or recarray.
@@ -629,6 +667,14 @@ def rename_fields(base, namemapper):
return base.view(newdtype)
+def _append_fields_dispatcher(base, names, data, dtypes=None,
+ fill_value=None, usemask=None, asrecarray=None):
+ yield base
+ for d in data:
+ yield d
+
+
+@array_function_dispatch(_append_fields_dispatcher)
def append_fields(base, names, data, dtypes=None,
fill_value=-1, usemask=True, asrecarray=False):
"""
@@ -699,6 +745,13 @@ def append_fields(base, names, data, dtypes=None,
return _fix_output(output, usemask=usemask, asrecarray=asrecarray)
+def _rec_append_fields_dispatcher(base, names, data, dtypes=None):
+ yield base
+ for d in data:
+ yield d
+
+
+@array_function_dispatch(_rec_append_fields_dispatcher)
def rec_append_fields(base, names, data, dtypes=None):
"""
Add new fields to an existing array.
@@ -732,6 +785,12 @@ def rec_append_fields(base, names, data, dtypes=None):
return append_fields(base, names, data=data, dtypes=dtypes,
asrecarray=True, usemask=False)
+
+def _repack_fields_dispatcher(a, align=None, recurse=None):
+ return (a,)
+
+
+@array_function_dispatch(_repack_fields_dispatcher)
def repack_fields(a, align=False, recurse=False):
"""
Re-pack the fields of a structured array or dtype in memory.
@@ -811,6 +870,13 @@ def repack_fields(a, align=False, recurse=False):
dt = np.dtype(fieldinfo, align=align)
return np.dtype((a.type, dt))
+
+def _stack_arrays_dispatcher(arrays, defaults=None, usemask=None,
+ asrecarray=None, autoconvert=None):
+ return arrays
+
+
+@array_function_dispatch(_stack_arrays_dispatcher)
def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False,
autoconvert=False):
"""
@@ -897,6 +963,12 @@ def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False,
usemask=usemask, asrecarray=asrecarray)
+def _find_duplicates_dispatcher(
+ a, key=None, ignoremask=None, return_index=None):
+ return (a,)
+
+
+@array_function_dispatch(_find_duplicates_dispatcher)
def find_duplicates(a, key=None, ignoremask=True, return_index=False):
"""
Find the duplicates in a structured array along a given key
@@ -951,8 +1023,15 @@ def find_duplicates(a, key=None, ignoremask=True, return_index=False):
return duplicates
+def _join_by_dispatcher(
+ key, r1, r2, jointype=None, r1postfix=None, r2postfix=None,
+ defaults=None, usemask=None, asrecarray=None):
+ return (r1, r2)
+
+
+@array_function_dispatch(_join_by_dispatcher)
def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
- defaults=None, usemask=True, asrecarray=False):
+ defaults=None, usemask=True, asrecarray=False):
"""
Join arrays `r1` and `r2` on key `key`.
@@ -1130,6 +1209,13 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
return _fix_output(_fix_defaults(output, defaults), **kwargs)
+def _rec_join_dispatcher(
+ key, r1, r2, jointype=None, r1postfix=None, r2postfix=None,
+ defaults=None):
+ return (r1, r2)
+
+
+@array_function_dispatch(_rec_join_dispatcher)
def rec_join(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
defaults=None):
"""
diff --git a/numpy/lib/scimath.py b/numpy/lib/scimath.py
index f1838fee6..9ca006841 100644
--- a/numpy/lib/scimath.py
+++ b/numpy/lib/scimath.py
@@ -20,6 +20,7 @@ from __future__ import division, absolute_import, print_function
import numpy.core.numeric as nx
import numpy.core.numerictypes as nt
from numpy.core.numeric import asarray, any
+from numpy.core.overrides import array_function_dispatch
from numpy.lib.type_check import isreal
@@ -94,6 +95,7 @@ def _tocomplex(arr):
else:
return arr.astype(nt.cdouble)
+
def _fix_real_lt_zero(x):
"""Convert `x` to complex if it has real, negative components.
@@ -121,6 +123,7 @@ def _fix_real_lt_zero(x):
x = _tocomplex(x)
return x
+
def _fix_int_lt_zero(x):
"""Convert `x` to double if it has real, negative components.
@@ -147,6 +150,7 @@ def _fix_int_lt_zero(x):
x = x * 1.0
return x
+
def _fix_real_abs_gt_1(x):
"""Convert `x` to complex if it has real components x_i with abs(x_i)>1.
@@ -173,6 +177,12 @@ def _fix_real_abs_gt_1(x):
x = _tocomplex(x)
return x
+
+def _unary_dispatcher(x):
+ return (x,)
+
+
+@array_function_dispatch(_unary_dispatcher)
def sqrt(x):
"""
Compute the square root of x.
@@ -215,6 +225,8 @@ def sqrt(x):
x = _fix_real_lt_zero(x)
return nx.sqrt(x)
+
+@array_function_dispatch(_unary_dispatcher)
def log(x):
"""
Compute the natural logarithm of `x`.
@@ -261,6 +273,8 @@ def log(x):
x = _fix_real_lt_zero(x)
return nx.log(x)
+
+@array_function_dispatch(_unary_dispatcher)
def log10(x):
"""
Compute the logarithm base 10 of `x`.
@@ -309,6 +323,12 @@ def log10(x):
x = _fix_real_lt_zero(x)
return nx.log10(x)
+
+def _logn_dispatcher(n, x):
+ return (n, x,)
+
+
+@array_function_dispatch(_logn_dispatcher)
def logn(n, x):
"""
Take log base n of x.
@@ -318,8 +338,8 @@ def logn(n, x):
Parameters
----------
- n : int
- The base in which the log is taken.
+ n : array_like
+ The integer base(s) in which the log is taken.
x : array_like
The value(s) whose log base `n` is (are) required.
@@ -343,6 +363,8 @@ def logn(n, x):
n = _fix_real_lt_zero(n)
return nx.log(x)/nx.log(n)
+
+@array_function_dispatch(_unary_dispatcher)
def log2(x):
"""
Compute the logarithm base 2 of `x`.
@@ -389,6 +411,12 @@ def log2(x):
x = _fix_real_lt_zero(x)
return nx.log2(x)
+
+def _power_dispatcher(x, p):
+ return (x, p)
+
+
+@array_function_dispatch(_power_dispatcher)
def power(x, p):
"""
Return x to the power p, (x**p).
@@ -432,6 +460,8 @@ def power(x, p):
p = _fix_int_lt_zero(p)
return nx.power(x, p)
+
+@array_function_dispatch(_unary_dispatcher)
def arccos(x):
"""
Compute the inverse cosine of x.
@@ -475,6 +505,8 @@ def arccos(x):
x = _fix_real_abs_gt_1(x)
return nx.arccos(x)
+
+@array_function_dispatch(_unary_dispatcher)
def arcsin(x):
"""
Compute the inverse sine of x.
@@ -519,6 +551,8 @@ def arcsin(x):
x = _fix_real_abs_gt_1(x)
return nx.arcsin(x)
+
+@array_function_dispatch(_unary_dispatcher)
def arctanh(x):
"""
Compute the inverse hyperbolic tangent of `x`.
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 66f534734..00424d55d 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -1,5 +1,6 @@
from __future__ import division, absolute_import, print_function
+import functools
import warnings
import numpy.core.numeric as _nx
@@ -8,6 +9,7 @@ from numpy.core.numeric import (
)
from numpy.core.fromnumeric import product, reshape, transpose
from numpy.core.multiarray import normalize_axis_index
+from numpy.core import overrides
from numpy.core import vstack, atleast_3d
from numpy.lib.index_tricks import ndindex
from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells
@@ -21,6 +23,10 @@ __all__ = [
]
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
def _make_along_axis_idx(arr_shape, indices, axis):
# compute dimensions to iterate over
if not _nx.issubdtype(indices.dtype, _nx.integer):
@@ -44,6 +50,11 @@ def _make_along_axis_idx(arr_shape, indices, axis):
return tuple(fancy_index)
+def _take_along_axis_dispatcher(arr, indices, axis):
+ return (arr, indices)
+
+
+@array_function_dispatch(_take_along_axis_dispatcher)
def take_along_axis(arr, indices, axis):
"""
Take values from the input array by matching 1d index and data slices.
@@ -160,6 +171,11 @@ def take_along_axis(arr, indices, axis):
return arr[_make_along_axis_idx(arr_shape, indices, axis)]
+def _put_along_axis_dispatcher(arr, indices, values, axis):
+ return (arr, indices, values)
+
+
+@array_function_dispatch(_put_along_axis_dispatcher)
def put_along_axis(arr, indices, values, axis):
"""
Put values into the destination array by matching 1d index and data slices.
@@ -245,6 +261,11 @@ def put_along_axis(arr, indices, values, axis):
arr[_make_along_axis_idx(arr_shape, indices, axis)] = values
+def _apply_along_axis_dispatcher(func1d, axis, arr, *args, **kwargs):
+ return (arr,)
+
+
+@array_function_dispatch(_apply_along_axis_dispatcher)
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
"""
Apply a function to 1-D slices along the given axis.
@@ -392,6 +413,11 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
return res.__array_wrap__(out_arr)
+def _apply_over_axes_dispatcher(func, a, axes):
+ return (a,)
+
+
+@array_function_dispatch(_apply_over_axes_dispatcher)
def apply_over_axes(func, a, axes):
"""
Apply a function repeatedly over multiple axes.
@@ -474,9 +500,15 @@ def apply_over_axes(func, a, axes):
val = res
else:
raise ValueError("function is not returning "
- "an array of the correct shape")
+ "an array of the correct shape")
return val
+
+def _expand_dims_dispatcher(a, axis):
+ return (a,)
+
+
+@array_function_dispatch(_expand_dims_dispatcher)
def expand_dims(a, axis):
"""
Expand the shape of an array.
@@ -554,8 +586,15 @@ def expand_dims(a, axis):
# axis = normalize_axis_index(axis, a.ndim + 1)
return a.reshape(shape[:axis] + (1,) + shape[axis:])
+
row_stack = vstack
+
+def _column_stack_dispatcher(tup):
+ return tup
+
+
+@array_function_dispatch(_column_stack_dispatcher)
def column_stack(tup):
"""
Stack 1-D arrays as columns into a 2-D array.
@@ -597,6 +636,12 @@ def column_stack(tup):
arrays.append(arr)
return _nx.concatenate(arrays, 1)
+
+def _dstack_dispatcher(tup):
+ return tup
+
+
+@array_function_dispatch(_dstack_dispatcher)
def dstack(tup):
"""
Stack arrays in sequence depth wise (along third axis).
@@ -649,6 +694,7 @@ def dstack(tup):
"""
return _nx.concatenate([atleast_3d(_m) for _m in tup], 2)
+
def _replace_zero_by_x_arrays(sub_arys):
for i in range(len(sub_arys)):
if _nx.ndim(sub_arys[i]) == 0:
@@ -657,6 +703,12 @@ def _replace_zero_by_x_arrays(sub_arys):
sub_arys[i] = _nx.empty(0, dtype=sub_arys[i].dtype)
return sub_arys
+
+def _array_split_dispatcher(ary, indices_or_sections, axis=None):
+ return (ary, indices_or_sections)
+
+
+@array_function_dispatch(_array_split_dispatcher)
def array_split(ary, indices_or_sections, axis=0):
"""
Split an array into multiple sub-arrays.
@@ -712,7 +764,12 @@ def array_split(ary, indices_or_sections, axis=0):
return sub_arys
-def split(ary,indices_or_sections,axis=0):
+def _split_dispatcher(ary, indices_or_sections, axis=None):
+ return (ary, indices_or_sections)
+
+
+@array_function_dispatch(_split_dispatcher)
+def split(ary, indices_or_sections, axis=0):
"""
Split an array into multiple sub-arrays.
@@ -789,6 +846,12 @@ def split(ary,indices_or_sections,axis=0):
res = array_split(ary, indices_or_sections, axis)
return res
+
+def _hvdsplit_dispatcher(ary, indices_or_sections):
+ return (ary, indices_or_sections)
+
+
+@array_function_dispatch(_hvdsplit_dispatcher)
def hsplit(ary, indices_or_sections):
"""
Split an array into multiple sub-arrays horizontally (column-wise).
@@ -851,6 +914,8 @@ def hsplit(ary, indices_or_sections):
else:
return split(ary, indices_or_sections, 0)
+
+@array_function_dispatch(_hvdsplit_dispatcher)
def vsplit(ary, indices_or_sections):
"""
Split an array into multiple sub-arrays vertically (row-wise).
@@ -902,6 +967,8 @@ def vsplit(ary, indices_or_sections):
raise ValueError('vsplit only works on arrays of 2 or more dimensions')
return split(ary, indices_or_sections, 0)
+
+@array_function_dispatch(_hvdsplit_dispatcher)
def dsplit(ary, indices_or_sections):
"""
Split array into multiple sub-arrays along the 3rd axis (depth).
@@ -971,6 +1038,12 @@ def get_array_wrap(*args):
return wrappers[-1][-1]
return None
+
+def _kron_dispatcher(a, b):
+ return (a, b)
+
+
+@array_function_dispatch(_kron_dispatcher)
def kron(a, b):
"""
Kronecker product of two arrays.
@@ -1070,6 +1143,11 @@ def kron(a, b):
return result
+def _tile_dispatcher(A, reps):
+ return (A, reps)
+
+
+@array_function_dispatch(_tile_dispatcher)
def tile(A, reps):
"""
Construct an array by repeating A the number of times given by reps.
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py
index ca13738c1..0dc36e41c 100644
--- a/numpy/lib/stride_tricks.py
+++ b/numpy/lib/stride_tricks.py
@@ -8,6 +8,7 @@ NumPy reference guide.
from __future__ import division, absolute_import, print_function
import numpy as np
+from numpy.core.overrides import array_function_dispatch
__all__ = ['broadcast_to', 'broadcast_arrays']
@@ -135,6 +136,11 @@ def _broadcast_to(array, shape, subok, readonly):
return result
+def _broadcast_to_dispatcher(array, shape, subok=None):
+ return (array,)
+
+
+@array_function_dispatch(_broadcast_to_dispatcher, module='numpy')
def broadcast_to(array, shape, subok=False):
"""Broadcast an array to a new shape.
@@ -195,6 +201,11 @@ def _broadcast_shape(*args):
return b.shape
+def _broadcast_arrays_dispatcher(*args, **kwargs):
+ return args
+
+
+@array_function_dispatch(_broadcast_arrays_dispatcher, module='numpy')
def broadcast_arrays(*args, **kwargs):
"""
Broadcast any number of arrays against each other.
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 40cca1dbb..0c789e012 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -3114,3 +3114,29 @@ class TestAdd_newdoc(object):
assert_equal(np.core.flatiter.index.__doc__[:len(tgt)], tgt)
assert_(len(np.core.ufunc.identity.__doc__) > 300)
assert_(len(np.lib.index_tricks.mgrid.__doc__) > 300)
+
+class TestSortComplex(object):
+
+ @pytest.mark.parametrize("type_in, type_out", [
+ ('l', 'D'),
+ ('h', 'F'),
+ ('H', 'F'),
+ ('b', 'F'),
+ ('B', 'F'),
+ ('g', 'G'),
+ ])
+ def test_sort_real(self, type_in, type_out):
+ # sort_complex() type casting for real input types
+ a = np.array([5, 3, 6, 2, 1], dtype=type_in)
+ actual = np.sort_complex(a)
+ expected = np.sort(a).astype(type_out)
+ assert_equal(actual, expected)
+ assert_equal(actual.dtype, expected.dtype)
+
+ def test_sort_complex(self):
+ # sort_complex() handling of complex input
+ a = np.array([2 + 3j, 1 - 2j, 1 - 3j, 2 + 1j], dtype='D')
+ expected = np.array([1 - 3j, 1 - 2j, 2 + 1j, 2 + 3j], dtype='D')
+ actual = np.sort_complex(a)
+ assert_equal(actual, expected)
+ assert_equal(actual.dtype, expected.dtype)
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index 76d9b403e..3246f68ff 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -226,6 +226,11 @@ class TestConcatenator(object):
g = r_[-10.1, np.array([1]), np.array([2, 3, 4]), 10.0]
assert_(g.dtype == 'f8')
+ def test_complex_step(self):
+ # Regression test for #12262
+ g = r_[0:36:100j]
+ assert_(g.shape == (100,))
+
def test_2d(self):
b = np.random.rand(5, 5)
c = np.random.rand(5, 5)
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index 98efba191..a05e68375 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -3,11 +3,14 @@
"""
from __future__ import division, absolute_import, print_function
+import functools
+
from numpy.core.numeric import (
absolute, asanyarray, arange, zeros, greater_equal, multiply, ones,
asarray, where, int8, int16, int32, int64, empty, promote_types, diagonal,
nonzero
)
+from numpy.core import overrides
from numpy.core import iinfo, transpose
@@ -17,6 +20,10 @@ __all__ = [
'tril_indices_from', 'triu_indices', 'triu_indices_from', ]
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
i1 = iinfo(int8)
i2 = iinfo(int16)
i4 = iinfo(int32)
@@ -33,6 +40,11 @@ def _min_int(low, high):
return int64
+def _flip_dispatcher(m):
+ return (m,)
+
+
+@array_function_dispatch(_flip_dispatcher)
def fliplr(m):
"""
Flip array in the left/right direction.
@@ -83,6 +95,7 @@ def fliplr(m):
return m[:, ::-1]
+@array_function_dispatch(_flip_dispatcher)
def flipud(m):
"""
Flip array in the up/down direction.
@@ -194,6 +207,11 @@ def eye(N, M=None, k=0, dtype=float, order='C'):
return m
+def _diag_dispatcher(v, k=None):
+ return (v,)
+
+
+@array_function_dispatch(_diag_dispatcher)
def diag(v, k=0):
"""
Extract a diagonal or construct a diagonal array.
@@ -265,6 +283,7 @@ def diag(v, k=0):
raise ValueError("Input must be 1- or 2-d.")
+@array_function_dispatch(_diag_dispatcher)
def diagflat(v, k=0):
"""
Create a two-dimensional array with the flattened input as a diagonal.
@@ -373,6 +392,11 @@ def tri(N, M=None, k=0, dtype=float):
return m
+def _trilu_dispatcher(m, k=None):
+ return (m,)
+
+
+@array_function_dispatch(_trilu_dispatcher)
def tril(m, k=0):
"""
Lower triangle of an array.
@@ -411,6 +435,7 @@ def tril(m, k=0):
return where(mask, m, zeros(1, m.dtype))
+@array_function_dispatch(_trilu_dispatcher)
def triu(m, k=0):
"""
Upper triangle of an array.
@@ -439,7 +464,12 @@ def triu(m, k=0):
return where(mask, zeros(1, m.dtype), m)
+def _vander_dispatcher(x, N=None, increasing=None):
+ return (x,)
+
+
# Originally borrowed from John Hunter and matplotlib
+@array_function_dispatch(_vander_dispatcher)
def vander(x, N=None, increasing=False):
"""
Generate a Vandermonde matrix.
@@ -530,6 +560,12 @@ def vander(x, N=None, increasing=False):
return v
+def _histogram2d_dispatcher(x, y, bins=None, range=None, normed=None,
+ weights=None, density=None):
+ return (x, y, bins, weights)
+
+
+@array_function_dispatch(_histogram2d_dispatcher)
def histogram2d(x, y, bins=10, range=None, normed=None, weights=None,
density=None):
"""
@@ -812,6 +848,11 @@ def tril_indices(n, k=0, m=None):
return nonzero(tri(n, m, k=k, dtype=bool))
+def _trilu_indices_form_dispatcher(arr, k=None):
+ return (arr,)
+
+
+@array_function_dispatch(_trilu_indices_form_dispatcher)
def tril_indices_from(arr, k=0):
"""
Return the indices for the lower-triangle of arr.
@@ -922,6 +963,7 @@ def triu_indices(n, k=0, m=None):
return nonzero(~tri(n, m, k=k-1, dtype=bool))
+@array_function_dispatch(_trilu_indices_form_dispatcher)
def triu_indices_from(arr, k=0):
"""
Return the indices for the upper-triangle of arr.
diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py
index 603da8567..9153e1692 100644
--- a/numpy/lib/type_check.py
+++ b/numpy/lib/type_check.py
@@ -2,6 +2,7 @@
"""
from __future__ import division, absolute_import, print_function
+import functools
import warnings
__all__ = ['iscomplexobj', 'isrealobj', 'imag', 'iscomplex',
@@ -11,10 +12,17 @@ __all__ = ['iscomplexobj', 'isrealobj', 'imag', 'iscomplex',
import numpy.core.numeric as _nx
from numpy.core.numeric import asarray, asanyarray, array, isnan, zeros
+from numpy.core import overrides
from .ufunclike import isneginf, isposinf
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
_typecodes_by_elsize = 'GDFgdfQqLlIiHhBb?'
+
def mintypecode(typechars,typeset='GDFgdf',default='d'):
"""
Return the character for the minimum-size type to which given types can
@@ -104,6 +112,11 @@ def asfarray(a, dtype=_nx.float_):
return asarray(a, dtype=dtype)
+def _real_dispatcher(val):
+ return (val,)
+
+
+@array_function_dispatch(_real_dispatcher)
def real(val):
"""
Return the real part of the complex argument.
@@ -145,6 +158,11 @@ def real(val):
return asanyarray(val).real
+def _imag_dispatcher(val):
+ return (val,)
+
+
+@array_function_dispatch(_imag_dispatcher)
def imag(val):
"""
Return the imaginary part of the complex argument.
@@ -183,6 +201,11 @@ def imag(val):
return asanyarray(val).imag
+def _is_type_dispatcher(x):
+ return (x,)
+
+
+@array_function_dispatch(_is_type_dispatcher)
def iscomplex(x):
"""
Returns a bool array, where True if input element is complex.
@@ -218,6 +241,8 @@ def iscomplex(x):
res = zeros(ax.shape, bool)
return res[()] # convert to scalar if needed
+
+@array_function_dispatch(_is_type_dispatcher)
def isreal(x):
"""
Returns a bool array, where True if input element is real.
@@ -248,6 +273,8 @@ def isreal(x):
"""
return imag(x) == 0
+
+@array_function_dispatch(_is_type_dispatcher)
def iscomplexobj(x):
"""
Check for a complex type or an array of complex numbers.
@@ -288,6 +315,7 @@ def iscomplexobj(x):
return issubclass(type_, _nx.complexfloating)
+@array_function_dispatch(_is_type_dispatcher)
def isrealobj(x):
"""
Return True if x is a not complex type or an array of complex numbers.
@@ -329,6 +357,12 @@ def _getmaxmin(t):
f = getlimits.finfo(t)
return f.max, f.min
+
+def _nan_to_num_dispatcher(x, copy=None):
+ return (x,)
+
+
+@array_function_dispatch(_nan_to_num_dispatcher)
def nan_to_num(x, copy=True):
"""
Replace NaN with zero and infinity with large finite numbers.
@@ -411,7 +445,12 @@ def nan_to_num(x, copy=True):
#-----------------------------------------------------------------------------
-def real_if_close(a,tol=100):
+def _real_if_close_dispatcher(a, tol=None):
+ return (a,)
+
+
+@array_function_dispatch(_real_if_close_dispatcher)
+def real_if_close(a, tol=100):
"""
If complex input returns a real array if complex parts are close to zero.
@@ -466,6 +505,11 @@ def real_if_close(a,tol=100):
return a
+def _asscalar_dispatcher(a):
+ return (a,)
+
+
+@array_function_dispatch(_asscalar_dispatcher)
def asscalar(a):
"""
Convert an array of size 1 to its scalar equivalent.
@@ -586,6 +630,13 @@ array_precision = {_nx.half: 0,
_nx.csingle: 1,
_nx.cdouble: 2,
_nx.clongdouble: 3}
+
+
+def _common_type_dispatcher(*arrays):
+ return arrays
+
+
+@array_function_dispatch(_common_type_dispatcher)
def common_type(*arrays):
"""
Return a scalar type which is common to the input arrays.
diff --git a/numpy/lib/ufunclike.py b/numpy/lib/ufunclike.py
index 6259c5445..ac0af0b37 100644
--- a/numpy/lib/ufunclike.py
+++ b/numpy/lib/ufunclike.py
@@ -8,6 +8,7 @@ from __future__ import division, absolute_import, print_function
__all__ = ['fix', 'isneginf', 'isposinf']
import numpy.core.numeric as nx
+from numpy.core.overrides import array_function_dispatch
import warnings
import functools
@@ -37,7 +38,30 @@ def _deprecate_out_named_y(f):
return func
+def _fix_out_named_y(f):
+ """
+ Allow the out argument to be passed as the name `y` (deprecated)
+
+ This decorator should only be used if _deprecate_out_named_y is used on
+ a corresponding dispatcher fucntion.
+ """
+ @functools.wraps(f)
+ def func(x, out=None, **kwargs):
+ if 'y' in kwargs:
+ # we already did error checking in _deprecate_out_named_y
+ out = kwargs.pop('y')
+ return f(x, out=out, **kwargs)
+
+ return func
+
+
@_deprecate_out_named_y
+def _dispatcher(x, out=None):
+ return (x, out)
+
+
+@array_function_dispatch(_dispatcher, verify=False, module='numpy')
+@_fix_out_named_y
def fix(x, out=None):
"""
Round to nearest integer towards zero.
@@ -83,7 +107,8 @@ def fix(x, out=None):
return res
-@_deprecate_out_named_y
+@array_function_dispatch(_dispatcher, verify=False, module='numpy')
+@_fix_out_named_y
def isposinf(x, out=None):
"""
Test element-wise for positive infinity, return result as bool array.
@@ -151,7 +176,8 @@ def isposinf(x, out=None):
return nx.logical_and(is_inf, signbit, out)
-@_deprecate_out_named_y
+@array_function_dispatch(_dispatcher, verify=False, module='numpy')
+@_fix_out_named_y
def isneginf(x, out=None):
"""
Test element-wise for negative infinity, return result as bool array.