summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2018-12-19 10:08:01 -0800
committerStephan Hoyer <shoyer@google.com>2018-12-19 10:46:40 -0800
commit45413455791c77465c5f33a5082053274eb18900 (patch)
tree7362d739f4558b4aece09bbde333df5c219d8f1a /numpy
parentf4ddc2b3251d4606cc227761933e991131425416 (diff)
downloadnumpy-45413455791c77465c5f33a5082053274eb18900.tar.gz
ENH: refactor __array_function__ pure Python implementation
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_methods.py2
-rw-r--r--numpy/core/multiarray.py8
-rw-r--r--numpy/core/overrides.py108
-rw-r--r--numpy/core/shape_base.py7
-rw-r--r--numpy/core/tests/test_overrides.py67
-rw-r--r--numpy/core/tests/test_shape_base.py4
6 files changed, 113 insertions, 83 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index baeab6383..c6c163e16 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -161,7 +161,7 @@ def _array_function(self, func, types, args, kwargs):
# TODO: rewrite this in C
# Cannot handle items that have __array_function__ other than our own.
for t in types:
- if not issubclass(t, mu.ndarray) and hasattr(t, '__array_function__'):
+ if not issubclass(t, mu.ndarray):
return NotImplemented
# The regular implementation can handle this, so we call it directly.
diff --git a/numpy/core/multiarray.py b/numpy/core/multiarray.py
index 1b9c65782..4c2715892 100644
--- a/numpy/core/multiarray.py
+++ b/numpy/core/multiarray.py
@@ -211,9 +211,11 @@ def concatenate(arrays, axis=None, out=None):
fill_value=999999)
"""
- for array in arrays:
- yield array
- yield out
+ if out is not None:
+ # optimize for the typical case where only arrays is provided
+ arrays = list(arrays)
+ arrays.append(out)
+ return arrays
@array_function_from_c_func_and_dispatcher(_multiarray_umath.inner)
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 0979858a1..8f535b213 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -1,4 +1,4 @@
-"""Preliminary implementation of NEP-18
+"""Preliminary implementation of NEP-18.
TODO: rewrite this in C for performance.
"""
@@ -10,64 +10,80 @@ from numpy.core._multiarray_umath import add_docstring, ndarray
from numpy.compat._inspect import getargspec
-_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
-_NDARRAY_ONLY = [ndarray]
-
ENABLE_ARRAY_FUNCTION = bool(
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
-def get_overloaded_types_and_args(relevant_args):
+def get_implementing_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.
-
Parameters
----------
relevant_args : iterable of array-like
Iterable of array-like arguments to check for __array_function__
methods.
-
Returns
-------
- overloaded_types : collection of types
+ implementing_types : collection of types
Types of arguments from relevant_args with __array_function__ methods.
- overloaded_args : list
+ implementing_args : list
Arguments from relevant_args on which to call __array_function__
methods, in the order in which they should be called.
"""
# Runtime is O(num_arguments * num_unique_types)
- overloaded_types = []
- overloaded_args = []
+ implementing_types = []
+ implementing_args = []
for arg in relevant_args:
arg_type = type(arg)
# We only collect arguments if they have a unique type, which ensures
# reasonable performance even with a long list of possibly overloaded
# arguments.
- if (arg_type not in overloaded_types and
+ if (arg_type not in implementing_types and
hasattr(arg_type, '__array_function__')):
# Create lists explicitly for the first type (usually the only one
- # done) to avoid setting up the iterator for overloaded_args.
- if overloaded_types:
- overloaded_types.append(arg_type)
+ # done) to avoid setting up the iterator for implementing_args.
+ if implementing_types:
+ implementing_types.append(arg_type)
# By default, insert argument at the end, but if it is
# subclass of another argument, insert it before that argument.
# This ensures "subclasses before superclasses".
- index = len(overloaded_args)
- for i, old_arg in enumerate(overloaded_args):
+ index = len(implementing_args)
+ for i, old_arg in enumerate(implementing_args):
if issubclass(arg_type, type(old_arg)):
index = i
break
- overloaded_args.insert(index, arg)
+ implementing_args.insert(index, arg)
else:
- overloaded_types = [arg_type]
- overloaded_args = [arg]
+ implementing_types = [arg_type]
+ implementing_args = [arg]
+
+ return implementing_types, implementing_args
+
+
+_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
+
- return overloaded_types, overloaded_args
+def any_overrides(relevant_args):
+ """Are there any __array_function__ methods that need to be called?"""
+ for arg in relevant_args:
+ arg_type = type(arg)
+ if (arg_type is not ndarray and
+ getattr(arg_type, '__array_function__',
+ _NDARRAY_ARRAY_FUNCTION)
+ is not _NDARRAY_ARRAY_FUNCTION):
+ return True
+ return False
-def array_function_implementation_or_override(
+_TUPLE_OR_LIST = {tuple, list}
+
+
+def implement_array_function(
implementation, public_api, relevant_args, args, kwargs):
- """Implement a function with checks for __array_function__ overrides.
+ """
+ Implement a function with checks for __array_function__ overrides.
+
+ All arguments are required, and can only be passed by position.
Arguments
---------
@@ -82,41 +98,55 @@ def array_function_implementation_or_override(
Iterable of arguments to check for __array_function__ methods.
args : tuple
Arbitrary positional arguments originally passed into ``public_api``.
- kwargs : tuple
+ kwargs : dict
Arbitrary keyword arguments originally passed into ``public_api``.
Returns
-------
- Result from calling `implementation()` or an `__array_function__`
+ Result from calling ``implementation()`` or an ``__array_function__``
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
"""
- # Check for __array_function__ methods.
- types, overloaded_args = get_overloaded_types_and_args(relevant_args)
- # Short-cut for common cases: no overload or only ndarray overload
- # (directly or with subclasses that do not override __array_function__).
- if (not overloaded_args or types == _NDARRAY_ONLY or
- all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION
- for arg in overloaded_args)):
+ if type(relevant_args) not in _TUPLE_OR_LIST:
+ relevant_args = tuple(relevant_args)
+
+ if not any_overrides(relevant_args):
return implementation(*args, **kwargs)
# Call overrides
- for overloaded_arg in overloaded_args:
+ types, implementing_args = get_implementing_types_and_args(relevant_args)
+ for arg in implementing_args:
# Use `public_api` instead of `implemenation` so __array_function__
# implementations can do equality/identity comparisons.
- result = overloaded_arg.__array_function__(
- public_api, types, args, kwargs)
-
+ result = arg.__array_function__(public_api, types, args, kwargs)
if result is not NotImplemented:
return result
func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
raise TypeError("no implementation found for '{}' on types that implement "
- '__array_function__: {}'
- .format(func_name, list(map(type, overloaded_args))))
+ '__array_function__: {}'.format(func_name, list(types)))
+
+
+def _get_implementing_args(relevant_args):
+ """
+ Collect arguments on which to call __array_function__.
+
+ Parameters
+ ----------
+ relevant_args : iterable of array-like
+ Iterable of possibly array-like arguments to check for
+ __array_function__ methods.
+
+ Returns
+ -------
+ Sequence of arguments with __array_function__ methods, in the order in
+ which they should be called.
+ """
+ _, args = get_implementing_types_and_args(relevant_args)
+ return args
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
@@ -215,7 +245,7 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
- return array_function_implementation_or_override(
+ return implement_array_function(
implementation, public_api, relevant_args, args, kwargs)
if module is not None:
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index 0378d3c1f..f8332c362 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -342,10 +342,11 @@ def hstack(tup):
def _stack_dispatcher(arrays, axis=None, out=None):
arrays = _arrays_for_stack_dispatcher(arrays, stacklevel=6)
- for a in arrays:
- yield a
if out is not None:
- yield out
+ # optimize for the typical case where only arrays is provided
+ arrays = list(arrays)
+ arrays.append(out)
+ return arrays
@array_function_dispatch(_stack_dispatcher)
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 62b2a3e53..8c7ef576e 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -7,7 +7,7 @@ import numpy as np
from numpy.testing import (
assert_, assert_equal, assert_raises, assert_raises_regex)
from numpy.core.overrides import (
- get_overloaded_types_and_args, array_function_dispatch,
+ _get_implementing_args, array_function_dispatch,
verify_matching_signatures, ENABLE_ARRAY_FUNCTION)
from numpy.core.numeric import pickle
import pytest
@@ -18,11 +18,6 @@ requires_array_function = pytest.mark.skipif(
reason="__array_function__ dispatch not enabled.")
-def _get_overloaded_args(relevant_args):
- types, args = get_overloaded_types_and_args(relevant_args)
- return args
-
-
def _return_not_implemented(self, *args, **kwargs):
return NotImplemented
@@ -41,26 +36,21 @@ def dispatched_two_arg(array1, array2):
@requires_array_function
-class TestGetOverloadedTypesAndArgs(object):
+class TestGetImplementingArgs(object):
def test_ndarray(self):
array = np.array(1)
- types, args = get_overloaded_types_and_args([array])
- assert_equal(set(types), {np.ndarray})
+ args = _get_implementing_args([array])
assert_equal(list(args), [array])
- types, args = get_overloaded_types_and_args([array, array])
- assert_equal(len(types), 1)
- assert_equal(set(types), {np.ndarray})
+ args = _get_implementing_args([array, array])
assert_equal(list(args), [array])
- types, args = get_overloaded_types_and_args([array, 1])
- assert_equal(set(types), {np.ndarray})
+ args = _get_implementing_args([array, 1])
assert_equal(list(args), [array])
- types, args = get_overloaded_types_and_args([1, array])
- assert_equal(set(types), {np.ndarray})
+ args = _get_implementing_args([1, array])
assert_equal(list(args), [array])
def test_ndarray_subclasses(self):
@@ -75,17 +65,14 @@ class TestGetOverloadedTypesAndArgs(object):
override_sub = np.array(1).view(OverrideSub)
no_override_sub = np.array(1).view(NoOverrideSub)
- types, args = get_overloaded_types_and_args([array, override_sub])
- assert_equal(set(types), {np.ndarray, OverrideSub})
+ args = _get_implementing_args([array, override_sub])
assert_equal(list(args), [override_sub, array])
- types, args = get_overloaded_types_and_args([array, no_override_sub])
- assert_equal(set(types), {np.ndarray, NoOverrideSub})
+ args = _get_implementing_args([array, no_override_sub])
assert_equal(list(args), [no_override_sub, array])
- types, args = get_overloaded_types_and_args(
+ args = _get_implementing_args(
[override_sub, no_override_sub])
- assert_equal(set(types), {OverrideSub, NoOverrideSub})
assert_equal(list(args), [override_sub, no_override_sub])
def test_ndarray_and_duck_array(self):
@@ -96,12 +83,10 @@ class TestGetOverloadedTypesAndArgs(object):
array = np.array(1)
other = Other()
- types, args = get_overloaded_types_and_args([other, array])
- assert_equal(set(types), {np.ndarray, Other})
+ args = _get_implementing_args([other, array])
assert_equal(list(args), [other, array])
- types, args = get_overloaded_types_and_args([array, other])
- assert_equal(set(types), {np.ndarray, Other})
+ args = _get_implementing_args([array, other])
assert_equal(list(args), [array, other])
def test_ndarray_subclass_and_duck_array(self):
@@ -116,9 +101,9 @@ class TestGetOverloadedTypesAndArgs(object):
subarray = np.array(1).view(OverrideSub)
other = Other()
- assert_equal(_get_overloaded_args([array, subarray, other]),
+ assert_equal(_get_implementing_args([array, subarray, other]),
[subarray, array, other])
- assert_equal(_get_overloaded_args([array, other, subarray]),
+ assert_equal(_get_implementing_args([array, other, subarray]),
[subarray, array, other])
def test_many_duck_arrays(self):
@@ -140,15 +125,15 @@ class TestGetOverloadedTypesAndArgs(object):
c = C()
d = D()
- assert_equal(_get_overloaded_args([1]), [])
- assert_equal(_get_overloaded_args([a]), [a])
- assert_equal(_get_overloaded_args([a, 1]), [a])
- assert_equal(_get_overloaded_args([a, a, a]), [a])
- assert_equal(_get_overloaded_args([a, d, a]), [a, d])
- assert_equal(_get_overloaded_args([a, b]), [b, a])
- assert_equal(_get_overloaded_args([b, a]), [b, a])
- assert_equal(_get_overloaded_args([a, b, c]), [b, c, a])
- assert_equal(_get_overloaded_args([a, c, b]), [c, b, a])
+ assert_equal(_get_implementing_args([1]), [])
+ assert_equal(_get_implementing_args([a]), [a])
+ assert_equal(_get_implementing_args([a, 1]), [a])
+ assert_equal(_get_implementing_args([a, a, a]), [a])
+ assert_equal(_get_implementing_args([a, d, a]), [a, d])
+ assert_equal(_get_implementing_args([a, b]), [b, a])
+ assert_equal(_get_implementing_args([b, a]), [b, a])
+ assert_equal(_get_implementing_args([a, b, c]), [b, c, a])
+ assert_equal(_get_implementing_args([a, c, b]), [c, b, a])
@requires_array_function
@@ -201,6 +186,14 @@ class TestNDArrayArrayFunction(object):
result = np.concatenate((array, override_sub))
assert_equal(result, expected.view(OverrideSub))
+ def test_no_wrapper(self):
+ array = np.array(1)
+ func = dispatched_one_arg.__wrapped__
+ with assert_raises_regex(AttributeError, '__wrapped__'):
+ array.__array_function__(func=func,
+ types=(np.ndarray,),
+ args=(array,), kwargs={})
+
@requires_array_function
class TestArrayFunctionDispatch(object):
diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py
index ef5c118ec..b996321c2 100644
--- a/numpy/core/tests/test_shape_base.py
+++ b/numpy/core/tests/test_shape_base.py
@@ -373,6 +373,10 @@ def test_stack():
# empty arrays
assert_(stack([[], [], []]).shape == (3, 0))
assert_(stack([[], [], []], axis=1).shape == (0, 3))
+ # out
+ out = np.zeros_like(r1)
+ np.stack((a, b), out=out)
+ assert_array_equal(out, r1)
# edge cases
assert_raises_regex(ValueError, 'need at least one array', stack, [])
assert_raises_regex(ValueError, 'must have the same shape',