diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2018-10-11 14:01:05 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-11 14:01:05 -0500 |
commit | eb2bd11870731ea19a0eee72e616c7deb00f6c54 (patch) | |
tree | b9eb90efea2f4757ce94da52417bd34b07f0b82b /numpy/core/numeric.py | |
parent | e309fb06d8c42e1500e3249e79b5d04fec0d66c2 (diff) | |
parent | 607842ab59b0479c485eb6fa30778f47dccc224a (diff) | |
download | numpy-eb2bd11870731ea19a0eee72e616c7deb00f6c54.tar.gz |
Merge pull request #12115 from shoyer/array-function-numpy-core
ENH: __array_function__ support for most of numpy.core
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 56ac69424..6e4e585c3 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -28,6 +28,7 @@ if sys.version_info[0] < 3: from .multiarray import newbuffer, getbuffer from . import umath +from .overrides import array_function_dispatch from .umath import (multiply, invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE, ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT, ERR_LOG, ERR_DEFAULT, PINF, NAN) @@ -97,6 +98,11 @@ class ComplexWarning(RuntimeWarning): pass +def _zeros_like_dispatcher(a, dtype=None, order=None, subok=None): + return (a,) + + +@array_function_dispatch(_zeros_like_dispatcher) def zeros_like(a, dtype=None, order='K', subok=True): """ Return an array of zeros with the same shape and type as a given array. @@ -211,6 +217,11 @@ def ones(shape, dtype=None, order='C'): return a +def _ones_like_dispatcher(a, dtype=None, order=None, subok=None): + return (a,) + + +@array_function_dispatch(_ones_like_dispatcher) def ones_like(a, dtype=None, order='K', subok=True): """ Return an array of ones with the same shape and type as a given array. @@ -317,6 +328,11 @@ def full(shape, fill_value, dtype=None, order='C'): return a +def _full_like_dispatcher(a, fill_value, dtype=None, order=None, subok=None): + return (a,) + + +@array_function_dispatch(_full_like_dispatcher) def full_like(a, fill_value, dtype=None, order='K', subok=True): """ Return a full array with the same shape and type as a given array. @@ -374,6 +390,11 @@ def full_like(a, fill_value, dtype=None, order='K', subok=True): return res +def _count_nonzero_dispatcher(a, axis=None): + return (a,) + + +@array_function_dispatch(_count_nonzero_dispatcher) def count_nonzero(a, axis=None): """ Counts the number of non-zero values in the array ``a``. @@ -793,6 +814,11 @@ def isfortran(a): return a.flags.fnc +def _argwhere_dispatcher(a): + return (a,) + + +@array_function_dispatch(_argwhere_dispatcher) def argwhere(a): """ Find the indices of array elements that are non-zero, grouped by element. @@ -834,6 +860,11 @@ def argwhere(a): return transpose(nonzero(a)) +def _flatnonzero_dispatcher(a): + return (a,) + + +@array_function_dispatch(_flatnonzero_dispatcher) def flatnonzero(a): """ Return indices that are non-zero in the flattened version of a. @@ -885,6 +916,11 @@ def _mode_from_name(mode): return mode +def _correlate_dispatcher(a, v, mode=None): + return (a, v) + + +@array_function_dispatch(_correlate_dispatcher) def correlate(a, v, mode='valid'): """ Cross-correlation of two 1-dimensional sequences. @@ -953,6 +989,11 @@ def correlate(a, v, mode='valid'): return multiarray.correlate2(a, v, mode) +def _convolve_dispatcher(a, v, mode=None): + return (a, v) + + +@array_function_dispatch(_convolve_dispatcher) def convolve(a, v, mode='full'): """ Returns the discrete, linear convolution of two one-dimensional sequences. @@ -1052,6 +1093,11 @@ def convolve(a, v, mode='full'): return multiarray.correlate(a, v[::-1], mode) +def _outer_dispatcher(a, b, out=None): + return (a, b, out) + + +@array_function_dispatch(_outer_dispatcher) def outer(a, b, out=None): """ Compute the outer product of two vectors. @@ -1136,6 +1182,11 @@ def outer(a, b, out=None): return multiply(a.ravel()[:, newaxis], b.ravel()[newaxis, :], out) +def _tensordot_dispatcher(a, b, axes=None): + return (a, b) + + +@array_function_dispatch(_tensordot_dispatcher) def tensordot(a, b, axes=2): """ Compute tensor dot product along specified axes for arrays >= 1-D. @@ -1322,6 +1373,11 @@ def tensordot(a, b, axes=2): return res.reshape(olda + oldb) +def _roll_dispatcher(a, shift, axis=None): + return (a,) + + +@array_function_dispatch(_roll_dispatcher) def roll(a, shift, axis=None): """ Roll array elements along a given axis. @@ -1411,6 +1467,11 @@ def roll(a, shift, axis=None): return result +def _rollaxis_dispatcher(a, axis, start=None): + return (a,) + + +@array_function_dispatch(_rollaxis_dispatcher) def rollaxis(a, axis, start=0): """ Roll the specified axis backwards, until it lies in a given position. @@ -1531,6 +1592,11 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): return axis +def _moveaxis_dispatcher(a, source, destination): + return (a,) + + +@array_function_dispatch(_moveaxis_dispatcher) def moveaxis(a, source, destination): """ Move axes of an array to new positions. @@ -1607,6 +1673,11 @@ def _move_axis_to_0(a, axis): return moveaxis(a, axis, 0) +def _cross_dispatcher(a, b, axisa=None, axisb=None, axisc=None, axis=None): + return (a, b) + + +@array_function_dispatch(_cross_dispatcher) def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): """ Return the cross product of two (arrays of) vectors. @@ -2250,6 +2321,11 @@ def identity(n, dtype=None): return eye(n, dtype=dtype) +def _allclose_dispatcher(a, b, rtol=None, atol=None, equal_nan=None): + return (a, b) + + +@array_function_dispatch(_allclose_dispatcher) def allclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): """ Returns True if two arrays are element-wise equal within a tolerance. @@ -2321,6 +2397,11 @@ def allclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): return bool(res) +def _isclose_dispatcher(a, b, rtol=None, atol=None, equal_nan=None): + return (a, b) + + +@array_function_dispatch(_isclose_dispatcher) def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): """ Returns a boolean array where two arrays are element-wise equal within a @@ -2436,6 +2517,11 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): return cond[()] # Flatten 0d arrays to scalars +def _array_equal_dispatcher(a1, a2): + return (a1, a2) + + +@array_function_dispatch(_array_equal_dispatcher) def array_equal(a1, a2): """ True if two arrays have the same shape and elements, False otherwise. @@ -2478,6 +2564,11 @@ def array_equal(a1, a2): return bool(asarray(a1 == a2).all()) +def _array_equiv_dispatcher(a1, a2): + return (a1, a2) + + +@array_function_dispatch(_array_equiv_dispatcher) def array_equiv(a1, a2): """ Returns True if input arrays are shape consistent and all elements equal. |