diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-12-20 09:23:30 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-12-20 09:23:30 -0500 |
commit | f4c497c768e0646df740b647782df463825bfd27 (patch) | |
tree | 2effea18d0f94f0d010d2bdb33317a1010e637d4 /numpy | |
parent | 6b8441665fb877ab65ec067de7be776ff4f5ac54 (diff) | |
parent | 7104a3e9fe989f43f20765d740cf22dabae563d4 (diff) | |
download | numpy-f4c497c768e0646df740b647782df463825bfd27.tar.gz |
Merge pull request #12317 from shoyer/array-function-c
ENH: port np.core.overrides to C for speed
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/_internal.py | 7 | ||||
-rw-r--r-- | numpy/core/_methods.py | 12 | ||||
-rw-r--r-- | numpy/core/code_generators/genapi.py | 1 | ||||
-rw-r--r-- | numpy/core/multiarray.py | 8 | ||||
-rw-r--r-- | numpy/core/overrides.py | 114 | ||||
-rw-r--r-- | numpy/core/setup.py | 2 | ||||
-rw-r--r-- | numpy/core/shape_base.py | 7 | ||||
-rw-r--r-- | numpy/core/src/common/get_attr_string.h | 1 | ||||
-rw-r--r-- | numpy/core/src/multiarray/arrayfunction_override.c | 376 | ||||
-rw-r--r-- | numpy/core/src/multiarray/arrayfunction_override.h | 16 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 25 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 11 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.h | 1 | ||||
-rw-r--r-- | numpy/core/tests/test_overrides.py | 78 | ||||
-rw-r--r-- | numpy/core/tests/test_shape_base.py | 4 |
15 files changed, 518 insertions, 145 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 27a3deeda..1d3bb5584 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -830,6 +830,13 @@ def array_ufunc_errmsg_formatter(dummy, ufunc, method, *inputs, **kwargs): .format(ufunc, method, args_string, types_string)) +def array_function_errmsg_formatter(public_api, types): + """ Format the error message for when __array_ufunc__ gives up. """ + func_name = '{}.{}'.format(public_api.__module__, public_api.__name__) + return ("no implementation found for '{}' on types that implement " + '__array_function__: {}'.format(func_name, list(types))) + + def _ufunc_doc_signature_formatter(ufunc): """ Builds a signature string which resembles PEP 457 diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index baeab6383..33f6d01a8 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -154,15 +154,3 @@ def _ptp(a, axis=None, out=None, keepdims=False): umr_minimum(a, axis, None, None, keepdims), out ) - -_NDARRAY_ARRAY_FUNCTION = mu.ndarray.__array_function__ - -def _array_function(self, func, types, args, kwargs): - # TODO: rewrite this in C - # Cannot handle items that have __array_function__ other than our own. - for t in types: - if not issubclass(t, mu.ndarray) and hasattr(t, '__array_function__'): - return NotImplemented - - # The regular implementation can handle this, so we call it directly. - return func.__wrapped__(*args, **kwargs) diff --git a/numpy/core/code_generators/genapi.py b/numpy/core/code_generators/genapi.py index 1d2cd25c8..4aca2373c 100644 --- a/numpy/core/code_generators/genapi.py +++ b/numpy/core/code_generators/genapi.py @@ -19,6 +19,7 @@ __docformat__ = 'restructuredtext' # The files under src/ that are scanned for API functions API_FILES = [join('multiarray', 'alloc.c'), + join('multiarray', 'arrayfunction_override.c'), join('multiarray', 'array_assign_array.c'), join('multiarray', 'array_assign_scalar.c'), join('multiarray', 'arrayobject.c'), diff --git a/numpy/core/multiarray.py b/numpy/core/multiarray.py index 1b9c65782..4c2715892 100644 --- a/numpy/core/multiarray.py +++ b/numpy/core/multiarray.py @@ -211,9 +211,11 @@ def concatenate(arrays, axis=None, out=None): fill_value=999999) """ - for array in arrays: - yield array - yield out + if out is not None: + # optimize for the typical case where only arrays is provided + arrays = list(arrays) + arrays.append(out) + return arrays @array_function_from_c_func_and_dispatcher(_multiarray_umath.inner) diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 0979858a1..c55174ecd 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -1,73 +1,23 @@ -"""Preliminary implementation of NEP-18 - -TODO: rewrite this in C for performance. -""" +"""Implementation of __array_function__ overrides from NEP-18.""" import collections import functools import os -from numpy.core._multiarray_umath import add_docstring, ndarray +from numpy.core._multiarray_umath import ( + add_docstring, implement_array_function, _get_implementing_args) from numpy.compat._inspect import getargspec -_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ -_NDARRAY_ONLY = [ndarray] - ENABLE_ARRAY_FUNCTION = bool( int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) -def get_overloaded_types_and_args(relevant_args): - """Returns a list of arguments on which to call __array_function__. - - Parameters - ---------- - relevant_args : iterable of array-like - Iterable of array-like arguments to check for __array_function__ - methods. - - Returns - ------- - overloaded_types : collection of types - Types of arguments from relevant_args with __array_function__ methods. - overloaded_args : list - Arguments from relevant_args on which to call __array_function__ - methods, in the order in which they should be called. +add_docstring( + implement_array_function, """ - # Runtime is O(num_arguments * num_unique_types) - overloaded_types = [] - overloaded_args = [] - for arg in relevant_args: - arg_type = type(arg) - # We only collect arguments if they have a unique type, which ensures - # reasonable performance even with a long list of possibly overloaded - # arguments. - if (arg_type not in overloaded_types and - hasattr(arg_type, '__array_function__')): - - # Create lists explicitly for the first type (usually the only one - # done) to avoid setting up the iterator for overloaded_args. - if overloaded_types: - overloaded_types.append(arg_type) - # By default, insert argument at the end, but if it is - # subclass of another argument, insert it before that argument. - # This ensures "subclasses before superclasses". - index = len(overloaded_args) - for i, old_arg in enumerate(overloaded_args): - if issubclass(arg_type, type(old_arg)): - index = i - break - overloaded_args.insert(index, arg) - else: - overloaded_types = [arg_type] - overloaded_args = [arg] - - return overloaded_types, overloaded_args - - -def array_function_implementation_or_override( - implementation, public_api, relevant_args, args, kwargs): - """Implement a function with checks for __array_function__ overrides. + Implement a function with checks for __array_function__ overrides. + + All arguments are required, and can only be passed by position. Arguments --------- @@ -82,41 +32,37 @@ def array_function_implementation_or_override( Iterable of arguments to check for __array_function__ methods. args : tuple Arbitrary positional arguments originally passed into ``public_api``. - kwargs : tuple + kwargs : dict Arbitrary keyword arguments originally passed into ``public_api``. Returns ------- - Result from calling `implementation()` or an `__array_function__` + Result from calling ``implementation()`` or an ``__array_function__`` method, as appropriate. Raises ------ TypeError : if no implementation is found. + """) + + +# exposed for testing purposes; used internally by implement_array_function +add_docstring( + _get_implementing_args, """ - # Check for __array_function__ methods. - types, overloaded_args = get_overloaded_types_and_args(relevant_args) - # Short-cut for common cases: no overload or only ndarray overload - # (directly or with subclasses that do not override __array_function__). - if (not overloaded_args or types == _NDARRAY_ONLY or - all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION - for arg in overloaded_args)): - return implementation(*args, **kwargs) - - # Call overrides - for overloaded_arg in overloaded_args: - # Use `public_api` instead of `implemenation` so __array_function__ - # implementations can do equality/identity comparisons. - result = overloaded_arg.__array_function__( - public_api, types, args, kwargs) - - if result is not NotImplemented: - return result - - func_name = '{}.{}'.format(public_api.__module__, public_api.__name__) - raise TypeError("no implementation found for '{}' on types that implement " - '__array_function__: {}' - .format(func_name, list(map(type, overloaded_args)))) + Collect arguments on which to call __array_function__. + + Parameters + ---------- + relevant_args : iterable of array-like + Iterable of possibly array-like arguments to check for + __array_function__ methods. + + Returns + ------- + Sequence of arguments with __array_function__ methods, in the order in + which they should be called. + """) ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') @@ -215,7 +161,7 @@ def array_function_dispatch(dispatcher, module=None, verify=True, @functools.wraps(implementation) def public_api(*args, **kwargs): relevant_args = dispatcher(*args, **kwargs) - return array_function_implementation_or_override( + return implement_array_function( implementation, public_api, relevant_args, args, kwargs) if module is not None: diff --git a/numpy/core/setup.py b/numpy/core/setup.py index 467b590ac..9ccca629e 100644 --- a/numpy/core/setup.py +++ b/numpy/core/setup.py @@ -775,6 +775,7 @@ def configuration(parent_package='',top_path=None): multiarray_deps = [ join('src', 'multiarray', 'arrayobject.h'), join('src', 'multiarray', 'arraytypes.h'), + join('src', 'multiarray', 'arrayfunction_override.h'), join('src', 'multiarray', 'buffer.h'), join('src', 'multiarray', 'calculation.h'), join('src', 'multiarray', 'common.h'), @@ -827,6 +828,7 @@ def configuration(parent_package='',top_path=None): join('src', 'multiarray', 'arraytypes.c.src'), join('src', 'multiarray', 'array_assign_scalar.c'), join('src', 'multiarray', 'array_assign_array.c'), + join('src', 'multiarray', 'arrayfunction_override.c'), join('src', 'multiarray', 'buffer.c'), join('src', 'multiarray', 'calculation.c'), join('src', 'multiarray', 'compiled_base.c'), diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 0378d3c1f..f8332c362 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -342,10 +342,11 @@ def hstack(tup): def _stack_dispatcher(arrays, axis=None, out=None): arrays = _arrays_for_stack_dispatcher(arrays, stacklevel=6) - for a in arrays: - yield a if out is not None: - yield out + # optimize for the typical case where only arrays is provided + arrays = list(arrays) + arrays.append(out) + return arrays @array_function_dispatch(_stack_dispatcher) diff --git a/numpy/core/src/common/get_attr_string.h b/numpy/core/src/common/get_attr_string.h index bec87c5ed..d458d9550 100644 --- a/numpy/core/src/common/get_attr_string.h +++ b/numpy/core/src/common/get_attr_string.h @@ -103,7 +103,6 @@ PyArray_LookupSpecial(PyObject *obj, char *name) if (_is_basic_python_type(tp)) { return NULL; } - return maybe_get_attr((PyObject *)tp, name); } diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c new file mode 100644 index 000000000..e62b32ab2 --- /dev/null +++ b/numpy/core/src/multiarray/arrayfunction_override.c @@ -0,0 +1,376 @@ +#define NPY_NO_DEPRECATED_API NPY_API_VERSION +#define _MULTIARRAYMODULE + +#include "npy_pycompat.h" +#include "get_attr_string.h" +#include "npy_import.h" +#include "multiarraymodule.h" + + +/* Return the ndarray.__array_function__ method. */ +static PyObject * +get_ndarray_array_function(void) +{ + PyObject* method = PyObject_GetAttrString((PyObject *)&PyArray_Type, + "__array_function__"); + assert(method != NULL); + return method; +} + + +/* + * Get an object's __array_function__ method in the fastest way possible. + * Never raises an exception. Returns NULL if the method doesn't exist. + */ +static PyObject * +get_array_function(PyObject *obj) +{ + static PyObject *ndarray_array_function = NULL; + + if (ndarray_array_function == NULL) { + ndarray_array_function = get_ndarray_array_function(); + } + + /* Fast return for ndarray */ + if (PyArray_CheckExact(obj)) { + Py_INCREF(ndarray_array_function); + return ndarray_array_function; + } + + return PyArray_LookupSpecial(obj, "__array_function__"); +} + + +/* + * Like list.insert(), but for C arrays of PyObject*. Skips error checking. + */ +static void +pyobject_array_insert(PyObject **array, int length, int index, PyObject *item) +{ + int j; + + for (j = length; j > index; j--) { + array[j] = array[j - 1]; + } + array[index] = item; +} + + +/* + * Collects arguments with __array_function__ and their corresponding methods + * in the order in which they should be tried (i.e., skipping redundant types). + * `relevant_args` is expected to have been produced by PySequence_Fast. + * Returns the number of arguments, or -1 on failure. + */ +static int +get_implementing_args_and_methods(PyObject *relevant_args, + PyObject **implementing_args, + PyObject **methods) +{ + int num_implementing_args = 0; + Py_ssize_t i; + int j; + + PyObject **items = PySequence_Fast_ITEMS(relevant_args); + Py_ssize_t length = PySequence_Fast_GET_SIZE(relevant_args); + + for (i = 0; i < length; i++) { + int new_class = 1; + PyObject *argument = items[i]; + + /* Have we seen this type before? */ + for (j = 0; j < num_implementing_args; j++) { + if (Py_TYPE(argument) == Py_TYPE(implementing_args[j])) { + new_class = 0; + break; + } + } + if (new_class) { + PyObject *method = get_array_function(argument); + + if (method != NULL) { + int arg_index; + + if (num_implementing_args >= NPY_MAXARGS) { + PyErr_Format( + PyExc_TypeError, + "maximum number (%d) of distinct argument types " \ + "implementing __array_function__ exceeded", + NPY_MAXARGS); + Py_DECREF(method); + goto fail; + } + + /* "subclasses before superclasses, otherwise left to right" */ + arg_index = num_implementing_args; + for (j = 0; j < num_implementing_args; j++) { + PyObject *other_type; + other_type = (PyObject *)Py_TYPE(implementing_args[j]); + if (PyObject_IsInstance(argument, other_type)) { + arg_index = j; + break; + } + } + Py_INCREF(argument); + pyobject_array_insert(implementing_args, num_implementing_args, + arg_index, argument); + pyobject_array_insert(methods, num_implementing_args, + arg_index, method); + ++num_implementing_args; + } + } + } + return num_implementing_args; + +fail: + for (j = 0; j < num_implementing_args; j++) { + Py_DECREF(implementing_args[j]); + Py_DECREF(methods[j]); + } + return -1; +} + + +/* + * Is this object ndarray.__array_function__? + */ +static int +is_default_array_function(PyObject *obj) +{ + static PyObject *ndarray_array_function = NULL; + + if (ndarray_array_function == NULL) { + ndarray_array_function = get_ndarray_array_function(); + } + return obj == ndarray_array_function; +} + + +/* + * Core implementation of ndarray.__array_function__. This is exposed + * separately so we can avoid the overhead of a Python method call from + * within `implement_array_function`. + */ +NPY_NO_EXPORT PyObject * +array_function_method_impl(PyObject *func, PyObject *types, PyObject *args, + PyObject *kwargs) +{ + Py_ssize_t j; + PyObject *implementation, *result; + + PyObject **items = PySequence_Fast_ITEMS(types); + Py_ssize_t length = PySequence_Fast_GET_SIZE(types); + + for (j = 0; j < length; j++) { + int is_subclass = PyObject_IsSubclass( + items[j], (PyObject *)&PyArray_Type); + if (is_subclass == -1) { + return NULL; + } + if (!is_subclass) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + } + + implementation = PyObject_GetAttr(func, npy_ma_str_wrapped); + if (implementation == NULL) { + return NULL; + } + result = PyObject_Call(implementation, args, kwargs); + Py_DECREF(implementation); + return result; +} + + +/* + * Calls __array_function__ on the provided argument, with a fast-path for + * ndarray. + */ +static PyObject * +call_array_function(PyObject* argument, PyObject* method, + PyObject* public_api, PyObject* types, + PyObject* args, PyObject* kwargs) +{ + if (is_default_array_function(method)) { + return array_function_method_impl(public_api, types, args, kwargs); + } + else { + return PyObject_CallFunctionObjArgs( + method, argument, public_api, types, args, kwargs, NULL); + } +} + + +/* + * Implements the __array_function__ protocol for a function, as described in + * in NEP-18. See numpy.core.overrides for a full docstring. + */ +NPY_NO_EXPORT PyObject * +array_implement_array_function( + PyObject *NPY_UNUSED(dummy), PyObject *positional_args) +{ + PyObject *implementation, *public_api, *relevant_args, *args, *kwargs; + + PyObject *types = NULL; + PyObject *implementing_args[NPY_MAXARGS]; + PyObject *array_function_methods[NPY_MAXARGS]; + + int j, any_overrides; + int num_implementing_args = 0; + PyObject *result = NULL; + + static PyObject *errmsg_formatter = NULL; + + if (!PyArg_UnpackTuple( + positional_args, "implement_array_function", 5, 5, + &implementation, &public_api, &relevant_args, &args, &kwargs)) { + return NULL; + } + + relevant_args = PySequence_Fast( + relevant_args, + "dispatcher for __array_function__ did not return an iterable"); + if (relevant_args == NULL) { + return NULL; + } + + /* Collect __array_function__ implementations */ + num_implementing_args = get_implementing_args_and_methods( + relevant_args, implementing_args, array_function_methods); + if (num_implementing_args == -1) { + goto cleanup; + } + + /* + * Handle the typical case of no overrides. This is merely an optimization + * if some arguments are ndarray objects, but is also necessary if no + * arguments implement __array_function__ at all (e.g., if they are all + * built-in types). + */ + any_overrides = 0; + for (j = 0; j < num_implementing_args; j++) { + if (!is_default_array_function(array_function_methods[j])) { + any_overrides = 1; + break; + } + } + if (!any_overrides) { + result = PyObject_Call(implementation, args, kwargs); + goto cleanup; + } + + /* + * Create a Python object for types. + * We use a tuple, because it's the fastest Python collection to create + * and has the bonus of being immutable. + */ + types = PyTuple_New(num_implementing_args); + if (types == NULL) { + goto cleanup; + } + for (j = 0; j < num_implementing_args; j++) { + PyObject *arg_type = (PyObject *)Py_TYPE(implementing_args[j]); + Py_INCREF(arg_type); + PyTuple_SET_ITEM(types, j, arg_type); + } + + /* Call __array_function__ methods */ + for (j = 0; j < num_implementing_args; j++) { + PyObject *argument = implementing_args[j]; + PyObject *method = array_function_methods[j]; + + /* + * We use `public_api` instead of `implementation` here so + * __array_function__ implementations can do equality/identity + * comparisons. + */ + result = call_array_function( + argument, method, public_api, types, args, kwargs); + + if (result == Py_NotImplemented) { + /* Try the next one */ + Py_DECREF(result); + result = NULL; + } + else { + /* Either a good result, or an exception was raised. */ + goto cleanup; + } + } + + /* No acceptable override found, raise TypeError. */ + npy_cache_import("numpy.core._internal", + "array_function_errmsg_formatter", + &errmsg_formatter); + if (errmsg_formatter != NULL) { + PyObject *errmsg = PyObject_CallFunctionObjArgs( + errmsg_formatter, public_api, types, NULL); + if (errmsg != NULL) { + PyErr_SetObject(PyExc_TypeError, errmsg); + Py_DECREF(errmsg); + } + } + +cleanup: + for (j = 0; j < num_implementing_args; j++) { + Py_DECREF(implementing_args[j]); + Py_DECREF(array_function_methods[j]); + } + Py_XDECREF(types); + Py_DECREF(relevant_args); + return result; +} + + +/* + * Python wrapper for get_implementing_args_and_methods, for testing purposes. + */ +NPY_NO_EXPORT PyObject * +array__get_implementing_args( + PyObject *NPY_UNUSED(dummy), PyObject *positional_args) +{ + PyObject *relevant_args; + int j; + int num_implementing_args = 0; + PyObject *implementing_args[NPY_MAXARGS]; + PyObject *array_function_methods[NPY_MAXARGS]; + PyObject *result = NULL; + + if (!PyArg_ParseTuple(positional_args, "O:array__get_implementing_args", + &relevant_args)) { + return NULL; + } + + relevant_args = PySequence_Fast( + relevant_args, + "dispatcher for __array_function__ did not return an iterable"); + if (relevant_args == NULL) { + return NULL; + } + + num_implementing_args = get_implementing_args_and_methods( + relevant_args, implementing_args, array_function_methods); + if (num_implementing_args == -1) { + goto cleanup; + } + + /* create a Python object for implementing_args */ + result = PyList_New(num_implementing_args); + if (result == NULL) { + goto cleanup; + } + for (j = 0; j < num_implementing_args; j++) { + PyObject *argument = implementing_args[j]; + Py_INCREF(argument); + PyList_SET_ITEM(result, j, argument); + } + +cleanup: + for (j = 0; j < num_implementing_args; j++) { + Py_DECREF(implementing_args[j]); + Py_DECREF(array_function_methods[j]); + } + Py_DECREF(relevant_args); + return result; +} diff --git a/numpy/core/src/multiarray/arrayfunction_override.h b/numpy/core/src/multiarray/arrayfunction_override.h new file mode 100644 index 000000000..0d224e2b6 --- /dev/null +++ b/numpy/core/src/multiarray/arrayfunction_override.h @@ -0,0 +1,16 @@ +#ifndef _NPY_PRIVATE__ARRAYFUNCTION_OVERRIDE_H +#define _NPY_PRIVATE__ARRAYFUNCTION_OVERRIDE_H + +NPY_NO_EXPORT PyObject * +array_implement_array_function( + PyObject *NPY_UNUSED(dummy), PyObject *positional_args); + +NPY_NO_EXPORT PyObject * +array__get_implementing_args( + PyObject *NPY_UNUSED(dummy), PyObject *positional_args); + +NPY_NO_EXPORT PyObject * +array_function_method_impl(PyObject *func, PyObject *types, PyObject *args, + PyObject *kwargs); + +#endif diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 7c814e6e6..085bc00c0 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -8,6 +8,7 @@ #include "numpy/arrayobject.h" #include "numpy/arrayscalars.h" +#include "arrayfunction_override.h" #include "npy_config.h" #include "npy_pycompat.h" #include "npy_import.h" @@ -1088,13 +1089,29 @@ cleanup: return result; } - static PyObject * -array_function(PyArrayObject *self, PyObject *args, PyObject *kwds) +array_function(PyArrayObject *self, PyObject *c_args, PyObject *c_kwds) { - NPY_FORWARD_NDARRAY_METHOD("_array_function"); -} + PyObject *func, *types, *args, *kwargs, *result; + static char *kwlist[] = {"func", "types", "args", "kwargs", NULL}; + + if (!PyArg_ParseTupleAndKeywords( + c_args, c_kwds, "OOOO:__array_function__", kwlist, + &func, &types, &args, &kwargs)) { + return NULL; + } + types = PySequence_Fast( + types, + "types argument to ndarray.__array_function__ must be iterable"); + if (types == NULL) { + return NULL; + } + + result = array_function_method_impl(func, types, args, kwargs); + Py_DECREF(types); + return result; +} static PyObject * array_copy(PyArrayObject *self, PyObject *args, PyObject *kwds) diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 8135769d9..62345d2b0 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -34,6 +34,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0; /* Internal APIs */ +#include "arrayfunction_override.h" #include "arraytypes.h" #include "arrayobject.h" #include "hashdescr.h" @@ -4062,6 +4063,9 @@ normalize_axis_index(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds) } static struct PyMethodDef array_module_methods[] = { + {"_get_implementing_args", + (PyCFunction)array__get_implementing_args, + METH_VARARGS, NULL}, {"_get_ndarray_c_version", (PyCFunction)array__get_ndarray_c_version, METH_VARARGS|METH_KEYWORDS, NULL}, @@ -4224,6 +4228,9 @@ static struct PyMethodDef array_module_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"_monotonicity", (PyCFunction)arr__monotonicity, METH_VARARGS | METH_KEYWORDS, NULL}, + {"implement_array_function", + (PyCFunction)array_implement_array_function, + METH_VARARGS, NULL}, {"interp", (PyCFunction)arr_interp, METH_VARARGS | METH_KEYWORDS, NULL}, {"interp_complex", (PyCFunction)arr_interp_complex, @@ -4476,6 +4483,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_array_wrap = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_array_finalize = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_buffer = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_ufunc = NULL; +NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_wrapped = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_order = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_copy = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_dtype = NULL; @@ -4492,6 +4500,7 @@ intern_strings(void) npy_ma_str_array_finalize = PyUString_InternFromString("__array_finalize__"); npy_ma_str_buffer = PyUString_InternFromString("__buffer__"); npy_ma_str_ufunc = PyUString_InternFromString("__array_ufunc__"); + npy_ma_str_wrapped = PyUString_InternFromString("__wrapped__"); npy_ma_str_order = PyUString_InternFromString("order"); npy_ma_str_copy = PyUString_InternFromString("copy"); npy_ma_str_dtype = PyUString_InternFromString("dtype"); @@ -4501,7 +4510,7 @@ intern_strings(void) return npy_ma_str_array && npy_ma_str_array_prepare && npy_ma_str_array_wrap && npy_ma_str_array_finalize && - npy_ma_str_buffer && npy_ma_str_ufunc && + npy_ma_str_buffer && npy_ma_str_ufunc && npy_ma_str_wrapped && npy_ma_str_order && npy_ma_str_copy && npy_ma_str_dtype && npy_ma_str_ndmin && npy_ma_str_axis1 && npy_ma_str_axis2; } diff --git a/numpy/core/src/multiarray/multiarraymodule.h b/numpy/core/src/multiarray/multiarraymodule.h index 3de68c549..60a3965c9 100644 --- a/numpy/core/src/multiarray/multiarraymodule.h +++ b/numpy/core/src/multiarray/multiarraymodule.h @@ -7,6 +7,7 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_array_wrap; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_array_finalize; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_buffer; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_ufunc; +NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_wrapped; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_order; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_copy; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_dtype; diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 62b2a3e53..8f1c16539 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -7,7 +7,7 @@ import numpy as np from numpy.testing import ( assert_, assert_equal, assert_raises, assert_raises_regex) from numpy.core.overrides import ( - get_overloaded_types_and_args, array_function_dispatch, + _get_implementing_args, array_function_dispatch, verify_matching_signatures, ENABLE_ARRAY_FUNCTION) from numpy.core.numeric import pickle import pytest @@ -18,11 +18,6 @@ requires_array_function = pytest.mark.skipif( reason="__array_function__ dispatch not enabled.") -def _get_overloaded_args(relevant_args): - types, args = get_overloaded_types_and_args(relevant_args) - return args - - def _return_not_implemented(self, *args, **kwargs): return NotImplemented @@ -41,26 +36,21 @@ def dispatched_two_arg(array1, array2): @requires_array_function -class TestGetOverloadedTypesAndArgs(object): +class TestGetImplementingArgs(object): def test_ndarray(self): array = np.array(1) - types, args = get_overloaded_types_and_args([array]) - assert_equal(set(types), {np.ndarray}) + args = _get_implementing_args([array]) assert_equal(list(args), [array]) - types, args = get_overloaded_types_and_args([array, array]) - assert_equal(len(types), 1) - assert_equal(set(types), {np.ndarray}) + args = _get_implementing_args([array, array]) assert_equal(list(args), [array]) - types, args = get_overloaded_types_and_args([array, 1]) - assert_equal(set(types), {np.ndarray}) + args = _get_implementing_args([array, 1]) assert_equal(list(args), [array]) - types, args = get_overloaded_types_and_args([1, array]) - assert_equal(set(types), {np.ndarray}) + args = _get_implementing_args([1, array]) assert_equal(list(args), [array]) def test_ndarray_subclasses(self): @@ -75,17 +65,14 @@ class TestGetOverloadedTypesAndArgs(object): override_sub = np.array(1).view(OverrideSub) no_override_sub = np.array(1).view(NoOverrideSub) - types, args = get_overloaded_types_and_args([array, override_sub]) - assert_equal(set(types), {np.ndarray, OverrideSub}) + args = _get_implementing_args([array, override_sub]) assert_equal(list(args), [override_sub, array]) - types, args = get_overloaded_types_and_args([array, no_override_sub]) - assert_equal(set(types), {np.ndarray, NoOverrideSub}) + args = _get_implementing_args([array, no_override_sub]) assert_equal(list(args), [no_override_sub, array]) - types, args = get_overloaded_types_and_args( + args = _get_implementing_args( [override_sub, no_override_sub]) - assert_equal(set(types), {OverrideSub, NoOverrideSub}) assert_equal(list(args), [override_sub, no_override_sub]) def test_ndarray_and_duck_array(self): @@ -96,12 +83,10 @@ class TestGetOverloadedTypesAndArgs(object): array = np.array(1) other = Other() - types, args = get_overloaded_types_and_args([other, array]) - assert_equal(set(types), {np.ndarray, Other}) + args = _get_implementing_args([other, array]) assert_equal(list(args), [other, array]) - types, args = get_overloaded_types_and_args([array, other]) - assert_equal(set(types), {np.ndarray, Other}) + args = _get_implementing_args([array, other]) assert_equal(list(args), [array, other]) def test_ndarray_subclass_and_duck_array(self): @@ -116,9 +101,9 @@ class TestGetOverloadedTypesAndArgs(object): subarray = np.array(1).view(OverrideSub) other = Other() - assert_equal(_get_overloaded_args([array, subarray, other]), + assert_equal(_get_implementing_args([array, subarray, other]), [subarray, array, other]) - assert_equal(_get_overloaded_args([array, other, subarray]), + assert_equal(_get_implementing_args([array, other, subarray]), [subarray, array, other]) def test_many_duck_arrays(self): @@ -140,15 +125,26 @@ class TestGetOverloadedTypesAndArgs(object): c = C() d = D() - assert_equal(_get_overloaded_args([1]), []) - assert_equal(_get_overloaded_args([a]), [a]) - assert_equal(_get_overloaded_args([a, 1]), [a]) - assert_equal(_get_overloaded_args([a, a, a]), [a]) - assert_equal(_get_overloaded_args([a, d, a]), [a, d]) - assert_equal(_get_overloaded_args([a, b]), [b, a]) - assert_equal(_get_overloaded_args([b, a]), [b, a]) - assert_equal(_get_overloaded_args([a, b, c]), [b, c, a]) - assert_equal(_get_overloaded_args([a, c, b]), [c, b, a]) + assert_equal(_get_implementing_args([1]), []) + assert_equal(_get_implementing_args([a]), [a]) + assert_equal(_get_implementing_args([a, 1]), [a]) + assert_equal(_get_implementing_args([a, a, a]), [a]) + assert_equal(_get_implementing_args([a, d, a]), [a, d]) + assert_equal(_get_implementing_args([a, b]), [b, a]) + assert_equal(_get_implementing_args([b, a]), [b, a]) + assert_equal(_get_implementing_args([a, b, c]), [b, c, a]) + assert_equal(_get_implementing_args([a, c, b]), [c, b, a]) + + def test_too_many_duck_arrays(self): + namespace = dict(__array_function__=_return_not_implemented) + types = [type('A' + str(i), (object,), namespace) for i in range(33)] + relevant_args = [t() for t in types] + + actual = _get_implementing_args(relevant_args[:32]) + assert_equal(actual, relevant_args[:32]) + + with assert_raises_regex(TypeError, 'distinct argument types'): + _get_implementing_args(relevant_args) @requires_array_function @@ -201,6 +197,14 @@ class TestNDArrayArrayFunction(object): result = np.concatenate((array, override_sub)) assert_equal(result, expected.view(OverrideSub)) + def test_no_wrapper(self): + array = np.array(1) + func = dispatched_one_arg.__wrapped__ + with assert_raises_regex(AttributeError, '__wrapped__'): + array.__array_function__(func=func, + types=(np.ndarray,), + args=(array,), kwargs={}) + @requires_array_function class TestArrayFunctionDispatch(object): diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py index ef5c118ec..b996321c2 100644 --- a/numpy/core/tests/test_shape_base.py +++ b/numpy/core/tests/test_shape_base.py @@ -373,6 +373,10 @@ def test_stack(): # empty arrays assert_(stack([[], [], []]).shape == (3, 0)) assert_(stack([[], [], []], axis=1).shape == (0, 3)) + # out + out = np.zeros_like(r1) + np.stack((a, b), out=out) + assert_array_equal(out, r1) # edge cases assert_raises_regex(ValueError, 'need at least one array', stack, []) assert_raises_regex(ValueError, 'must have the same shape', |