1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
|
"""Preliminary implementation of NEP-18
TODO: rewrite this in C for performance.
"""
import collections
import functools
from numpy.core.multiarray import ndarray
from numpy.compat._inspect import getargspec
_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
def get_overloaded_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.
Parameters
----------
relevant_args : iterable of array-like
Iterable of array-like arguments to check for __array_function__
methods.
Returns
-------
overloaded_types : collection of types
Types of arguments from relevant_args with __array_function__ methods.
overloaded_args : list
Arguments from relevant_args on which to call __array_function__
methods, in the order in which they should be called.
"""
# Runtime is O(num_arguments * num_unique_types)
overloaded_types = []
overloaded_args = []
for arg in relevant_args:
arg_type = type(arg)
# We only collect arguments if they have a unique type, which ensures
# reasonable performance even with a long list of possibly overloaded
# arguments.
if (arg_type not in overloaded_types and
hasattr(arg_type, '__array_function__')):
overloaded_types.append(arg_type)
# By default, insert this argument at the end, but if it is
# subclass of another argument, insert it before that argument.
# This ensures "subclasses before superclasses".
index = len(overloaded_args)
for i, old_arg in enumerate(overloaded_args):
if issubclass(arg_type, type(old_arg)):
index = i
break
overloaded_args.insert(index, arg)
# Special handling for ndarray.__array_function__
overloaded_args = [
arg for arg in overloaded_args
if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
]
return overloaded_types, overloaded_args
def array_function_implementation_or_override(
implementation, public_api, relevant_args, args, kwargs):
"""Implement a function with checks for __array_function__ overrides.
Arguments
---------
implementation : function
Function that implements the operation on NumPy array without
overrides when called like ``implementation(*args, **kwargs)``.
public_api : function
Function exposed by NumPy's public API riginally called like
``public_api(*args, **kwargs`` on which arguments are now being
checked.
relevant_args : iterable
Iterable of arguments to check for __array_function__ methods.
args : tuple
Arbitrary positional arguments originally passed into ``public_api``.
kwargs : tuple
Arbitrary keyword arguments originally passed into ``public_api``.
Returns
-------
Result from calling `implementation()` or an `__array_function__`
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
"""
# Check for __array_function__ methods.
types, overloaded_args = get_overloaded_types_and_args(relevant_args)
if not overloaded_args:
return implementation(*args, **kwargs)
# Call overrides
for overloaded_arg in overloaded_args:
# Use `public_api` instead of `implemenation` so __array_function__
# implementations can do equality/identity comparisons.
result = overloaded_arg.__array_function__(
public_api, types, args, kwargs)
if result is not NotImplemented:
return result
raise TypeError('no implementation found for {} on types that implement '
'__array_function__: {}'
.format(public_api, list(map(type, overloaded_args))))
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
def verify_matching_signatures(implementation, dispatcher):
"""Verify that a dispatcher function has the right signature."""
implementation_spec = ArgSpec(*getargspec(implementation))
dispatcher_spec = ArgSpec(*getargspec(dispatcher))
if (implementation_spec.args != dispatcher_spec.args or
implementation_spec.varargs != dispatcher_spec.varargs or
implementation_spec.keywords != dispatcher_spec.keywords or
(bool(implementation_spec.defaults) !=
bool(dispatcher_spec.defaults)) or
(implementation_spec.defaults is not None and
len(implementation_spec.defaults) !=
len(dispatcher_spec.defaults))):
raise RuntimeError('implementation and dispatcher for %s have '
'different function signatures' % implementation)
if implementation_spec.defaults is not None:
if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
raise RuntimeError('dispatcher functions can only use None for '
'default argument values')
def array_function_dispatch(dispatcher, verify=True):
"""Decorator for adding dispatch with the __array_function__ protocol."""
def decorator(implementation):
# TODO: only do this check when the appropriate flag is enabled or for
# a dev install. We want this check for testing but don't want to
# slow down all numpy imports.
if verify:
verify_matching_signatures(implementation, dispatcher)
@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
return array_function_implementation_or_override(
implementation, public_api, relevant_args, args, kwargs)
return public_api
return decorator
|