summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_dtype.py4
-rw-r--r--numpy/core/_dtype_ctypes.py68
-rw-r--r--numpy/core/_internal.py6
-rw-r--r--numpy/core/arrayprint.py6
-rw-r--r--numpy/core/defchararray.py7
-rw-r--r--numpy/core/einsumfunc.py21
-rw-r--r--numpy/core/fromnumeric.py6
-rw-r--r--numpy/core/multiarray.py7
-rw-r--r--numpy/core/numeric.py7
-rw-r--r--numpy/core/overrides.py11
-rw-r--r--numpy/core/setup.py2
-rw-r--r--numpy/core/shape_base.py217
-rw-r--r--numpy/core/src/common/npy_ctypes.h49
-rw-r--r--numpy/core/src/multiarray/ctors.c12
-rw-r--r--numpy/core/src/multiarray/descriptor.c124
-rw-r--r--numpy/core/src/multiarray/descriptor.h2
-rw-r--r--numpy/core/src/multiarray/scalarapi.c2
-rw-r--r--numpy/core/tests/test_dtype.py49
-rw-r--r--numpy/core/tests/test_overrides.py48
-rw-r--r--numpy/core/tests/test_shape_base.py172
-rw-r--r--numpy/fft/fftpack.py8
-rw-r--r--numpy/fft/helper.py4
-rw-r--r--numpy/lib/arraypad.py2
-rw-r--r--numpy/lib/arraysetops.py8
-rw-r--r--numpy/lib/financial.py7
-rw-r--r--numpy/lib/function_base.py8
-rw-r--r--numpy/lib/index_tricks.py10
-rw-r--r--numpy/lib/nanfunctions.py7
-rw-r--r--numpy/lib/npyio.py29
-rw-r--r--numpy/lib/polynomial.py57
-rw-r--r--numpy/lib/recfunctions.py88
-rw-r--r--numpy/lib/scimath.py38
-rw-r--r--numpy/lib/shape_base.py82
-rw-r--r--numpy/lib/stride_tricks.py11
-rw-r--r--numpy/lib/tests/test_function_base.py26
-rw-r--r--numpy/lib/tests/test_index_tricks.py5
-rw-r--r--numpy/lib/twodim_base.py42
-rw-r--r--numpy/lib/type_check.py53
-rw-r--r--numpy/lib/ufunclike.py30
-rw-r--r--numpy/linalg/linalg.py8
-rw-r--r--numpy/testing/_private/utils.py36
-rw-r--r--numpy/testing/tests/test_utils.py56
42 files changed, 1246 insertions, 189 deletions
diff --git a/numpy/core/_dtype.py b/numpy/core/_dtype.py
index d115e0fa6..3a12c8fad 100644
--- a/numpy/core/_dtype.py
+++ b/numpy/core/_dtype.py
@@ -138,7 +138,9 @@ def _scalar_str(dtype, short):
else:
return "'%sU%d'" % (byteorder, dtype.itemsize / 4)
- elif dtype.type == np.void:
+ # unlike the other types, subclasses of void are preserved - but
+ # historically the repr does not actually reveal the subclass
+ elif issubclass(dtype.type, np.void):
if _isunsized(dtype):
return "'V'"
else:
diff --git a/numpy/core/_dtype_ctypes.py b/numpy/core/_dtype_ctypes.py
new file mode 100644
index 000000000..f10b4e99f
--- /dev/null
+++ b/numpy/core/_dtype_ctypes.py
@@ -0,0 +1,68 @@
+"""
+Conversion from ctypes to dtype.
+
+In an ideal world, we could acheive this through the PEP3118 buffer protocol,
+something like::
+
+ def dtype_from_ctypes_type(t):
+ # needed to ensure that the shape of `t` is within memoryview.format
+ class DummyStruct(ctypes.Structure):
+ _fields_ = [('a', t)]
+
+ # empty to avoid memory allocation
+ ctype_0 = (DummyStruct * 0)()
+ mv = memoryview(ctype_0)
+
+ # convert the struct, and slice back out the field
+ return _dtype_from_pep3118(mv.format)['a']
+
+Unfortunately, this fails because:
+
+* ctypes cannot handle length-0 arrays with PEP3118 (bpo-32782)
+* PEP3118 cannot represent unions, but both numpy and ctypes can
+* ctypes cannot handle big-endian structs with PEP3118 (bpo-32780)
+"""
+import _ctypes
+import ctypes
+
+import numpy as np
+
+
+def _from_ctypes_array(t):
+ return np.dtype((dtype_from_ctypes_type(t._type_), (t._length_,)))
+
+
+def _from_ctypes_structure(t):
+ # TODO: gh-10533, gh-10532
+ fields = []
+ for item in t._fields_:
+ if len(item) > 2:
+ raise TypeError(
+ "ctypes bitfields have no dtype equivalent")
+ fname, ftyp = item
+ fields.append((fname, dtype_from_ctypes_type(ftyp)))
+
+ # by default, ctypes structs are aligned
+ return np.dtype(fields, align=True)
+
+
+def dtype_from_ctypes_type(t):
+ """
+ Construct a dtype object from a ctypes type
+ """
+ if issubclass(t, _ctypes.Array):
+ return _from_ctypes_array(t)
+ elif issubclass(t, _ctypes._Pointer):
+ raise TypeError("ctypes pointers have no dtype equivalent")
+ elif issubclass(t, _ctypes.Structure):
+ return _from_ctypes_structure(t)
+ elif issubclass(t, _ctypes.Union):
+ # TODO
+ raise NotImplementedError(
+ "conversion from ctypes.Union types like {} to dtype"
+ .format(t.__name__))
+ elif isinstance(t._type_, str):
+ return np.dtype(t._type_)
+ else:
+ raise NotImplementedError(
+ "Unknown ctypes type {}".format(t.__name__))
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index c4d967dc2..30069f0ca 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -796,13 +796,13 @@ def _ufunc_doc_signature_formatter(ufunc):
)
-def _is_from_ctypes(obj):
- # determine if an object comes from ctypes, in order to work around
+def npy_ctypes_check(cls):
+ # determine if a class comes from ctypes, in order to work around
# a bug in the buffer protocol for those objects, bpo-10746
try:
# ctypes class are new-style, so have an __mro__. This probably fails
# for ctypes classes with multiple inheritance.
- ctype_base = type(obj).__mro__[-2]
+ ctype_base = cls.__mro__[-2]
# right now, they're part of the _ctypes module
return 'ctypes' in ctype_base.__module__
except Exception:
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index 1b9fbbfa9..3201b2f78 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -506,7 +506,7 @@ def _array2string_dispatcher(
return (a,)
-@array_function_dispatch(_array2string_dispatcher)
+@array_function_dispatch(_array2string_dispatcher, module='numpy')
def array2string(a, max_line_width=None, precision=None,
suppress_small=None, separator=' ', prefix="",
style=np._NoValue, formatter=None, threshold=None,
@@ -1386,7 +1386,7 @@ def _array_repr_dispatcher(
return (arr,)
-@array_function_dispatch(_array_repr_dispatcher)
+@array_function_dispatch(_array_repr_dispatcher, module='numpy')
def array_repr(arr, max_line_width=None, precision=None, suppress_small=None):
"""
Return the string representation of an array.
@@ -1480,7 +1480,7 @@ def _array_str_dispatcher(
return (a,)
-@array_function_dispatch(_array_str_dispatcher)
+@array_function_dispatch(_array_str_dispatcher, module='numpy')
def array_str(a, max_line_width=None, precision=None, suppress_small=None):
"""
Return a string representation of the data in an array.
diff --git a/numpy/core/defchararray.py b/numpy/core/defchararray.py
index 0a8c7bbec..e86086012 100644
--- a/numpy/core/defchararray.py
+++ b/numpy/core/defchararray.py
@@ -17,12 +17,13 @@ The preferred alias for `defchararray` is `numpy.char`.
"""
from __future__ import division, absolute_import, print_function
+import functools
import sys
from .numerictypes import string_, unicode_, integer, object_, bool_, character
from .numeric import ndarray, compare_chararrays
from .numeric import array as narray
from numpy.core.multiarray import _vec_string
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides
from numpy.compat import asbytes, long
import numpy
@@ -48,6 +49,10 @@ else:
_bytes = str
_len = len
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy.char')
+
+
def _use_unicode(*args):
"""
Helper function for determining the output type of some string
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py
index 1281b3c98..3ffb152e1 100644
--- a/numpy/core/einsumfunc.py
+++ b/numpy/core/einsumfunc.py
@@ -9,6 +9,7 @@ import itertools
from numpy.compat import basestring
from numpy.core.multiarray import c_einsum
from numpy.core.numeric import asanyarray, tensordot
+from numpy.core.overrides import array_function_dispatch
__all__ = ['einsum', 'einsum_path']
@@ -689,6 +690,17 @@ def _parse_einsum_input(operands):
return (input_subscripts, output_subscript, operands)
+def _einsum_path_dispatcher(*operands, **kwargs):
+ # NOTE: technically, we should only dispatch on array-like arguments, not
+ # subscripts (given as strings). But separating operands into
+ # arrays/subscripts is a little tricky/slow (given einsum's two supported
+ # signatures), so as a practical shortcut we dispatch on everything.
+ # Strings will be ignored for dispatching since they don't define
+ # __array_function__.
+ return operands
+
+
+@array_function_dispatch(_einsum_path_dispatcher)
def einsum_path(*operands, **kwargs):
"""
einsum_path(subscripts, *operands, optimize='greedy')
@@ -980,7 +992,16 @@ def einsum_path(*operands, **kwargs):
return (path, path_print)
+def _einsum_dispatcher(*operands, **kwargs):
+ # Arguably we dispatch on more arguments that we really should; see note in
+ # _einsum_path_dispatcher for why.
+ for op in operands:
+ yield op
+ yield kwargs.get('out')
+
+
# Rewrite einsum to handle different cases
+@array_function_dispatch(_einsum_dispatcher)
def einsum(*operands, **kwargs):
"""
einsum(subscripts, *operands, out=None, dtype=None, order='K',
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index 2fdbf3e23..7dfb52fea 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -3,16 +3,17 @@
"""
from __future__ import division, absolute_import, print_function
+import functools
import types
import warnings
import numpy as np
from .. import VisibleDeprecationWarning
from . import multiarray as mu
+from . import overrides
from . import umath as um
from . import numerictypes as nt
from .numeric import asarray, array, asanyarray, concatenate
-from .overrides import array_function_dispatch
from . import _methods
_dt_ = nt.sctype2char
@@ -32,6 +33,9 @@ _gentype = types.GeneratorType
# save away Python sum
_sum_ = sum
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
# functions that are now methods
def _wrapit(obj, method, *args, **kwds):
diff --git a/numpy/core/multiarray.py b/numpy/core/multiarray.py
index 4dbd3b0fd..25debd2f8 100644
--- a/numpy/core/multiarray.py
+++ b/numpy/core/multiarray.py
@@ -6,8 +6,10 @@ by importing from the extension module.
"""
+import functools
+
+from . import overrides
from . import _multiarray_umath
-from .overrides import array_function_dispatch
import numpy as np
from numpy.core._multiarray_umath import *
from numpy.core._multiarray_umath import (
@@ -37,6 +39,9 @@ __all__ = [
'tracemalloc_domain', 'typeinfo', 'unpackbits', 'unravel_index', 'vdot',
'where', 'zeros']
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
def _empty_like_dispatcher(prototype, dtype=None, order=None, subok=None):
return (prototype,)
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 6e4e585c3..5d82bbd8d 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -6,6 +6,7 @@ try:
import collections.abc as collections_abc
except ImportError:
import collections as collections_abc
+import functools
import itertools
import operator
import sys
@@ -27,8 +28,8 @@ from .multiarray import (
if sys.version_info[0] < 3:
from .multiarray import newbuffer, getbuffer
+from . import overrides
from . import umath
-from .overrides import array_function_dispatch
from .umath import (multiply, invert, sin, UFUNC_BUFSIZE_DEFAULT,
ERR_IGNORE, ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT,
ERR_LOG, ERR_DEFAULT, PINF, NAN)
@@ -55,6 +56,10 @@ else:
import __builtin__ as builtins
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
def loads(*args, **kwargs):
# NumPy 1.15.0, 2017-12-10
warnings.warn(
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 5be60cd29..4640efd31 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -105,9 +105,10 @@ def array_function_implementation_or_override(
if result is not NotImplemented:
return result
- raise TypeError('no implementation found for {} on types that implement '
+ func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
+ raise TypeError("no implementation found for '{}' on types that implement "
'__array_function__: {}'
- .format(public_api, list(map(type, overloaded_args))))
+ .format(func_name, list(map(type, overloaded_args))))
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
@@ -135,7 +136,7 @@ def verify_matching_signatures(implementation, dispatcher):
'default argument values')
-def array_function_dispatch(dispatcher, verify=True):
+def array_function_dispatch(dispatcher, module=None, verify=True):
"""Decorator for adding dispatch with the __array_function__ protocol."""
def decorator(implementation):
# TODO: only do this check when the appropriate flag is enabled or for
@@ -149,6 +150,10 @@ def array_function_dispatch(dispatcher, verify=True):
relevant_args = dispatcher(*args, **kwargs)
return array_function_implementation_or_override(
implementation, public_api, relevant_args, args, kwargs)
+
+ if module is not None:
+ public_api.__module__ = module
+
return public_api
return decorator
diff --git a/numpy/core/setup.py b/numpy/core/setup.py
index fc15fe59f..a4429cee2 100644
--- a/numpy/core/setup.py
+++ b/numpy/core/setup.py
@@ -731,7 +731,9 @@ def configuration(parent_package='',top_path=None):
join('src', 'common', 'lowlevel_strided_loops.h'),
join('src', 'common', 'mem_overlap.h'),
join('src', 'common', 'npy_config.h'),
+ join('src', 'common', 'npy_ctypes.h'),
join('src', 'common', 'npy_extint128.h'),
+ join('src', 'common', 'npy_import.h'),
join('src', 'common', 'npy_longdouble.h'),
join('src', 'common', 'templ_common.h.src'),
join('src', 'common', 'ucsnarrow.h'),
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index fde23076b..c9f8ebccb 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -3,6 +3,8 @@ from __future__ import division, absolute_import, print_function
__all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'block', 'hstack',
'stack', 'vstack']
+import functools
+import operator
from . import numeric as _nx
from .numeric import array, asanyarray, newaxis
@@ -432,6 +434,10 @@ def _block_check_depths_match(arrays, parent_index=[]):
refer to it, and the last index along the empty axis will be `None`.
max_arr_ndim : int
The maximum of the ndims of the arrays nested in `arrays`.
+ final_size: int
+ The number of elements in the final array. This is used the motivate
+ the choice of algorithm used using benchmarking wisdom.
+
"""
if type(arrays) is tuple:
# not strictly necessary, but saves us from:
@@ -450,8 +456,9 @@ def _block_check_depths_match(arrays, parent_index=[]):
idxs_ndims = (_block_check_depths_match(arr, parent_index + [i])
for i, arr in enumerate(arrays))
- first_index, max_arr_ndim = next(idxs_ndims)
- for index, ndim in idxs_ndims:
+ first_index, max_arr_ndim, final_size = next(idxs_ndims)
+ for index, ndim, size in idxs_ndims:
+ final_size += size
if ndim > max_arr_ndim:
max_arr_ndim = ndim
if len(index) != len(first_index):
@@ -466,13 +473,15 @@ def _block_check_depths_match(arrays, parent_index=[]):
# propagate our flag that indicates an empty list at the bottom
if index[-1] is None:
first_index = index
- return first_index, max_arr_ndim
+
+ return first_index, max_arr_ndim, final_size
elif type(arrays) is list and len(arrays) == 0:
# We've 'bottomed out' on an empty list
- return parent_index + [None], 0
+ return parent_index + [None], 0, 0
else:
# We've 'bottomed out' - arrays is either a scalar or an array
- return parent_index, _nx.ndim(arrays)
+ size = _nx.size(arrays)
+ return parent_index, _nx.ndim(arrays), size
def _atleast_nd(a, ndim):
@@ -481,9 +490,132 @@ def _atleast_nd(a, ndim):
return array(a, ndmin=ndim, copy=False, subok=True)
+def _accumulate(values):
+ # Helper function because Python 2.7 doesn't have
+ # itertools.accumulate
+ value = 0
+ accumulated = []
+ for v in values:
+ value += v
+ accumulated.append(value)
+ return accumulated
+
+
+def _concatenate_shapes(shapes, axis):
+ """Given array shapes, return the resulting shape and slices prefixes.
+
+ These help in nested concatation.
+ Returns
+ -------
+ shape: tuple of int
+ This tuple satisfies:
+ ```
+ shape, _ = _concatenate_shapes([arr.shape for shape in arrs], axis)
+ shape == concatenate(arrs, axis).shape
+ ```
+
+ slice_prefixes: tuple of (slice(start, end), )
+ For a list of arrays being concatenated, this returns the slice
+ in the larger array at axis that needs to be sliced into.
+
+ For example, the following holds:
+ ```
+ ret = concatenate([a, b, c], axis)
+ _, (sl_a, sl_b, sl_c) = concatenate_slices([a, b, c], axis)
+
+ ret[(slice(None),) * axis + sl_a] == a
+ ret[(slice(None),) * axis + sl_b] == b
+ ret[(slice(None),) * axis + sl_c] == c
+ ```
+
+ Thses are called slice prefixes since they are used in the recursive
+ blocking algorithm to compute the left-most slices during the
+ recursion. Therefore, they must be prepended to rest of the slice
+ that was computed deeper in the recusion.
+
+ These are returned as tuples to ensure that they can quickly be added
+ to existing slice tuple without creating a new tuple everytime.
+
+ """
+ # Cache a result that will be reused.
+ shape_at_axis = [shape[axis] for shape in shapes]
+
+ # Take a shape, any shape
+ first_shape = shapes[0]
+ first_shape_pre = first_shape[:axis]
+ first_shape_post = first_shape[axis+1:]
+
+ if any(shape[:axis] != first_shape_pre or
+ shape[axis+1:] != first_shape_post for shape in shapes):
+ raise ValueError(
+ 'Mismatched array shapes in block along axis {}.'.format(axis))
+
+ shape = (first_shape_pre + (sum(shape_at_axis),) + first_shape[axis+1:])
+
+ offsets_at_axis = _accumulate(shape_at_axis)
+ slice_prefixes = [(slice(start, end),)
+ for start, end in zip([0] + offsets_at_axis,
+ offsets_at_axis)]
+ return shape, slice_prefixes
+
+
+def _block_info_recursion(arrays, max_depth, result_ndim, depth=0):
+ """
+ Returns the shape of the final array, along with a list
+ of slices and a list of arrays that can be used for assignment inside the
+ new array
+
+ Parameters
+ ----------
+ arrays : nested list of arrays
+ The arrays to check
+ max_depth : list of int
+ The number of nested lists
+ result_ndim: int
+ The number of dimensions in thefinal array.
+
+ Returns
+ -------
+ shape : tuple of int
+ The shape that the final array will take on.
+ slices: list of tuple of slices
+ The slices into the full array required for assignment. These are
+ required to be prepended with ``(Ellipsis, )`` to obtain to correct
+ final index.
+ arrays: list of ndarray
+ The data to assign to each slice of the full array
+
+ """
+ if depth < max_depth:
+ shapes, slices, arrays = zip(
+ *[_block_info_recursion(arr, max_depth, result_ndim, depth+1)
+ for arr in arrays])
+
+ axis = result_ndim - max_depth + depth
+ shape, slice_prefixes = _concatenate_shapes(shapes, axis)
+
+ # Prepend the slice prefix and flatten the slices
+ slices = [slice_prefix + the_slice
+ for slice_prefix, inner_slices in zip(slice_prefixes, slices)
+ for the_slice in inner_slices]
+
+ # Flatten the array list
+ arrays = functools.reduce(operator.add, arrays)
+
+ return shape, slices, arrays
+ else:
+ # We've 'bottomed out' - arrays is either a scalar or an array
+ # type(arrays) is not list
+ # Return the slice and the array inside a list to be consistent with
+ # the recursive case.
+ arr = _atleast_nd(arrays, result_ndim)
+ return arr.shape, [()], [arr]
+
+
def _block(arrays, max_depth, result_ndim, depth=0):
"""
- Internal implementation of block. `arrays` is the argument passed to
+ Internal implementation of block based on repeated concatenation.
+ `arrays` is the argument passed to
block. `max_depth` is the depth of nested lists within `arrays` and
`result_ndim` is the greatest of the dimensions of the arrays in
`arrays` and the depth of the lists in `arrays` (see block docstring
@@ -499,7 +631,19 @@ def _block(arrays, max_depth, result_ndim, depth=0):
return _atleast_nd(arrays, result_ndim)
-# TODO: support array_function_dispatch
+def _block_dispatcher(arrays):
+ # Use type(...) is list to match the behavior of np.block(), which special
+ # cases list specifically rather than allowing for generic iterables or
+ # tuple. Also, we know that list.__array_function__ will never exist.
+ if type(arrays) is list:
+ for subarrays in arrays:
+ for subarray in _block_dispatcher(subarrays):
+ yield subarray
+ else:
+ yield arrays
+
+
+@array_function_dispatch(_block_dispatcher)
def block(arrays):
"""
Assemble an nd-array from nested lists of blocks.
@@ -648,7 +792,38 @@ def block(arrays):
"""
- bottom_index, arr_ndim = _block_check_depths_match(arrays)
+ arrays, list_ndim, result_ndim, final_size = _block_setup(arrays)
+
+ # It was found through benchmarking that making an array of final size
+ # around 256x256 was faster by straight concatenation on a
+ # i7-7700HQ processor and dual channel ram 2400MHz.
+ # It didn't seem to matter heavily on the dtype used.
+ #
+ # A 2D array using repeated concatenation requires 2 copies of the array.
+ #
+ # The fastest algorithm will depend on the ratio of CPU power to memory
+ # speed.
+ # One can monitor the results of the benchmark
+ # https://pv.github.io/numpy-bench/#bench_shape_base.Block2D.time_block2d
+ # to tune this parameter until a C version of the `_block_info_recursion`
+ # algorithm is implemented which would likely be faster than the python
+ # version.
+ if list_ndim * final_size > (2 * 512 * 512):
+ return _block_slicing(arrays, list_ndim, result_ndim)
+ else:
+ return _block_concatenate(arrays, list_ndim, result_ndim)
+
+
+# Theses helper functions are mostly used for testing.
+# They allow us to write tests that directly call `_block_slicing`
+# or `_block_concatenate` wtihout blocking large arrays to forse the wisdom
+# to trigger the desired path.
+def _block_setup(arrays):
+ """
+ Returns
+ (`arrays`, list_ndim, result_ndim, final_size)
+ """
+ bottom_index, arr_ndim, final_size = _block_check_depths_match(arrays)
list_ndim = len(bottom_index)
if bottom_index and bottom_index[-1] is None:
raise ValueError(
@@ -656,7 +831,31 @@ def block(arrays):
_block_format_index(bottom_index)
)
)
- result = _block(arrays, list_ndim, max(arr_ndim, list_ndim))
+ result_ndim = max(arr_ndim, list_ndim)
+ return arrays, list_ndim, result_ndim, final_size
+
+
+def _block_slicing(arrays, list_ndim, result_ndim):
+ shape, slices, arrays = _block_info_recursion(
+ arrays, list_ndim, result_ndim)
+ dtype = _nx.result_type(*[arr.dtype for arr in arrays])
+
+ # Test preferring F only in the case that all input arrays are F
+ F_order = all(arr.flags['F_CONTIGUOUS'] for arr in arrays)
+ C_order = all(arr.flags['C_CONTIGUOUS'] for arr in arrays)
+ order = 'F' if F_order and not C_order else 'C'
+ result = _nx.empty(shape=shape, dtype=dtype, order=order)
+ # Note: In a c implementation, the function
+ # PyArray_CreateMultiSortedStridePerm could be used for more advanced
+ # guessing of the desired order.
+
+ for the_slice, arr in zip(slices, arrays):
+ result[(Ellipsis,) + the_slice] = arr
+ return result
+
+
+def _block_concatenate(arrays, list_ndim, result_ndim):
+ result = _block(arrays, list_ndim, result_ndim)
if list_ndim == 0:
# Catch an edge case where _block returns a view because
# `arrays` is a single numpy array and not a list of numpy arrays.
diff --git a/numpy/core/src/common/npy_ctypes.h b/numpy/core/src/common/npy_ctypes.h
new file mode 100644
index 000000000..f26db9e05
--- /dev/null
+++ b/numpy/core/src/common/npy_ctypes.h
@@ -0,0 +1,49 @@
+#ifndef NPY_CTYPES_H
+#define NPY_CTYPES_H
+
+#include <Python.h>
+
+#include "npy_import.h"
+
+/*
+ * Check if a python type is a ctypes class.
+ *
+ * Works like the Py<type>_Check functions, returning true if the argument
+ * looks like a ctypes object.
+ *
+ * This entire function is just a wrapper around the Python function of the
+ * same name.
+ */
+NPY_INLINE static int
+npy_ctypes_check(PyTypeObject *obj)
+{
+ static PyObject *py_func = NULL;
+ PyObject *ret_obj;
+ int ret;
+
+ npy_cache_import("numpy.core._internal", "npy_ctypes_check", &py_func);
+ if (py_func == NULL) {
+ goto fail;
+ }
+
+ ret_obj = PyObject_CallFunctionObjArgs(py_func, (PyObject *)obj, NULL);
+ if (ret_obj == NULL) {
+ goto fail;
+ }
+
+ ret = PyObject_IsTrue(ret_obj);
+ if (ret == -1) {
+ goto fail;
+ }
+
+ return ret;
+
+fail:
+ /* If the above fails, then we should just assume that the type is not from
+ * ctypes
+ */
+ PyErr_Clear();
+ return 0;
+}
+
+#endif
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c
index aaaaeee82..bf888659d 100644
--- a/numpy/core/src/multiarray/ctors.c
+++ b/numpy/core/src/multiarray/ctors.c
@@ -11,7 +11,7 @@
#include "npy_config.h"
-#include "npy_import.h"
+#include "npy_ctypes.h"
#include "npy_pycompat.h"
#include "multiarraymodule.h"
@@ -1381,15 +1381,7 @@ _array_from_buffer_3118(PyObject *memoryview)
* Note that even if the above are fixed in master, we have to drop the
* early patch versions of python to actually make use of the fixes.
*/
-
- int is_ctypes = _is_from_ctypes(view->obj);
- if (is_ctypes < 0) {
- /* This error is not useful */
- PyErr_WriteUnraisable(view->obj);
- is_ctypes = 0;
- }
-
- if (!is_ctypes) {
+ if (!npy_ctypes_check(Py_TYPE(view->obj))) {
/* This object has no excuse for a broken PEP3118 buffer */
PyErr_Format(
PyExc_RuntimeError,
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index 439980877..7acac8059 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -10,7 +10,7 @@
#include "numpy/arrayscalars.h"
#include "npy_config.h"
-
+#include "npy_ctypes.h"
#include "npy_pycompat.h"
#include "_datetime.h"
@@ -54,79 +54,46 @@ Borrowed_PyMapping_GetItemString(PyObject *o, char *key)
return ret;
}
-/*
- * Creates a dtype object from ctypes inputs.
- *
- * Returns a new reference to a dtype object, or NULL
- * if this is not possible. When it returns NULL, it does
- * not set a Python exception.
- */
static PyArray_Descr *
-_arraydescr_fromctypes(PyObject *obj)
+_arraydescr_from_ctypes_type(PyTypeObject *type)
{
- PyObject *dtypedescr;
- PyArray_Descr *newdescr;
- int ret;
+ PyObject *_numpy_dtype_ctypes;
+ PyObject *res;
- /* Understand basic ctypes */
- dtypedescr = PyObject_GetAttrString(obj, "_type_");
- PyErr_Clear();
- if (dtypedescr) {
- ret = PyArray_DescrConverter(dtypedescr, &newdescr);
- Py_DECREF(dtypedescr);
- if (ret == NPY_SUCCEED) {
- PyObject *length;
- /* Check for ctypes arrays */
- length = PyObject_GetAttrString(obj, "_length_");
- PyErr_Clear();
- if (length) {
- /* derived type */
- PyObject *newtup;
- PyArray_Descr *derived;
- newtup = Py_BuildValue("N(N)", newdescr, length);
- ret = PyArray_DescrConverter(newtup, &derived);
- Py_DECREF(newtup);
- if (ret == NPY_SUCCEED) {
- return derived;
- }
- PyErr_Clear();
- return NULL;
- }
- return newdescr;
- }
- PyErr_Clear();
+ /* Call the python function of the same name. */
+ _numpy_dtype_ctypes = PyImport_ImportModule("numpy.core._dtype_ctypes");
+ if (_numpy_dtype_ctypes == NULL) {
return NULL;
}
- /* Understand ctypes structures --
- bit-fields are not supported
- automatically aligns */
- dtypedescr = PyObject_GetAttrString(obj, "_fields_");
- PyErr_Clear();
- if (dtypedescr) {
- ret = PyArray_DescrAlignConverter(dtypedescr, &newdescr);
- Py_DECREF(dtypedescr);
- if (ret == NPY_SUCCEED) {
- return newdescr;
- }
- PyErr_Clear();
+ res = PyObject_CallMethod(_numpy_dtype_ctypes, "dtype_from_ctypes_type", "O", (PyObject *)type);
+ Py_DECREF(_numpy_dtype_ctypes);
+ if (res == NULL) {
+ return NULL;
}
- return NULL;
+ /*
+ * sanity check that dtype_from_ctypes_type returned the right type,
+ * since getting it wrong would give segfaults.
+ */
+ if (!PyObject_TypeCheck(res, &PyArrayDescr_Type)) {
+ Py_DECREF(res);
+ PyErr_BadInternalCall();
+ return NULL;
+ }
+
+ return (PyArray_Descr *)res;
}
/*
- * This function creates a dtype object when:
- * - The object has a "dtype" attribute, and it can be converted
- * to a dtype object.
- * - The object is a ctypes type object, including array
- * and structure types.
+ * This function creates a dtype object when the object has a "dtype" attribute,
+ * and it can be converted to a dtype object.
*
* Returns a new reference to a dtype object, or NULL
* if this is not possible. When it returns NULL, it does
* not set a Python exception.
*/
NPY_NO_EXPORT PyArray_Descr *
-_arraydescr_fromobj(PyObject *obj)
+_arraydescr_from_dtype_attr(PyObject *obj)
{
PyObject *dtypedescr;
PyArray_Descr *newdescr = NULL;
@@ -135,15 +102,18 @@ _arraydescr_fromobj(PyObject *obj)
/* For arbitrary objects that have a "dtype" attribute */
dtypedescr = PyObject_GetAttrString(obj, "dtype");
PyErr_Clear();
- if (dtypedescr != NULL) {
- ret = PyArray_DescrConverter(dtypedescr, &newdescr);
- Py_DECREF(dtypedescr);
- if (ret == NPY_SUCCEED) {
- return newdescr;
- }
+ if (dtypedescr == NULL) {
+ return NULL;
+ }
+
+ ret = PyArray_DescrConverter(dtypedescr, &newdescr);
+ Py_DECREF(dtypedescr);
+ if (ret != NPY_SUCCEED) {
PyErr_Clear();
+ return NULL;
}
- return _arraydescr_fromctypes(obj);
+
+ return newdescr;
}
/*
@@ -1423,10 +1393,20 @@ PyArray_DescrConverter(PyObject *obj, PyArray_Descr **at)
check_num = NPY_VOID;
}
else {
- *at = _arraydescr_fromobj(obj);
+ *at = _arraydescr_from_dtype_attr(obj);
if (*at) {
return NPY_SUCCEED;
}
+
+ /*
+ * Note: this comes after _arraydescr_from_dtype_attr because the ctypes
+ * type might override the dtype if numpy does not otherwise
+ * support it.
+ */
+ if (npy_ctypes_check((PyTypeObject *)obj)) {
+ *at = _arraydescr_from_ctypes_type((PyTypeObject *)obj);
+ return *at ? NPY_SUCCEED : NPY_FAIL;
+ }
}
goto finish;
}
@@ -1596,13 +1576,23 @@ PyArray_DescrConverter(PyObject *obj, PyArray_Descr **at)
goto fail;
}
else {
- *at = _arraydescr_fromobj(obj);
+ *at = _arraydescr_from_dtype_attr(obj);
if (*at) {
return NPY_SUCCEED;
}
if (PyErr_Occurred()) {
return NPY_FAIL;
}
+
+ /*
+ * Note: this comes after _arraydescr_from_dtype_attr because the ctypes
+ * type might override the dtype if numpy does not otherwise
+ * support it.
+ */
+ if (npy_ctypes_check(Py_TYPE(obj))) {
+ *at = _arraydescr_from_ctypes_type(Py_TYPE(obj));
+ return *at ? NPY_SUCCEED : NPY_FAIL;
+ }
goto fail;
}
if (PyErr_Occurred()) {
diff --git a/numpy/core/src/multiarray/descriptor.h b/numpy/core/src/multiarray/descriptor.h
index 5a3e4b15f..a5f3b8cdf 100644
--- a/numpy/core/src/multiarray/descriptor.h
+++ b/numpy/core/src/multiarray/descriptor.h
@@ -8,7 +8,7 @@ NPY_NO_EXPORT PyObject *
array_set_typeDict(PyObject *NPY_UNUSED(ignored), PyObject *args);
NPY_NO_EXPORT PyArray_Descr *
-_arraydescr_fromobj(PyObject *obj);
+_arraydescr_from_dtype_attr(PyObject *obj);
NPY_NO_EXPORT int
diff --git a/numpy/core/src/multiarray/scalarapi.c b/numpy/core/src/multiarray/scalarapi.c
index 5ef6c0bbf..bc435d1ca 100644
--- a/numpy/core/src/multiarray/scalarapi.c
+++ b/numpy/core/src/multiarray/scalarapi.c
@@ -471,7 +471,7 @@ PyArray_DescrFromTypeObject(PyObject *type)
/* Do special thing for VOID sub-types */
if (PyType_IsSubtype((PyTypeObject *)type, &PyVoidArrType_Type)) {
new = PyArray_DescrNewFromType(NPY_VOID);
- conv = _arraydescr_fromobj(type);
+ conv = _arraydescr_from_dtype_attr(type);
if (conv) {
new->fields = conv->fields;
Py_INCREF(new->fields);
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 1bce86a5a..ecb51f72d 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -620,6 +620,25 @@ class TestString(object):
# Pull request #4722
np.array(["", ""]).astype(object)
+ def test_void_subclass_unsized(self):
+ dt = np.dtype(np.record)
+ assert_equal(repr(dt), "dtype('V')")
+ assert_equal(str(dt), '|V0')
+ assert_equal(dt.name, 'record')
+
+ def test_void_subclass_sized(self):
+ dt = np.dtype((np.record, 2))
+ assert_equal(repr(dt), "dtype('V2')")
+ assert_equal(str(dt), '|V2')
+ assert_equal(dt.name, 'record16')
+
+ def test_void_subclass_fields(self):
+ dt = np.dtype((np.record, [('a', '<u2')]))
+ assert_equal(repr(dt), "dtype((numpy.record, [('a', '<u2')]))")
+ assert_equal(str(dt), "(numpy.record, [('a', '<u2')])")
+ assert_equal(dt.name, 'record16')
+
+
class TestDtypeAttributeDeletion(object):
def test_dtype_non_writable_attributes_deletion(self):
@@ -775,6 +794,36 @@ class TestFromCTypes(object):
], align=True)
self.check(PaddedStruct, expected)
+ def test_bit_fields(self):
+ class BitfieldStruct(ctypes.Structure):
+ _fields_ = [
+ ('a', ctypes.c_uint8, 7),
+ ('b', ctypes.c_uint8, 1)
+ ]
+ assert_raises(TypeError, np.dtype, BitfieldStruct)
+ assert_raises(TypeError, np.dtype, BitfieldStruct())
+
+ def test_pointer(self):
+ p_uint8 = ctypes.POINTER(ctypes.c_uint8)
+ assert_raises(TypeError, np.dtype, p_uint8)
+
+ @pytest.mark.xfail(
+ reason="Unions are not implemented",
+ raises=NotImplementedError)
+ def test_union(self):
+ class Union(ctypes.Union):
+ _fields_ = [
+ ('a', ctypes.c_uint8),
+ ('b', ctypes.c_uint16),
+ ]
+ expected = np.dtype(dict(
+ names=['a', 'b'],
+ formats=[np.uint8, np.uint16],
+ offsets=[0, 0],
+ itemsize=2
+ ))
+ self.check(Union, expected)
+
@pytest.mark.xfail(reason="_pack_ is ignored - see gh-11651")
def test_packed_structure(self):
class PackedStructure(ctypes.Structure):
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 3f87a6afe..7b3472f96 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -16,8 +16,8 @@ def _get_overloaded_args(relevant_args):
return args
-def _return_self(self, *args, **kwargs):
- return self
+def _return_not_implemented(self, *args, **kwargs):
+ return NotImplemented
class TestGetOverloadedTypesAndArgs(object):
@@ -45,7 +45,7 @@ class TestGetOverloadedTypesAndArgs(object):
def test_ndarray_subclasses(self):
class OverrideSub(np.ndarray):
- __array_function__ = _return_self
+ __array_function__ = _return_not_implemented
class NoOverrideSub(np.ndarray):
pass
@@ -70,7 +70,7 @@ class TestGetOverloadedTypesAndArgs(object):
def test_ndarray_and_duck_array(self):
class Other(object):
- __array_function__ = _return_self
+ __array_function__ = _return_not_implemented
array = np.array(1)
other = Other()
@@ -86,10 +86,10 @@ class TestGetOverloadedTypesAndArgs(object):
def test_ndarray_subclass_and_duck_array(self):
class OverrideSub(np.ndarray):
- __array_function__ = _return_self
+ __array_function__ = _return_not_implemented
class Other(object):
- __array_function__ = _return_self
+ __array_function__ = _return_not_implemented
array = np.array(1)
subarray = np.array(1).view(OverrideSub)
@@ -103,16 +103,16 @@ class TestGetOverloadedTypesAndArgs(object):
def test_many_duck_arrays(self):
class A(object):
- __array_function__ = _return_self
+ __array_function__ = _return_not_implemented
class B(A):
- __array_function__ = _return_self
+ __array_function__ = _return_not_implemented
class C(A):
- __array_function__ = _return_self
+ __array_function__ = _return_not_implemented
class D(object):
- __array_function__ = _return_self
+ __array_function__ = _return_not_implemented
a = A()
b = B()
@@ -135,7 +135,7 @@ class TestNDArrayArrayFunction(object):
def test_method(self):
class SubOverride(np.ndarray):
- __array_function__ = _return_self
+ __array_function__ = _return_not_implemented
class NoOverrideSub(np.ndarray):
pass
@@ -189,7 +189,8 @@ class TestArrayFunctionDispatch(object):
assert_(obj is original)
assert_(func is dispatched_one_arg)
assert_equal(set(types), {MyArray})
- assert_equal(args, (original,))
+ # assert_equal uses the overloaded np.iscomplexobj() internally
+ assert_(args == (original,))
assert_equal(kwargs, {})
def test_not_implemented(self):
@@ -295,12 +296,31 @@ class TestArrayFunctionImplementation(object):
def test_not_implemented(self):
MyArray, implements = _new_duck_type_and_implements()
- @array_function_dispatch(lambda array: (array,))
+ @array_function_dispatch(lambda array: (array,), module='my')
def func(array):
return array
array = np.array(1)
assert_(func(array) is array)
- with assert_raises_regex(TypeError, 'no implementation found'):
+ with assert_raises_regex(
+ TypeError, "no implementation found for 'my.func'"):
func(MyArray())
+
+
+class TestNumPyFunctions(object):
+
+ def test_module(self):
+ assert_equal(np.sum.__module__, 'numpy')
+ assert_equal(np.char.equal.__module__, 'numpy.char')
+ assert_equal(np.fft.fft.__module__, 'numpy.fft')
+ assert_equal(np.linalg.solve.__module__, 'numpy.linalg')
+
+ def test_override_sum(self):
+ MyArray, implements = _new_duck_type_and_implements()
+
+ @implements(np.sum)
+ def _(array):
+ return 'yes'
+
+ assert_equal(np.sum(MyArray()), 'yes')
diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py
index df819b73f..9bedd8670 100644
--- a/numpy/core/tests/test_shape_base.py
+++ b/numpy/core/tests/test_shape_base.py
@@ -6,6 +6,8 @@ from numpy.core import (
array, arange, atleast_1d, atleast_2d, atleast_3d, block, vstack, hstack,
newaxis, concatenate, stack
)
+from numpy.core.shape_base import (_block_dispatcher, _block_setup,
+ _block_concatenate, _block_slicing)
from numpy.testing import (
assert_, assert_raises, assert_array_equal, assert_equal,
assert_raises_regex, assert_almost_equal
@@ -372,14 +374,63 @@ def test_stack():
stack, [np.arange(2), np.arange(3)])
+# See for more information on how to parametrize a whole class
+# https://docs.pytest.org/en/latest/example/parametrize.html#parametrizing-test-methods-through-per-class-configuration
+def pytest_generate_tests(metafunc):
+ # called once per each test function
+ if hasattr(metafunc.cls, 'params'):
+ arglist = metafunc.cls.params
+ argnames = sorted(arglist[0])
+ metafunc.parametrize(argnames,
+ [[funcargs[name] for name in argnames]
+ for funcargs in arglist])
+
+
+# blocking small arrays and large arrays go through different paths.
+# the algorithm is triggered depending on the number of element
+# copies required.
+# We define a test fixture that forces most tests to go through
+# both code paths.
+# Ultimately, this should be removed if a single algorithm is found
+# to be faster for both small and large arrays.s
+def _block_force_concatenate(arrays):
+ arrays, list_ndim, result_ndim, _ = _block_setup(arrays)
+ return _block_concatenate(arrays, list_ndim, result_ndim)
+
+
+def _block_force_slicing(arrays):
+ arrays, list_ndim, result_ndim, _ = _block_setup(arrays)
+ return _block_slicing(arrays, list_ndim, result_ndim)
+
+
class TestBlock(object):
- def test_returns_copy(self):
+ params = [dict(block=block),
+ dict(block=_block_force_concatenate),
+ dict(block=_block_force_slicing)]
+
+ def test_returns_copy(self, block):
a = np.eye(3)
- b = np.block(a)
+ b = block(a)
b[0, 0] = 2
assert b[0, 0] != a[0, 0]
- def test_block_simple_row_wise(self):
+ def test_block_total_size_estimate(self, block):
+ _, _, _, total_size = _block_setup([1])
+ assert total_size == 1
+
+ _, _, _, total_size = _block_setup([[1]])
+ assert total_size == 1
+
+ _, _, _, total_size = _block_setup([[1, 1]])
+ assert total_size == 2
+
+ _, _, _, total_size = _block_setup([[1], [1]])
+ assert total_size == 2
+
+ _, _, _, total_size = _block_setup([[1, 2], [3, 4]])
+ assert total_size == 4
+
+ def test_block_simple_row_wise(self, block):
a_2d = np.ones((2, 2))
b_2d = 2 * a_2d
desired = np.array([[1, 1, 2, 2],
@@ -387,7 +438,7 @@ class TestBlock(object):
result = block([a_2d, b_2d])
assert_equal(desired, result)
- def test_block_simple_column_wise(self):
+ def test_block_simple_column_wise(self, block):
a_2d = np.ones((2, 2))
b_2d = 2 * a_2d
expected = np.array([[1, 1],
@@ -397,7 +448,7 @@ class TestBlock(object):
result = block([[a_2d], [b_2d]])
assert_equal(expected, result)
- def test_block_with_1d_arrays_row_wise(self):
+ def test_block_with_1d_arrays_row_wise(self, block):
# # # 1-D vectors are treated as row arrays
a = np.array([1, 2, 3])
b = np.array([2, 3, 4])
@@ -405,7 +456,7 @@ class TestBlock(object):
result = block([a, b])
assert_equal(expected, result)
- def test_block_with_1d_arrays_multiple_rows(self):
+ def test_block_with_1d_arrays_multiple_rows(self, block):
a = np.array([1, 2, 3])
b = np.array([2, 3, 4])
expected = np.array([[1, 2, 3, 2, 3, 4],
@@ -413,7 +464,7 @@ class TestBlock(object):
result = block([[a, b], [a, b]])
assert_equal(expected, result)
- def test_block_with_1d_arrays_column_wise(self):
+ def test_block_with_1d_arrays_column_wise(self, block):
# # # 1-D vectors are treated as row arrays
a_1d = np.array([1, 2, 3])
b_1d = np.array([2, 3, 4])
@@ -422,7 +473,7 @@ class TestBlock(object):
result = block([[a_1d], [b_1d]])
assert_equal(expected, result)
- def test_block_mixed_1d_and_2d(self):
+ def test_block_mixed_1d_and_2d(self, block):
a_2d = np.ones((2, 2))
b_1d = np.array([2, 2])
result = block([[a_2d], [b_1d]])
@@ -431,7 +482,7 @@ class TestBlock(object):
[2, 2]])
assert_equal(expected, result)
- def test_block_complicated(self):
+ def test_block_complicated(self, block):
# a bit more complicated
one_2d = np.array([[1, 1, 1]])
two_2d = np.array([[2, 2, 2]])
@@ -455,7 +506,7 @@ class TestBlock(object):
[zero_2d]])
assert_equal(result, expected)
- def test_nested(self):
+ def test_nested(self, block):
one = np.array([1, 1, 1])
two = np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]])
three = np.array([3, 3, 3])
@@ -464,9 +515,9 @@ class TestBlock(object):
six = np.array([6, 6, 6, 6, 6])
zero = np.zeros((2, 6))
- result = np.block([
+ result = block([
[
- np.block([
+ block([
[one],
[three],
[four]
@@ -485,7 +536,7 @@ class TestBlock(object):
assert_equal(result, expected)
- def test_3d(self):
+ def test_3d(self, block):
a000 = np.ones((2, 2, 2), int) * 1
a100 = np.ones((3, 2, 2), int) * 2
@@ -498,7 +549,7 @@ class TestBlock(object):
a111 = np.ones((3, 3, 3), int) * 8
- result = np.block([
+ result = block([
[
[a000, a001],
[a010, a011],
@@ -540,55 +591,102 @@ class TestBlock(object):
assert_array_equal(result, expected)
- def test_block_with_mismatched_shape(self):
+ def test_block_with_mismatched_shape(self, block):
a = np.array([0, 0])
b = np.eye(2)
- assert_raises(ValueError, np.block, [a, b])
- assert_raises(ValueError, np.block, [b, a])
+ assert_raises(ValueError, block, [a, b])
+ assert_raises(ValueError, block, [b, a])
- def test_no_lists(self):
- assert_equal(np.block(1), np.array(1))
- assert_equal(np.block(np.eye(3)), np.eye(3))
+ to_block = [[np.ones((2,3)), np.ones((2,2))],
+ [np.ones((2,2)), np.ones((2,2))]]
+ assert_raises(ValueError, block, to_block)
+ def test_no_lists(self, block):
+ assert_equal(block(1), np.array(1))
+ assert_equal(block(np.eye(3)), np.eye(3))
- def test_invalid_nesting(self):
+ def test_invalid_nesting(self, block):
msg = 'depths are mismatched'
- assert_raises_regex(ValueError, msg, np.block, [1, [2]])
- assert_raises_regex(ValueError, msg, np.block, [1, []])
- assert_raises_regex(ValueError, msg, np.block, [[1], 2])
- assert_raises_regex(ValueError, msg, np.block, [[], 2])
- assert_raises_regex(ValueError, msg, np.block, [
+ assert_raises_regex(ValueError, msg, block, [1, [2]])
+ assert_raises_regex(ValueError, msg, block, [1, []])
+ assert_raises_regex(ValueError, msg, block, [[1], 2])
+ assert_raises_regex(ValueError, msg, block, [[], 2])
+ assert_raises_regex(ValueError, msg, block, [
[[1], [2]],
[[3, 4]],
[5] # missing brackets
])
- def test_empty_lists(self):
- assert_raises_regex(ValueError, 'empty', np.block, [])
- assert_raises_regex(ValueError, 'empty', np.block, [[]])
- assert_raises_regex(ValueError, 'empty', np.block, [[1], []])
+ def test_empty_lists(self, block):
+ assert_raises_regex(ValueError, 'empty', block, [])
+ assert_raises_regex(ValueError, 'empty', block, [[]])
+ assert_raises_regex(ValueError, 'empty', block, [[1], []])
- def test_tuple(self):
- assert_raises_regex(TypeError, 'tuple', np.block, ([1, 2], [3, 4]))
- assert_raises_regex(TypeError, 'tuple', np.block, [(1, 2), (3, 4)])
+ def test_tuple(self, block):
+ assert_raises_regex(TypeError, 'tuple', block, ([1, 2], [3, 4]))
+ assert_raises_regex(TypeError, 'tuple', block, [(1, 2), (3, 4)])
- def test_different_ndims(self):
+ def test_different_ndims(self, block):
a = 1.
b = 2 * np.ones((1, 2))
c = 3 * np.ones((1, 1, 3))
- result = np.block([a, b, c])
+ result = block([a, b, c])
expected = np.array([[[1., 2., 2., 3., 3., 3.]]])
assert_equal(result, expected)
- def test_different_ndims_depths(self):
+ def test_different_ndims_depths(self, block):
a = 1.
b = 2 * np.ones((1, 2))
c = 3 * np.ones((1, 2, 3))
- result = np.block([[a, b], [c]])
+ result = block([[a, b], [c]])
expected = np.array([[[1., 2., 2.],
[3., 3., 3.],
[3., 3., 3.]]])
assert_equal(result, expected)
+
+ def test_block_memory_order(self, block):
+ # 3D
+ arr_c = np.zeros((3,)*3, order='C')
+ arr_f = np.zeros((3,)*3, order='F')
+
+ b_c = [[[arr_c, arr_c],
+ [arr_c, arr_c]],
+ [[arr_c, arr_c],
+ [arr_c, arr_c]]]
+
+ b_f = [[[arr_f, arr_f],
+ [arr_f, arr_f]],
+ [[arr_f, arr_f],
+ [arr_f, arr_f]]]
+
+ assert block(b_c).flags['C_CONTIGUOUS']
+ assert block(b_f).flags['F_CONTIGUOUS']
+
+ arr_c = np.zeros((3, 3), order='C')
+ arr_f = np.zeros((3, 3), order='F')
+ # 2D
+ b_c = [[arr_c, arr_c],
+ [arr_c, arr_c]]
+
+ b_f = [[arr_f, arr_f],
+ [arr_f, arr_f]]
+
+ assert block(b_c).flags['C_CONTIGUOUS']
+ assert block(b_f).flags['F_CONTIGUOUS']
+
+
+def test_block_dispatcher():
+ class ArrayLike(object):
+ pass
+ a = ArrayLike()
+ b = ArrayLike()
+ c = ArrayLike()
+ assert_equal(list(_block_dispatcher(a)), [a])
+ assert_equal(list(_block_dispatcher([a])), [a])
+ assert_equal(list(_block_dispatcher([a, b])), [a, b])
+ assert_equal(list(_block_dispatcher([[a], [b, [c]]])), [a, b, c])
+ # don't recurse into non-lists
+ assert_equal(list(_block_dispatcher((a, b))), [(a, b)])
diff --git a/numpy/fft/fftpack.py b/numpy/fft/fftpack.py
index d88990373..de675936f 100644
--- a/numpy/fft/fftpack.py
+++ b/numpy/fft/fftpack.py
@@ -35,10 +35,12 @@ from __future__ import division, absolute_import, print_function
__all__ = ['fft', 'ifft', 'rfft', 'irfft', 'hfft', 'ihfft', 'rfftn',
'irfftn', 'rfft2', 'irfft2', 'fft2', 'ifft2', 'fftn', 'ifftn']
+import functools
+
from numpy.core import (array, asarray, zeros, swapaxes, shape, conjugate,
take, sqrt)
from numpy.core.multiarray import normalize_axis_index
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides
from . import fftpack_lite as fftpack
from .helper import _FFTCache
@@ -46,6 +48,10 @@ _fft_cache = _FFTCache(max_size_in_mb=100, max_item_count=32)
_real_fft_cache = _FFTCache(max_size_in_mb=100, max_item_count=32)
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy.fft')
+
+
def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti,
work_function=fftpack.cfftf, fft_cache=_fft_cache):
a = asarray(a)
diff --git a/numpy/fft/helper.py b/numpy/fft/helper.py
index 4b698bb4d..e65883651 100644
--- a/numpy/fft/helper.py
+++ b/numpy/fft/helper.py
@@ -24,7 +24,7 @@ def _fftshift_dispatcher(x, axes=None):
return (x,)
-@array_function_dispatch(_fftshift_dispatcher)
+@array_function_dispatch(_fftshift_dispatcher, module='numpy.fft')
def fftshift(x, axes=None):
"""
Shift the zero-frequency component to the center of the spectrum.
@@ -81,7 +81,7 @@ def fftshift(x, axes=None):
return roll(x, shift, axes)
-@array_function_dispatch(_fftshift_dispatcher)
+@array_function_dispatch(_fftshift_dispatcher, module='numpy.fft')
def ifftshift(x, axes=None):
"""
The inverse of `fftshift`. Although identical for even-length `x`, the
diff --git a/numpy/lib/arraypad.py b/numpy/lib/arraypad.py
index f76ad456f..d27a3918f 100644
--- a/numpy/lib/arraypad.py
+++ b/numpy/lib/arraypad.py
@@ -995,7 +995,7 @@ def _pad_dispatcher(array, pad_width, mode, **kwargs):
return (array,)
-@array_function_dispatch(_pad_dispatcher)
+@array_function_dispatch(_pad_dispatcher, module='numpy')
def pad(array, pad_width, mode, **kwargs):
"""
Pads an array.
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index ec62cd7a6..850e20123 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -27,8 +27,14 @@ To do: Optionally return indices analogously to unique for all functions.
"""
from __future__ import division, absolute_import, print_function
+import functools
+
import numpy as np
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides
+
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
__all__ = [
diff --git a/numpy/lib/financial.py b/numpy/lib/financial.py
index d1a0cd9c0..e1e297492 100644
--- a/numpy/lib/financial.py
+++ b/numpy/lib/financial.py
@@ -13,9 +13,14 @@ otherwise stated.
from __future__ import division, absolute_import, print_function
from decimal import Decimal
+import functools
import numpy as np
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides
+
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
__all__ = ['fv', 'pmt', 'nper', 'ipmt', 'ppmt', 'pv', 'rate',
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index c52ecdbd8..fae6541bc 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -6,6 +6,7 @@ try:
import collections.abc as collections_abc
except ImportError:
import collections as collections_abc
+import functools
import re
import sys
import warnings
@@ -26,7 +27,7 @@ from numpy.core.fromnumeric import (
ravel, nonzero, partition, mean, any, sum
)
from numpy.core.numerictypes import typecodes
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides
from numpy.core.function_base import add_newdoc
from numpy.lib.twodim_base import diag
from .utils import deprecate
@@ -44,6 +45,11 @@ if sys.version_info[0] < 3:
else:
import builtins
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
# needed in this module for compatibility
from numpy.lib.histograms import histogram, histogramdd
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 26243d231..ff2e00d3e 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -1,5 +1,6 @@
from __future__ import division, absolute_import, print_function
+import functools
import sys
import math
@@ -9,14 +10,17 @@ from numpy.core.numeric import (
)
from numpy.core.numerictypes import find_common_type, issubdtype
-from . import function_base
import numpy.matrixlib as matrixlib
from .function_base import diff
from numpy.core.multiarray import ravel_multi_index, unravel_index
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides, linspace
from numpy.lib.stride_tricks import as_strided
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
__all__ = [
'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_',
's_', 'index_exp', 'ix_', 'ndenumerate', 'ndindex', 'fill_diagonal',
@@ -341,7 +345,7 @@ class AxisConcatenator(object):
step = 1
if isinstance(step, complex):
size = int(abs(step))
- newobj = function_base.linspace(start, stop, num=size)
+ newobj = linspace(start, stop, num=size)
else:
newobj = _nx.arange(start, stop, step)
if ndmin > 1:
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py
index 279c4c5c4..d73d84467 100644
--- a/numpy/lib/nanfunctions.py
+++ b/numpy/lib/nanfunctions.py
@@ -22,10 +22,15 @@ Functions
"""
from __future__ import division, absolute_import, print_function
+import functools
import warnings
import numpy as np
from numpy.lib import function_base
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides
+
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
__all__ = [
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 62fc9c5b3..733795671 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -12,6 +12,7 @@ import numpy as np
from . import format
from ._datasource import DataSource
from numpy.core.multiarray import packbits, unpackbits
+from numpy.core.overrides import array_function_dispatch
from numpy.core._internal import recursive
from ._iotools import (
LineSplitter, NameValidator, StringConverter, ConverterError,
@@ -447,6 +448,11 @@ def load(file, mmap_mode=None, allow_pickle=True, fix_imports=True,
fid.close()
+def _save_dispatcher(file, arr, allow_pickle=None, fix_imports=None):
+ return (arr,)
+
+
+@array_function_dispatch(_save_dispatcher)
def save(file, arr, allow_pickle=True, fix_imports=True):
"""
Save an array to a binary file in NumPy ``.npy`` format.
@@ -525,6 +531,14 @@ def save(file, arr, allow_pickle=True, fix_imports=True):
fid.close()
+def _savez_dispatcher(file, *args, **kwds):
+ for a in args:
+ yield a
+ for v in kwds.values():
+ yield v
+
+
+@array_function_dispatch(_savez_dispatcher)
def savez(file, *args, **kwds):
"""
Save several arrays into a single file in uncompressed ``.npz`` format.
@@ -604,6 +618,14 @@ def savez(file, *args, **kwds):
_savez(file, args, kwds, False)
+def _savez_compressed_dispatcher(file, *args, **kwds):
+ for a in args:
+ yield a
+ for v in kwds.values():
+ yield v
+
+
+@array_function_dispatch(_savez_compressed_dispatcher)
def savez_compressed(file, *args, **kwds):
"""
Save several arrays into a single file in compressed ``.npz`` format.
@@ -1154,6 +1176,13 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None,
return X
+def _savetxt_dispatcher(fname, X, fmt=None, delimiter=None, newline=None,
+ header=None, footer=None, comments=None,
+ encoding=None):
+ return (X,)
+
+
+@array_function_dispatch(_savetxt_dispatcher)
def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='',
footer='', comments='# ', encoding=None):
"""
diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py
index 9f3b84732..c2702f0a7 100644
--- a/numpy/lib/polynomial.py
+++ b/numpy/lib/polynomial.py
@@ -8,17 +8,24 @@ __all__ = ['poly', 'roots', 'polyint', 'polyder', 'polyadd',
'polysub', 'polymul', 'polydiv', 'polyval', 'poly1d',
'polyfit', 'RankWarning']
+import functools
import re
import warnings
import numpy.core.numeric as NX
from numpy.core import (isscalar, abs, finfo, atleast_1d, hstack, dot, array,
ones)
+from numpy.core import overrides
from numpy.lib.twodim_base import diag, vander
from numpy.lib.function_base import trim_zeros
from numpy.lib.type_check import iscomplex, real, imag, mintypecode
from numpy.linalg import eigvals, lstsq, inv
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
class RankWarning(UserWarning):
"""
Issued by `polyfit` when the Vandermonde matrix is rank deficient.
@@ -29,6 +36,12 @@ class RankWarning(UserWarning):
"""
pass
+
+def _poly_dispatcher(seq_of_zeros):
+ return seq_of_zeros
+
+
+@array_function_dispatch(_poly_dispatcher)
def poly(seq_of_zeros):
"""
Find the coefficients of a polynomial with the given sequence of roots.
@@ -145,6 +158,12 @@ def poly(seq_of_zeros):
return a
+
+def _roots_dispatcher(p):
+ return p
+
+
+@array_function_dispatch(_roots_dispatcher)
def roots(p):
"""
Return the roots of a polynomial with coefficients given in p.
@@ -229,6 +248,12 @@ def roots(p):
roots = hstack((roots, NX.zeros(trailing_zeros, roots.dtype)))
return roots
+
+def _polyint_dispatcher(p, m=None, k=None):
+ return (p,)
+
+
+@array_function_dispatch(_polyint_dispatcher)
def polyint(p, m=1, k=None):
"""
Return an antiderivative (indefinite integral) of a polynomial.
@@ -322,6 +347,12 @@ def polyint(p, m=1, k=None):
return poly1d(val)
return val
+
+def _polyder_dispatcher(p, m=None):
+ return (p,)
+
+
+@array_function_dispatch(_polyder_dispatcher)
def polyder(p, m=1):
"""
Return the derivative of the specified order of a polynomial.
@@ -390,6 +421,12 @@ def polyder(p, m=1):
val = poly1d(val)
return val
+
+def _polyfit_dispatcher(x, y, deg, rcond=None, full=None, w=None, cov=None):
+ return (x, y, w)
+
+
+@array_function_dispatch(_polyfit_dispatcher)
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
"""
Least squares polynomial fit.
@@ -610,6 +647,11 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
return c
+def _polyval_dispatcher(p, x):
+ return (p, x)
+
+
+@array_function_dispatch(_polyval_dispatcher)
def polyval(p, x):
"""
Evaluate a polynomial at specific values.
@@ -679,6 +721,12 @@ def polyval(p, x):
y = y * x + p[i]
return y
+
+def _binary_op_dispatcher(a1, a2):
+ return (a1, a2)
+
+
+@array_function_dispatch(_binary_op_dispatcher)
def polyadd(a1, a2):
"""
Find the sum of two polynomials.
@@ -739,6 +787,8 @@ def polyadd(a1, a2):
val = poly1d(val)
return val
+
+@array_function_dispatch(_binary_op_dispatcher)
def polysub(a1, a2):
"""
Difference (subtraction) of two polynomials.
@@ -786,6 +836,7 @@ def polysub(a1, a2):
return val
+@array_function_dispatch(_binary_op_dispatcher)
def polymul(a1, a2):
"""
Find the product of two polynomials.
@@ -842,6 +893,12 @@ def polymul(a1, a2):
val = poly1d(val)
return val
+
+def _polydiv_dispatcher(u, v):
+ return (u, v)
+
+
+@array_function_dispatch(_polydiv_dispatcher)
def polydiv(u, v):
"""
Returns the quotient and remainder of polynomial division.
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py
index b6453d5a2..53a586f56 100644
--- a/numpy/lib/recfunctions.py
+++ b/numpy/lib/recfunctions.py
@@ -14,6 +14,7 @@ import numpy.ma as ma
from numpy import ndarray, recarray
from numpy.ma import MaskedArray
from numpy.ma.mrecords import MaskedRecords
+from numpy.core.overrides import array_function_dispatch
from numpy.lib._iotools import _is_string_like
from numpy.compat import basestring
@@ -31,6 +32,11 @@ __all__ = [
]
+def _recursive_fill_fields_dispatcher(input, output):
+ return (input, output)
+
+
+@array_function_dispatch(_recursive_fill_fields_dispatcher)
def recursive_fill_fields(input, output):
"""
Fills fields from output with fields from input,
@@ -189,6 +195,11 @@ def flatten_descr(ndtype):
return tuple(descr)
+def _zip_dtype_dispatcher(seqarrays, flatten=None):
+ return seqarrays
+
+
+@array_function_dispatch(_zip_dtype_dispatcher)
def zip_dtype(seqarrays, flatten=False):
newdtype = []
if flatten:
@@ -205,6 +216,7 @@ def zip_dtype(seqarrays, flatten=False):
return np.dtype(newdtype)
+@array_function_dispatch(_zip_dtype_dispatcher)
def zip_descr(seqarrays, flatten=False):
"""
Combine the dtype description of a series of arrays.
@@ -297,6 +309,11 @@ def _izip_fields(iterable):
yield element
+def _izip_records_dispatcher(seqarrays, fill_value=None, flatten=None):
+ return seqarrays
+
+
+@array_function_dispatch(_izip_records_dispatcher)
def izip_records(seqarrays, fill_value=None, flatten=True):
"""
Returns an iterator of concatenated items from a sequence of arrays.
@@ -357,6 +374,12 @@ def _fix_defaults(output, defaults=None):
return output
+def _merge_arrays_dispatcher(seqarrays, fill_value=None, flatten=None,
+ usemask=None, asrecarray=None):
+ return seqarrays
+
+
+@array_function_dispatch(_merge_arrays_dispatcher)
def merge_arrays(seqarrays, fill_value=-1, flatten=False,
usemask=False, asrecarray=False):
"""
@@ -494,6 +517,11 @@ def merge_arrays(seqarrays, fill_value=-1, flatten=False,
return output
+def _drop_fields_dispatcher(base, drop_names, usemask=None, asrecarray=None):
+ return (base,)
+
+
+@array_function_dispatch(_drop_fields_dispatcher)
def drop_fields(base, drop_names, usemask=True, asrecarray=False):
"""
Return a new array with fields in `drop_names` dropped.
@@ -583,6 +611,11 @@ def _keep_fields(base, keep_names, usemask=True, asrecarray=False):
return _fix_output(output, usemask=usemask, asrecarray=asrecarray)
+def _rec_drop_fields_dispatcher(base, drop_names):
+ return (base,)
+
+
+@array_function_dispatch(_rec_drop_fields_dispatcher)
def rec_drop_fields(base, drop_names):
"""
Returns a new numpy.recarray with fields in `drop_names` dropped.
@@ -590,6 +623,11 @@ def rec_drop_fields(base, drop_names):
return drop_fields(base, drop_names, usemask=False, asrecarray=True)
+def _rename_fields_dispatcher(base, namemapper):
+ return (base,)
+
+
+@array_function_dispatch(_rename_fields_dispatcher)
def rename_fields(base, namemapper):
"""
Rename the fields from a flexible-datatype ndarray or recarray.
@@ -629,6 +667,14 @@ def rename_fields(base, namemapper):
return base.view(newdtype)
+def _append_fields_dispatcher(base, names, data, dtypes=None,
+ fill_value=None, usemask=None, asrecarray=None):
+ yield base
+ for d in data:
+ yield d
+
+
+@array_function_dispatch(_append_fields_dispatcher)
def append_fields(base, names, data, dtypes=None,
fill_value=-1, usemask=True, asrecarray=False):
"""
@@ -699,6 +745,13 @@ def append_fields(base, names, data, dtypes=None,
return _fix_output(output, usemask=usemask, asrecarray=asrecarray)
+def _rec_append_fields_dispatcher(base, names, data, dtypes=None):
+ yield base
+ for d in data:
+ yield d
+
+
+@array_function_dispatch(_rec_append_fields_dispatcher)
def rec_append_fields(base, names, data, dtypes=None):
"""
Add new fields to an existing array.
@@ -732,6 +785,12 @@ def rec_append_fields(base, names, data, dtypes=None):
return append_fields(base, names, data=data, dtypes=dtypes,
asrecarray=True, usemask=False)
+
+def _repack_fields_dispatcher(a, align=None, recurse=None):
+ return (a,)
+
+
+@array_function_dispatch(_repack_fields_dispatcher)
def repack_fields(a, align=False, recurse=False):
"""
Re-pack the fields of a structured array or dtype in memory.
@@ -811,6 +870,13 @@ def repack_fields(a, align=False, recurse=False):
dt = np.dtype(fieldinfo, align=align)
return np.dtype((a.type, dt))
+
+def _stack_arrays_dispatcher(arrays, defaults=None, usemask=None,
+ asrecarray=None, autoconvert=None):
+ return arrays
+
+
+@array_function_dispatch(_stack_arrays_dispatcher)
def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False,
autoconvert=False):
"""
@@ -897,6 +963,12 @@ def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False,
usemask=usemask, asrecarray=asrecarray)
+def _find_duplicates_dispatcher(
+ a, key=None, ignoremask=None, return_index=None):
+ return (a,)
+
+
+@array_function_dispatch(_find_duplicates_dispatcher)
def find_duplicates(a, key=None, ignoremask=True, return_index=False):
"""
Find the duplicates in a structured array along a given key
@@ -951,8 +1023,15 @@ def find_duplicates(a, key=None, ignoremask=True, return_index=False):
return duplicates
+def _join_by_dispatcher(
+ key, r1, r2, jointype=None, r1postfix=None, r2postfix=None,
+ defaults=None, usemask=None, asrecarray=None):
+ return (r1, r2)
+
+
+@array_function_dispatch(_join_by_dispatcher)
def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
- defaults=None, usemask=True, asrecarray=False):
+ defaults=None, usemask=True, asrecarray=False):
"""
Join arrays `r1` and `r2` on key `key`.
@@ -1130,6 +1209,13 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
return _fix_output(_fix_defaults(output, defaults), **kwargs)
+def _rec_join_dispatcher(
+ key, r1, r2, jointype=None, r1postfix=None, r2postfix=None,
+ defaults=None):
+ return (r1, r2)
+
+
+@array_function_dispatch(_rec_join_dispatcher)
def rec_join(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
defaults=None):
"""
diff --git a/numpy/lib/scimath.py b/numpy/lib/scimath.py
index f1838fee6..9ca006841 100644
--- a/numpy/lib/scimath.py
+++ b/numpy/lib/scimath.py
@@ -20,6 +20,7 @@ from __future__ import division, absolute_import, print_function
import numpy.core.numeric as nx
import numpy.core.numerictypes as nt
from numpy.core.numeric import asarray, any
+from numpy.core.overrides import array_function_dispatch
from numpy.lib.type_check import isreal
@@ -94,6 +95,7 @@ def _tocomplex(arr):
else:
return arr.astype(nt.cdouble)
+
def _fix_real_lt_zero(x):
"""Convert `x` to complex if it has real, negative components.
@@ -121,6 +123,7 @@ def _fix_real_lt_zero(x):
x = _tocomplex(x)
return x
+
def _fix_int_lt_zero(x):
"""Convert `x` to double if it has real, negative components.
@@ -147,6 +150,7 @@ def _fix_int_lt_zero(x):
x = x * 1.0
return x
+
def _fix_real_abs_gt_1(x):
"""Convert `x` to complex if it has real components x_i with abs(x_i)>1.
@@ -173,6 +177,12 @@ def _fix_real_abs_gt_1(x):
x = _tocomplex(x)
return x
+
+def _unary_dispatcher(x):
+ return (x,)
+
+
+@array_function_dispatch(_unary_dispatcher)
def sqrt(x):
"""
Compute the square root of x.
@@ -215,6 +225,8 @@ def sqrt(x):
x = _fix_real_lt_zero(x)
return nx.sqrt(x)
+
+@array_function_dispatch(_unary_dispatcher)
def log(x):
"""
Compute the natural logarithm of `x`.
@@ -261,6 +273,8 @@ def log(x):
x = _fix_real_lt_zero(x)
return nx.log(x)
+
+@array_function_dispatch(_unary_dispatcher)
def log10(x):
"""
Compute the logarithm base 10 of `x`.
@@ -309,6 +323,12 @@ def log10(x):
x = _fix_real_lt_zero(x)
return nx.log10(x)
+
+def _logn_dispatcher(n, x):
+ return (n, x,)
+
+
+@array_function_dispatch(_logn_dispatcher)
def logn(n, x):
"""
Take log base n of x.
@@ -318,8 +338,8 @@ def logn(n, x):
Parameters
----------
- n : int
- The base in which the log is taken.
+ n : array_like
+ The integer base(s) in which the log is taken.
x : array_like
The value(s) whose log base `n` is (are) required.
@@ -343,6 +363,8 @@ def logn(n, x):
n = _fix_real_lt_zero(n)
return nx.log(x)/nx.log(n)
+
+@array_function_dispatch(_unary_dispatcher)
def log2(x):
"""
Compute the logarithm base 2 of `x`.
@@ -389,6 +411,12 @@ def log2(x):
x = _fix_real_lt_zero(x)
return nx.log2(x)
+
+def _power_dispatcher(x, p):
+ return (x, p)
+
+
+@array_function_dispatch(_power_dispatcher)
def power(x, p):
"""
Return x to the power p, (x**p).
@@ -432,6 +460,8 @@ def power(x, p):
p = _fix_int_lt_zero(p)
return nx.power(x, p)
+
+@array_function_dispatch(_unary_dispatcher)
def arccos(x):
"""
Compute the inverse cosine of x.
@@ -475,6 +505,8 @@ def arccos(x):
x = _fix_real_abs_gt_1(x)
return nx.arccos(x)
+
+@array_function_dispatch(_unary_dispatcher)
def arcsin(x):
"""
Compute the inverse sine of x.
@@ -519,6 +551,8 @@ def arcsin(x):
x = _fix_real_abs_gt_1(x)
return nx.arcsin(x)
+
+@array_function_dispatch(_unary_dispatcher)
def arctanh(x):
"""
Compute the inverse hyperbolic tangent of `x`.
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 66f534734..00424d55d 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -1,5 +1,6 @@
from __future__ import division, absolute_import, print_function
+import functools
import warnings
import numpy.core.numeric as _nx
@@ -8,6 +9,7 @@ from numpy.core.numeric import (
)
from numpy.core.fromnumeric import product, reshape, transpose
from numpy.core.multiarray import normalize_axis_index
+from numpy.core import overrides
from numpy.core import vstack, atleast_3d
from numpy.lib.index_tricks import ndindex
from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells
@@ -21,6 +23,10 @@ __all__ = [
]
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
def _make_along_axis_idx(arr_shape, indices, axis):
# compute dimensions to iterate over
if not _nx.issubdtype(indices.dtype, _nx.integer):
@@ -44,6 +50,11 @@ def _make_along_axis_idx(arr_shape, indices, axis):
return tuple(fancy_index)
+def _take_along_axis_dispatcher(arr, indices, axis):
+ return (arr, indices)
+
+
+@array_function_dispatch(_take_along_axis_dispatcher)
def take_along_axis(arr, indices, axis):
"""
Take values from the input array by matching 1d index and data slices.
@@ -160,6 +171,11 @@ def take_along_axis(arr, indices, axis):
return arr[_make_along_axis_idx(arr_shape, indices, axis)]
+def _put_along_axis_dispatcher(arr, indices, values, axis):
+ return (arr, indices, values)
+
+
+@array_function_dispatch(_put_along_axis_dispatcher)
def put_along_axis(arr, indices, values, axis):
"""
Put values into the destination array by matching 1d index and data slices.
@@ -245,6 +261,11 @@ def put_along_axis(arr, indices, values, axis):
arr[_make_along_axis_idx(arr_shape, indices, axis)] = values
+def _apply_along_axis_dispatcher(func1d, axis, arr, *args, **kwargs):
+ return (arr,)
+
+
+@array_function_dispatch(_apply_along_axis_dispatcher)
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
"""
Apply a function to 1-D slices along the given axis.
@@ -392,6 +413,11 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
return res.__array_wrap__(out_arr)
+def _apply_over_axes_dispatcher(func, a, axes):
+ return (a,)
+
+
+@array_function_dispatch(_apply_over_axes_dispatcher)
def apply_over_axes(func, a, axes):
"""
Apply a function repeatedly over multiple axes.
@@ -474,9 +500,15 @@ def apply_over_axes(func, a, axes):
val = res
else:
raise ValueError("function is not returning "
- "an array of the correct shape")
+ "an array of the correct shape")
return val
+
+def _expand_dims_dispatcher(a, axis):
+ return (a,)
+
+
+@array_function_dispatch(_expand_dims_dispatcher)
def expand_dims(a, axis):
"""
Expand the shape of an array.
@@ -554,8 +586,15 @@ def expand_dims(a, axis):
# axis = normalize_axis_index(axis, a.ndim + 1)
return a.reshape(shape[:axis] + (1,) + shape[axis:])
+
row_stack = vstack
+
+def _column_stack_dispatcher(tup):
+ return tup
+
+
+@array_function_dispatch(_column_stack_dispatcher)
def column_stack(tup):
"""
Stack 1-D arrays as columns into a 2-D array.
@@ -597,6 +636,12 @@ def column_stack(tup):
arrays.append(arr)
return _nx.concatenate(arrays, 1)
+
+def _dstack_dispatcher(tup):
+ return tup
+
+
+@array_function_dispatch(_dstack_dispatcher)
def dstack(tup):
"""
Stack arrays in sequence depth wise (along third axis).
@@ -649,6 +694,7 @@ def dstack(tup):
"""
return _nx.concatenate([atleast_3d(_m) for _m in tup], 2)
+
def _replace_zero_by_x_arrays(sub_arys):
for i in range(len(sub_arys)):
if _nx.ndim(sub_arys[i]) == 0:
@@ -657,6 +703,12 @@ def _replace_zero_by_x_arrays(sub_arys):
sub_arys[i] = _nx.empty(0, dtype=sub_arys[i].dtype)
return sub_arys
+
+def _array_split_dispatcher(ary, indices_or_sections, axis=None):
+ return (ary, indices_or_sections)
+
+
+@array_function_dispatch(_array_split_dispatcher)
def array_split(ary, indices_or_sections, axis=0):
"""
Split an array into multiple sub-arrays.
@@ -712,7 +764,12 @@ def array_split(ary, indices_or_sections, axis=0):
return sub_arys
-def split(ary,indices_or_sections,axis=0):
+def _split_dispatcher(ary, indices_or_sections, axis=None):
+ return (ary, indices_or_sections)
+
+
+@array_function_dispatch(_split_dispatcher)
+def split(ary, indices_or_sections, axis=0):
"""
Split an array into multiple sub-arrays.
@@ -789,6 +846,12 @@ def split(ary,indices_or_sections,axis=0):
res = array_split(ary, indices_or_sections, axis)
return res
+
+def _hvdsplit_dispatcher(ary, indices_or_sections):
+ return (ary, indices_or_sections)
+
+
+@array_function_dispatch(_hvdsplit_dispatcher)
def hsplit(ary, indices_or_sections):
"""
Split an array into multiple sub-arrays horizontally (column-wise).
@@ -851,6 +914,8 @@ def hsplit(ary, indices_or_sections):
else:
return split(ary, indices_or_sections, 0)
+
+@array_function_dispatch(_hvdsplit_dispatcher)
def vsplit(ary, indices_or_sections):
"""
Split an array into multiple sub-arrays vertically (row-wise).
@@ -902,6 +967,8 @@ def vsplit(ary, indices_or_sections):
raise ValueError('vsplit only works on arrays of 2 or more dimensions')
return split(ary, indices_or_sections, 0)
+
+@array_function_dispatch(_hvdsplit_dispatcher)
def dsplit(ary, indices_or_sections):
"""
Split array into multiple sub-arrays along the 3rd axis (depth).
@@ -971,6 +1038,12 @@ def get_array_wrap(*args):
return wrappers[-1][-1]
return None
+
+def _kron_dispatcher(a, b):
+ return (a, b)
+
+
+@array_function_dispatch(_kron_dispatcher)
def kron(a, b):
"""
Kronecker product of two arrays.
@@ -1070,6 +1143,11 @@ def kron(a, b):
return result
+def _tile_dispatcher(A, reps):
+ return (A, reps)
+
+
+@array_function_dispatch(_tile_dispatcher)
def tile(A, reps):
"""
Construct an array by repeating A the number of times given by reps.
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py
index ca13738c1..0dc36e41c 100644
--- a/numpy/lib/stride_tricks.py
+++ b/numpy/lib/stride_tricks.py
@@ -8,6 +8,7 @@ NumPy reference guide.
from __future__ import division, absolute_import, print_function
import numpy as np
+from numpy.core.overrides import array_function_dispatch
__all__ = ['broadcast_to', 'broadcast_arrays']
@@ -135,6 +136,11 @@ def _broadcast_to(array, shape, subok, readonly):
return result
+def _broadcast_to_dispatcher(array, shape, subok=None):
+ return (array,)
+
+
+@array_function_dispatch(_broadcast_to_dispatcher, module='numpy')
def broadcast_to(array, shape, subok=False):
"""Broadcast an array to a new shape.
@@ -195,6 +201,11 @@ def _broadcast_shape(*args):
return b.shape
+def _broadcast_arrays_dispatcher(*args, **kwargs):
+ return args
+
+
+@array_function_dispatch(_broadcast_arrays_dispatcher, module='numpy')
def broadcast_arrays(*args, **kwargs):
"""
Broadcast any number of arrays against each other.
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 40cca1dbb..0c789e012 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -3114,3 +3114,29 @@ class TestAdd_newdoc(object):
assert_equal(np.core.flatiter.index.__doc__[:len(tgt)], tgt)
assert_(len(np.core.ufunc.identity.__doc__) > 300)
assert_(len(np.lib.index_tricks.mgrid.__doc__) > 300)
+
+class TestSortComplex(object):
+
+ @pytest.mark.parametrize("type_in, type_out", [
+ ('l', 'D'),
+ ('h', 'F'),
+ ('H', 'F'),
+ ('b', 'F'),
+ ('B', 'F'),
+ ('g', 'G'),
+ ])
+ def test_sort_real(self, type_in, type_out):
+ # sort_complex() type casting for real input types
+ a = np.array([5, 3, 6, 2, 1], dtype=type_in)
+ actual = np.sort_complex(a)
+ expected = np.sort(a).astype(type_out)
+ assert_equal(actual, expected)
+ assert_equal(actual.dtype, expected.dtype)
+
+ def test_sort_complex(self):
+ # sort_complex() handling of complex input
+ a = np.array([2 + 3j, 1 - 2j, 1 - 3j, 2 + 1j], dtype='D')
+ expected = np.array([1 - 3j, 1 - 2j, 2 + 1j, 2 + 3j], dtype='D')
+ actual = np.sort_complex(a)
+ assert_equal(actual, expected)
+ assert_equal(actual.dtype, expected.dtype)
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index 76d9b403e..3246f68ff 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -226,6 +226,11 @@ class TestConcatenator(object):
g = r_[-10.1, np.array([1]), np.array([2, 3, 4]), 10.0]
assert_(g.dtype == 'f8')
+ def test_complex_step(self):
+ # Regression test for #12262
+ g = r_[0:36:100j]
+ assert_(g.shape == (100,))
+
def test_2d(self):
b = np.random.rand(5, 5)
c = np.random.rand(5, 5)
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index 98efba191..a05e68375 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -3,11 +3,14 @@
"""
from __future__ import division, absolute_import, print_function
+import functools
+
from numpy.core.numeric import (
absolute, asanyarray, arange, zeros, greater_equal, multiply, ones,
asarray, where, int8, int16, int32, int64, empty, promote_types, diagonal,
nonzero
)
+from numpy.core import overrides
from numpy.core import iinfo, transpose
@@ -17,6 +20,10 @@ __all__ = [
'tril_indices_from', 'triu_indices', 'triu_indices_from', ]
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
i1 = iinfo(int8)
i2 = iinfo(int16)
i4 = iinfo(int32)
@@ -33,6 +40,11 @@ def _min_int(low, high):
return int64
+def _flip_dispatcher(m):
+ return (m,)
+
+
+@array_function_dispatch(_flip_dispatcher)
def fliplr(m):
"""
Flip array in the left/right direction.
@@ -83,6 +95,7 @@ def fliplr(m):
return m[:, ::-1]
+@array_function_dispatch(_flip_dispatcher)
def flipud(m):
"""
Flip array in the up/down direction.
@@ -194,6 +207,11 @@ def eye(N, M=None, k=0, dtype=float, order='C'):
return m
+def _diag_dispatcher(v, k=None):
+ return (v,)
+
+
+@array_function_dispatch(_diag_dispatcher)
def diag(v, k=0):
"""
Extract a diagonal or construct a diagonal array.
@@ -265,6 +283,7 @@ def diag(v, k=0):
raise ValueError("Input must be 1- or 2-d.")
+@array_function_dispatch(_diag_dispatcher)
def diagflat(v, k=0):
"""
Create a two-dimensional array with the flattened input as a diagonal.
@@ -373,6 +392,11 @@ def tri(N, M=None, k=0, dtype=float):
return m
+def _trilu_dispatcher(m, k=None):
+ return (m,)
+
+
+@array_function_dispatch(_trilu_dispatcher)
def tril(m, k=0):
"""
Lower triangle of an array.
@@ -411,6 +435,7 @@ def tril(m, k=0):
return where(mask, m, zeros(1, m.dtype))
+@array_function_dispatch(_trilu_dispatcher)
def triu(m, k=0):
"""
Upper triangle of an array.
@@ -439,7 +464,12 @@ def triu(m, k=0):
return where(mask, zeros(1, m.dtype), m)
+def _vander_dispatcher(x, N=None, increasing=None):
+ return (x,)
+
+
# Originally borrowed from John Hunter and matplotlib
+@array_function_dispatch(_vander_dispatcher)
def vander(x, N=None, increasing=False):
"""
Generate a Vandermonde matrix.
@@ -530,6 +560,12 @@ def vander(x, N=None, increasing=False):
return v
+def _histogram2d_dispatcher(x, y, bins=None, range=None, normed=None,
+ weights=None, density=None):
+ return (x, y, bins, weights)
+
+
+@array_function_dispatch(_histogram2d_dispatcher)
def histogram2d(x, y, bins=10, range=None, normed=None, weights=None,
density=None):
"""
@@ -812,6 +848,11 @@ def tril_indices(n, k=0, m=None):
return nonzero(tri(n, m, k=k, dtype=bool))
+def _trilu_indices_form_dispatcher(arr, k=None):
+ return (arr,)
+
+
+@array_function_dispatch(_trilu_indices_form_dispatcher)
def tril_indices_from(arr, k=0):
"""
Return the indices for the lower-triangle of arr.
@@ -922,6 +963,7 @@ def triu_indices(n, k=0, m=None):
return nonzero(~tri(n, m, k=k-1, dtype=bool))
+@array_function_dispatch(_trilu_indices_form_dispatcher)
def triu_indices_from(arr, k=0):
"""
Return the indices for the upper-triangle of arr.
diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py
index 603da8567..9153e1692 100644
--- a/numpy/lib/type_check.py
+++ b/numpy/lib/type_check.py
@@ -2,6 +2,7 @@
"""
from __future__ import division, absolute_import, print_function
+import functools
import warnings
__all__ = ['iscomplexobj', 'isrealobj', 'imag', 'iscomplex',
@@ -11,10 +12,17 @@ __all__ = ['iscomplexobj', 'isrealobj', 'imag', 'iscomplex',
import numpy.core.numeric as _nx
from numpy.core.numeric import asarray, asanyarray, array, isnan, zeros
+from numpy.core import overrides
from .ufunclike import isneginf, isposinf
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy')
+
+
_typecodes_by_elsize = 'GDFgdfQqLlIiHhBb?'
+
def mintypecode(typechars,typeset='GDFgdf',default='d'):
"""
Return the character for the minimum-size type to which given types can
@@ -104,6 +112,11 @@ def asfarray(a, dtype=_nx.float_):
return asarray(a, dtype=dtype)
+def _real_dispatcher(val):
+ return (val,)
+
+
+@array_function_dispatch(_real_dispatcher)
def real(val):
"""
Return the real part of the complex argument.
@@ -145,6 +158,11 @@ def real(val):
return asanyarray(val).real
+def _imag_dispatcher(val):
+ return (val,)
+
+
+@array_function_dispatch(_imag_dispatcher)
def imag(val):
"""
Return the imaginary part of the complex argument.
@@ -183,6 +201,11 @@ def imag(val):
return asanyarray(val).imag
+def _is_type_dispatcher(x):
+ return (x,)
+
+
+@array_function_dispatch(_is_type_dispatcher)
def iscomplex(x):
"""
Returns a bool array, where True if input element is complex.
@@ -218,6 +241,8 @@ def iscomplex(x):
res = zeros(ax.shape, bool)
return res[()] # convert to scalar if needed
+
+@array_function_dispatch(_is_type_dispatcher)
def isreal(x):
"""
Returns a bool array, where True if input element is real.
@@ -248,6 +273,8 @@ def isreal(x):
"""
return imag(x) == 0
+
+@array_function_dispatch(_is_type_dispatcher)
def iscomplexobj(x):
"""
Check for a complex type or an array of complex numbers.
@@ -288,6 +315,7 @@ def iscomplexobj(x):
return issubclass(type_, _nx.complexfloating)
+@array_function_dispatch(_is_type_dispatcher)
def isrealobj(x):
"""
Return True if x is a not complex type or an array of complex numbers.
@@ -329,6 +357,12 @@ def _getmaxmin(t):
f = getlimits.finfo(t)
return f.max, f.min
+
+def _nan_to_num_dispatcher(x, copy=None):
+ return (x,)
+
+
+@array_function_dispatch(_nan_to_num_dispatcher)
def nan_to_num(x, copy=True):
"""
Replace NaN with zero and infinity with large finite numbers.
@@ -411,7 +445,12 @@ def nan_to_num(x, copy=True):
#-----------------------------------------------------------------------------
-def real_if_close(a,tol=100):
+def _real_if_close_dispatcher(a, tol=None):
+ return (a,)
+
+
+@array_function_dispatch(_real_if_close_dispatcher)
+def real_if_close(a, tol=100):
"""
If complex input returns a real array if complex parts are close to zero.
@@ -466,6 +505,11 @@ def real_if_close(a,tol=100):
return a
+def _asscalar_dispatcher(a):
+ return (a,)
+
+
+@array_function_dispatch(_asscalar_dispatcher)
def asscalar(a):
"""
Convert an array of size 1 to its scalar equivalent.
@@ -586,6 +630,13 @@ array_precision = {_nx.half: 0,
_nx.csingle: 1,
_nx.cdouble: 2,
_nx.clongdouble: 3}
+
+
+def _common_type_dispatcher(*arrays):
+ return arrays
+
+
+@array_function_dispatch(_common_type_dispatcher)
def common_type(*arrays):
"""
Return a scalar type which is common to the input arrays.
diff --git a/numpy/lib/ufunclike.py b/numpy/lib/ufunclike.py
index 6259c5445..ac0af0b37 100644
--- a/numpy/lib/ufunclike.py
+++ b/numpy/lib/ufunclike.py
@@ -8,6 +8,7 @@ from __future__ import division, absolute_import, print_function
__all__ = ['fix', 'isneginf', 'isposinf']
import numpy.core.numeric as nx
+from numpy.core.overrides import array_function_dispatch
import warnings
import functools
@@ -37,7 +38,30 @@ def _deprecate_out_named_y(f):
return func
+def _fix_out_named_y(f):
+ """
+ Allow the out argument to be passed as the name `y` (deprecated)
+
+ This decorator should only be used if _deprecate_out_named_y is used on
+ a corresponding dispatcher fucntion.
+ """
+ @functools.wraps(f)
+ def func(x, out=None, **kwargs):
+ if 'y' in kwargs:
+ # we already did error checking in _deprecate_out_named_y
+ out = kwargs.pop('y')
+ return f(x, out=out, **kwargs)
+
+ return func
+
+
@_deprecate_out_named_y
+def _dispatcher(x, out=None):
+ return (x, out)
+
+
+@array_function_dispatch(_dispatcher, verify=False, module='numpy')
+@_fix_out_named_y
def fix(x, out=None):
"""
Round to nearest integer towards zero.
@@ -83,7 +107,8 @@ def fix(x, out=None):
return res
-@_deprecate_out_named_y
+@array_function_dispatch(_dispatcher, verify=False, module='numpy')
+@_fix_out_named_y
def isposinf(x, out=None):
"""
Test element-wise for positive infinity, return result as bool array.
@@ -151,7 +176,8 @@ def isposinf(x, out=None):
return nx.logical_and(is_inf, signbit, out)
-@_deprecate_out_named_y
+@array_function_dispatch(_dispatcher, verify=False, module='numpy')
+@_fix_out_named_y
def isneginf(x, out=None):
"""
Test element-wise for negative infinity, return result as bool array.
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 59923f3c5..771481e8e 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -16,6 +16,7 @@ __all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv',
'svd', 'eig', 'eigh', 'lstsq', 'norm', 'qr', 'cond', 'matrix_rank',
'LinAlgError', 'multi_dot']
+import functools
import operator
import warnings
@@ -28,10 +29,15 @@ from numpy.core import (
swapaxes, divide, count_nonzero, isnan
)
from numpy.core.multiarray import normalize_axis_index
-from numpy.core.overrides import array_function_dispatch
+from numpy.core import overrides
from numpy.lib.twodim_base import triu, eye
from numpy.linalg import lapack_lite, _umath_linalg
+
+array_function_dispatch = functools.partial(
+ overrides.array_function_dispatch, module='numpy.linalg')
+
+
# For Python2/3 compatibility
_N = b'N'
_V = b'V'
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index a3832fcde..20a7dfd0b 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -19,7 +19,7 @@ from warnings import WarningMessage
import pprint
from numpy.core import(
- float32, empty, arange, array_repr, ndarray, isnat, array)
+ bool_, float32, empty, arange, array_repr, ndarray, isnat, array)
from numpy.lib.utils import deprecate
if sys.version_info[0] >= 3:
@@ -352,7 +352,7 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
# XXX: catch ValueError for subclasses of ndarray where iscomplex fail
try:
usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
- except ValueError:
+ except (ValueError, TypeError):
usecomplex = False
if usecomplex:
@@ -692,6 +692,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
+ # original array for output formating
+ ox, oy = x, y
+
def isnumber(x):
return x.dtype.char in '?bhilqpBHILQPefdgFDG'
@@ -705,15 +708,20 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
at the same locations.
"""
- # Both the != True comparison here and the cast to bool_ at the end are
- # done to deal with `masked`, which cannot be compared usefully, and
- # for which np.all yields masked. The use of the function np.all is
- # for back compatibility with ndarray subclasses that changed the
- # return values of the all method. We are not committed to supporting
- # such subclasses, but some used to work.
x_id = func(x)
y_id = func(y)
- if npall(x_id == y_id) != True:
+ # We include work-arounds here to handle three types of slightly
+ # pathological ndarray subclasses:
+ # (1) all() on `masked` array scalars can return masked arrays, so we
+ # use != True
+ # (2) __eq__ on some ndarray subclasses returns Python booleans
+ # instead of element-wise comparisons, so we cast to bool_() and
+ # use isinstance(..., bool) checks
+ # (3) subclasses with bare-bones __array_function__ implemenations may
+ # not implement np.all(), so favor using the .all() method
+ # We are not committed to supporting such subclasses, but it's nice to
+ # support them if possible.
+ if bool_(x_id == y_id).all() != True:
msg = build_err_msg([x, y],
err_msg + '\nx and y %s location mismatch:'
% (hasval), verbose=verbose, header=header,
@@ -721,9 +729,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
raise AssertionError(msg)
# If there is a scalar, then here we know the array has the same
# flag as it everywhere, so we should return the scalar flag.
- if x_id.ndim == 0:
+ if isinstance(x_id, bool) or x_id.ndim == 0:
return bool_(x_id)
- elif y_id.ndim == 0:
+ elif isinstance(x_id, bool) or y_id.ndim == 0:
return bool_(y_id)
else:
return y_id
@@ -780,10 +788,10 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
# do not trigger a failure (np.ma.masked != True evaluates as
# np.ma.masked, which is falsy).
if cond != True:
- match = 100-100.0*reduced.count(1)/len(reduced)
- msg = build_err_msg([x, y],
+ mismatch = 100.0 * reduced.count(0) / ox.size
+ msg = build_err_msg([ox, oy],
err_msg
- + '\n(mismatch %s%%)' % (match,),
+ + '\n(mismatch %s%%)' % (mismatch,),
verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index e0d3414f7..e54fbc390 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -158,6 +158,44 @@ class TestArrayEqual(_GenericTest):
self._test_equal(a, b)
self._test_equal(b, a)
+ def test_subclass_that_overrides_eq(self):
+ # While we cannot guarantee testing functions will always work for
+ # subclasses, the tests should ideally rely only on subclasses having
+ # comparison operators, not on them being able to store booleans
+ # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
+ class MyArray(np.ndarray):
+ def __eq__(self, other):
+ return bool(np.equal(self, other).all())
+
+ def __ne__(self, other):
+ return not self == other
+
+ a = np.array([1., 2.]).view(MyArray)
+ b = np.array([2., 3.]).view(MyArray)
+ assert_(type(a == a), bool)
+ assert_(a == a)
+ assert_(a != b)
+ self._test_equal(a, a)
+ self._test_not_equal(a, b)
+ self._test_not_equal(b, a)
+
+ def test_subclass_that_does_not_implement_npall(self):
+ # While we cannot guarantee testing functions will always work for
+ # subclasses, the tests should ideally rely only on subclasses having
+ # comparison operators, not on them being able to store booleans
+ # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
+ class MyArray(np.ndarray):
+ def __array_function__(self, *args, **kwargs):
+ return NotImplemented
+
+ a = np.array([1., 2.]).view(MyArray)
+ b = np.array([2., 3.]).view(MyArray)
+ with assert_raises(TypeError):
+ np.all(a)
+ self._test_equal(a, a)
+ self._test_not_equal(a, b)
+ self._test_not_equal(b, a)
+
class TestBuildErrorMessage(object):
@@ -469,7 +507,8 @@ class TestAlmostEqual(_GenericTest):
self._test_not_equal(x, z)
def test_error_message(self):
- """Check the message is formatted correctly for the decimal value"""
+ """Check the message is formatted correctly for the decimal value.
+ Also check the message when input includes inf or nan (gh12200)"""
x = np.array([1.00000000001, 2.00000000002, 3.00003])
y = np.array([1.00000000002, 2.00000000003, 3.00004])
@@ -493,6 +532,19 @@ class TestAlmostEqual(_GenericTest):
# remove anything that's not the array string
assert_equal(str(e).split('%)\n ')[1], b)
+ # Check the error message when input includes inf or nan
+ x = np.array([np.inf, 0])
+ y = np.array([np.inf, 1])
+ try:
+ self._assert_func(x, y)
+ except AssertionError as e:
+ msgs = str(e).split('\n')
+ # assert error percentage is 50%
+ assert_equal(msgs[3], '(mismatch 50.0%)')
+ # assert output array contains inf
+ assert_equal(msgs[4], ' x: array([inf, 0.])')
+ assert_equal(msgs[5], ' y: array([inf, 1.])')
+
def test_subclass_that_cannot_be_bool(self):
# While we cannot guarantee testing functions will always work for
# subclasses, the tests should ideally rely only on subclasses having
@@ -1077,7 +1129,7 @@ class TestStringEqual(object):
assert_raises(AssertionError,
lambda: assert_string_equal("foo", "hello"))
-
+
def test_regex(self):
assert_string_equal("a+*b", "a+*b")