summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2018-12-19 10:48:10 -0800
committerStephan Hoyer <shoyer@google.com>2018-12-19 10:48:10 -0800
commit7104a3e9fe989f43f20765d740cf22dabae563d4 (patch)
treecde09169afa607600abd09d68989fe62a4f37ccf /numpy
parent45413455791c77465c5f33a5082053274eb18900 (diff)
downloadnumpy-7104a3e9fe989f43f20765d740cf22dabae563d4.tar.gz
ENH: port __array_function__ overrides to C
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_internal.py7
-rw-r--r--numpy/core/_methods.py12
-rw-r--r--numpy/core/code_generators/genapi.py1
-rw-r--r--numpy/core/overrides.py104
-rw-r--r--numpy/core/setup.py2
-rw-r--r--numpy/core/src/common/get_attr_string.h1
-rw-r--r--numpy/core/src/multiarray/arrayfunction_override.c376
-rw-r--r--numpy/core/src/multiarray/arrayfunction_override.h16
-rw-r--r--numpy/core/src/multiarray/methods.c25
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c11
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.h1
-rw-r--r--numpy/core/tests/test_overrides.py11
12 files changed, 455 insertions, 112 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 27a3deeda..1d3bb5584 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -830,6 +830,13 @@ def array_ufunc_errmsg_formatter(dummy, ufunc, method, *inputs, **kwargs):
.format(ufunc, method, args_string, types_string))
+def array_function_errmsg_formatter(public_api, types):
+ """ Format the error message for when __array_ufunc__ gives up. """
+ func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
+ return ("no implementation found for '{}' on types that implement "
+ '__array_function__: {}'.format(func_name, list(types)))
+
+
def _ufunc_doc_signature_formatter(ufunc):
"""
Builds a signature string which resembles PEP 457
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index c6c163e16..33f6d01a8 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -154,15 +154,3 @@ def _ptp(a, axis=None, out=None, keepdims=False):
umr_minimum(a, axis, None, None, keepdims),
out
)
-
-_NDARRAY_ARRAY_FUNCTION = mu.ndarray.__array_function__
-
-def _array_function(self, func, types, args, kwargs):
- # TODO: rewrite this in C
- # Cannot handle items that have __array_function__ other than our own.
- for t in types:
- if not issubclass(t, mu.ndarray):
- return NotImplemented
-
- # The regular implementation can handle this, so we call it directly.
- return func.__wrapped__(*args, **kwargs)
diff --git a/numpy/core/code_generators/genapi.py b/numpy/core/code_generators/genapi.py
index 1d2cd25c8..4aca2373c 100644
--- a/numpy/core/code_generators/genapi.py
+++ b/numpy/core/code_generators/genapi.py
@@ -19,6 +19,7 @@ __docformat__ = 'restructuredtext'
# The files under src/ that are scanned for API functions
API_FILES = [join('multiarray', 'alloc.c'),
+ join('multiarray', 'arrayfunction_override.c'),
join('multiarray', 'array_assign_array.c'),
join('multiarray', 'array_assign_scalar.c'),
join('multiarray', 'arrayobject.c'),
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 8f535b213..c55174ecd 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -1,12 +1,10 @@
-"""Preliminary implementation of NEP-18.
-
-TODO: rewrite this in C for performance.
-"""
+"""Implementation of __array_function__ overrides from NEP-18."""
import collections
import functools
import os
-from numpy.core._multiarray_umath import add_docstring, ndarray
+from numpy.core._multiarray_umath import (
+ add_docstring, implement_array_function, _get_implementing_args)
from numpy.compat._inspect import getargspec
@@ -14,72 +12,8 @@ ENABLE_ARRAY_FUNCTION = bool(
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
-def get_implementing_types_and_args(relevant_args):
- """Returns a list of arguments on which to call __array_function__.
- Parameters
- ----------
- relevant_args : iterable of array-like
- Iterable of array-like arguments to check for __array_function__
- methods.
- Returns
- -------
- implementing_types : collection of types
- Types of arguments from relevant_args with __array_function__ methods.
- implementing_args : list
- Arguments from relevant_args on which to call __array_function__
- methods, in the order in which they should be called.
- """
- # Runtime is O(num_arguments * num_unique_types)
- implementing_types = []
- implementing_args = []
- for arg in relevant_args:
- arg_type = type(arg)
- # We only collect arguments if they have a unique type, which ensures
- # reasonable performance even with a long list of possibly overloaded
- # arguments.
- if (arg_type not in implementing_types and
- hasattr(arg_type, '__array_function__')):
-
- # Create lists explicitly for the first type (usually the only one
- # done) to avoid setting up the iterator for implementing_args.
- if implementing_types:
- implementing_types.append(arg_type)
- # By default, insert argument at the end, but if it is
- # subclass of another argument, insert it before that argument.
- # This ensures "subclasses before superclasses".
- index = len(implementing_args)
- for i, old_arg in enumerate(implementing_args):
- if issubclass(arg_type, type(old_arg)):
- index = i
- break
- implementing_args.insert(index, arg)
- else:
- implementing_types = [arg_type]
- implementing_args = [arg]
-
- return implementing_types, implementing_args
-
-
-_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
-
-
-def any_overrides(relevant_args):
- """Are there any __array_function__ methods that need to be called?"""
- for arg in relevant_args:
- arg_type = type(arg)
- if (arg_type is not ndarray and
- getattr(arg_type, '__array_function__',
- _NDARRAY_ARRAY_FUNCTION)
- is not _NDARRAY_ARRAY_FUNCTION):
- return True
- return False
-
-
-_TUPLE_OR_LIST = {tuple, list}
-
-
-def implement_array_function(
- implementation, public_api, relevant_args, args, kwargs):
+add_docstring(
+ implement_array_function,
"""
Implement a function with checks for __array_function__ overrides.
@@ -109,28 +43,12 @@ def implement_array_function(
Raises
------
TypeError : if no implementation is found.
- """
- if type(relevant_args) not in _TUPLE_OR_LIST:
- relevant_args = tuple(relevant_args)
-
- if not any_overrides(relevant_args):
- return implementation(*args, **kwargs)
+ """)
- # Call overrides
- types, implementing_args = get_implementing_types_and_args(relevant_args)
- for arg in implementing_args:
- # Use `public_api` instead of `implemenation` so __array_function__
- # implementations can do equality/identity comparisons.
- result = arg.__array_function__(public_api, types, args, kwargs)
- if result is not NotImplemented:
- return result
- func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
- raise TypeError("no implementation found for '{}' on types that implement "
- '__array_function__: {}'.format(func_name, list(types)))
-
-
-def _get_implementing_args(relevant_args):
+# exposed for testing purposes; used internally by implement_array_function
+add_docstring(
+ _get_implementing_args,
"""
Collect arguments on which to call __array_function__.
@@ -144,9 +62,7 @@ def _get_implementing_args(relevant_args):
-------
Sequence of arguments with __array_function__ methods, in the order in
which they should be called.
- """
- _, args = get_implementing_types_and_args(relevant_args)
- return args
+ """)
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
diff --git a/numpy/core/setup.py b/numpy/core/setup.py
index 467b590ac..9ccca629e 100644
--- a/numpy/core/setup.py
+++ b/numpy/core/setup.py
@@ -775,6 +775,7 @@ def configuration(parent_package='',top_path=None):
multiarray_deps = [
join('src', 'multiarray', 'arrayobject.h'),
join('src', 'multiarray', 'arraytypes.h'),
+ join('src', 'multiarray', 'arrayfunction_override.h'),
join('src', 'multiarray', 'buffer.h'),
join('src', 'multiarray', 'calculation.h'),
join('src', 'multiarray', 'common.h'),
@@ -827,6 +828,7 @@ def configuration(parent_package='',top_path=None):
join('src', 'multiarray', 'arraytypes.c.src'),
join('src', 'multiarray', 'array_assign_scalar.c'),
join('src', 'multiarray', 'array_assign_array.c'),
+ join('src', 'multiarray', 'arrayfunction_override.c'),
join('src', 'multiarray', 'buffer.c'),
join('src', 'multiarray', 'calculation.c'),
join('src', 'multiarray', 'compiled_base.c'),
diff --git a/numpy/core/src/common/get_attr_string.h b/numpy/core/src/common/get_attr_string.h
index bec87c5ed..d458d9550 100644
--- a/numpy/core/src/common/get_attr_string.h
+++ b/numpy/core/src/common/get_attr_string.h
@@ -103,7 +103,6 @@ PyArray_LookupSpecial(PyObject *obj, char *name)
if (_is_basic_python_type(tp)) {
return NULL;
}
-
return maybe_get_attr((PyObject *)tp, name);
}
diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c
new file mode 100644
index 000000000..e62b32ab2
--- /dev/null
+++ b/numpy/core/src/multiarray/arrayfunction_override.c
@@ -0,0 +1,376 @@
+#define NPY_NO_DEPRECATED_API NPY_API_VERSION
+#define _MULTIARRAYMODULE
+
+#include "npy_pycompat.h"
+#include "get_attr_string.h"
+#include "npy_import.h"
+#include "multiarraymodule.h"
+
+
+/* Return the ndarray.__array_function__ method. */
+static PyObject *
+get_ndarray_array_function(void)
+{
+ PyObject* method = PyObject_GetAttrString((PyObject *)&PyArray_Type,
+ "__array_function__");
+ assert(method != NULL);
+ return method;
+}
+
+
+/*
+ * Get an object's __array_function__ method in the fastest way possible.
+ * Never raises an exception. Returns NULL if the method doesn't exist.
+ */
+static PyObject *
+get_array_function(PyObject *obj)
+{
+ static PyObject *ndarray_array_function = NULL;
+
+ if (ndarray_array_function == NULL) {
+ ndarray_array_function = get_ndarray_array_function();
+ }
+
+ /* Fast return for ndarray */
+ if (PyArray_CheckExact(obj)) {
+ Py_INCREF(ndarray_array_function);
+ return ndarray_array_function;
+ }
+
+ return PyArray_LookupSpecial(obj, "__array_function__");
+}
+
+
+/*
+ * Like list.insert(), but for C arrays of PyObject*. Skips error checking.
+ */
+static void
+pyobject_array_insert(PyObject **array, int length, int index, PyObject *item)
+{
+ int j;
+
+ for (j = length; j > index; j--) {
+ array[j] = array[j - 1];
+ }
+ array[index] = item;
+}
+
+
+/*
+ * Collects arguments with __array_function__ and their corresponding methods
+ * in the order in which they should be tried (i.e., skipping redundant types).
+ * `relevant_args` is expected to have been produced by PySequence_Fast.
+ * Returns the number of arguments, or -1 on failure.
+ */
+static int
+get_implementing_args_and_methods(PyObject *relevant_args,
+ PyObject **implementing_args,
+ PyObject **methods)
+{
+ int num_implementing_args = 0;
+ Py_ssize_t i;
+ int j;
+
+ PyObject **items = PySequence_Fast_ITEMS(relevant_args);
+ Py_ssize_t length = PySequence_Fast_GET_SIZE(relevant_args);
+
+ for (i = 0; i < length; i++) {
+ int new_class = 1;
+ PyObject *argument = items[i];
+
+ /* Have we seen this type before? */
+ for (j = 0; j < num_implementing_args; j++) {
+ if (Py_TYPE(argument) == Py_TYPE(implementing_args[j])) {
+ new_class = 0;
+ break;
+ }
+ }
+ if (new_class) {
+ PyObject *method = get_array_function(argument);
+
+ if (method != NULL) {
+ int arg_index;
+
+ if (num_implementing_args >= NPY_MAXARGS) {
+ PyErr_Format(
+ PyExc_TypeError,
+ "maximum number (%d) of distinct argument types " \
+ "implementing __array_function__ exceeded",
+ NPY_MAXARGS);
+ Py_DECREF(method);
+ goto fail;
+ }
+
+ /* "subclasses before superclasses, otherwise left to right" */
+ arg_index = num_implementing_args;
+ for (j = 0; j < num_implementing_args; j++) {
+ PyObject *other_type;
+ other_type = (PyObject *)Py_TYPE(implementing_args[j]);
+ if (PyObject_IsInstance(argument, other_type)) {
+ arg_index = j;
+ break;
+ }
+ }
+ Py_INCREF(argument);
+ pyobject_array_insert(implementing_args, num_implementing_args,
+ arg_index, argument);
+ pyobject_array_insert(methods, num_implementing_args,
+ arg_index, method);
+ ++num_implementing_args;
+ }
+ }
+ }
+ return num_implementing_args;
+
+fail:
+ for (j = 0; j < num_implementing_args; j++) {
+ Py_DECREF(implementing_args[j]);
+ Py_DECREF(methods[j]);
+ }
+ return -1;
+}
+
+
+/*
+ * Is this object ndarray.__array_function__?
+ */
+static int
+is_default_array_function(PyObject *obj)
+{
+ static PyObject *ndarray_array_function = NULL;
+
+ if (ndarray_array_function == NULL) {
+ ndarray_array_function = get_ndarray_array_function();
+ }
+ return obj == ndarray_array_function;
+}
+
+
+/*
+ * Core implementation of ndarray.__array_function__. This is exposed
+ * separately so we can avoid the overhead of a Python method call from
+ * within `implement_array_function`.
+ */
+NPY_NO_EXPORT PyObject *
+array_function_method_impl(PyObject *func, PyObject *types, PyObject *args,
+ PyObject *kwargs)
+{
+ Py_ssize_t j;
+ PyObject *implementation, *result;
+
+ PyObject **items = PySequence_Fast_ITEMS(types);
+ Py_ssize_t length = PySequence_Fast_GET_SIZE(types);
+
+ for (j = 0; j < length; j++) {
+ int is_subclass = PyObject_IsSubclass(
+ items[j], (PyObject *)&PyArray_Type);
+ if (is_subclass == -1) {
+ return NULL;
+ }
+ if (!is_subclass) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+ }
+
+ implementation = PyObject_GetAttr(func, npy_ma_str_wrapped);
+ if (implementation == NULL) {
+ return NULL;
+ }
+ result = PyObject_Call(implementation, args, kwargs);
+ Py_DECREF(implementation);
+ return result;
+}
+
+
+/*
+ * Calls __array_function__ on the provided argument, with a fast-path for
+ * ndarray.
+ */
+static PyObject *
+call_array_function(PyObject* argument, PyObject* method,
+ PyObject* public_api, PyObject* types,
+ PyObject* args, PyObject* kwargs)
+{
+ if (is_default_array_function(method)) {
+ return array_function_method_impl(public_api, types, args, kwargs);
+ }
+ else {
+ return PyObject_CallFunctionObjArgs(
+ method, argument, public_api, types, args, kwargs, NULL);
+ }
+}
+
+
+/*
+ * Implements the __array_function__ protocol for a function, as described in
+ * in NEP-18. See numpy.core.overrides for a full docstring.
+ */
+NPY_NO_EXPORT PyObject *
+array_implement_array_function(
+ PyObject *NPY_UNUSED(dummy), PyObject *positional_args)
+{
+ PyObject *implementation, *public_api, *relevant_args, *args, *kwargs;
+
+ PyObject *types = NULL;
+ PyObject *implementing_args[NPY_MAXARGS];
+ PyObject *array_function_methods[NPY_MAXARGS];
+
+ int j, any_overrides;
+ int num_implementing_args = 0;
+ PyObject *result = NULL;
+
+ static PyObject *errmsg_formatter = NULL;
+
+ if (!PyArg_UnpackTuple(
+ positional_args, "implement_array_function", 5, 5,
+ &implementation, &public_api, &relevant_args, &args, &kwargs)) {
+ return NULL;
+ }
+
+ relevant_args = PySequence_Fast(
+ relevant_args,
+ "dispatcher for __array_function__ did not return an iterable");
+ if (relevant_args == NULL) {
+ return NULL;
+ }
+
+ /* Collect __array_function__ implementations */
+ num_implementing_args = get_implementing_args_and_methods(
+ relevant_args, implementing_args, array_function_methods);
+ if (num_implementing_args == -1) {
+ goto cleanup;
+ }
+
+ /*
+ * Handle the typical case of no overrides. This is merely an optimization
+ * if some arguments are ndarray objects, but is also necessary if no
+ * arguments implement __array_function__ at all (e.g., if they are all
+ * built-in types).
+ */
+ any_overrides = 0;
+ for (j = 0; j < num_implementing_args; j++) {
+ if (!is_default_array_function(array_function_methods[j])) {
+ any_overrides = 1;
+ break;
+ }
+ }
+ if (!any_overrides) {
+ result = PyObject_Call(implementation, args, kwargs);
+ goto cleanup;
+ }
+
+ /*
+ * Create a Python object for types.
+ * We use a tuple, because it's the fastest Python collection to create
+ * and has the bonus of being immutable.
+ */
+ types = PyTuple_New(num_implementing_args);
+ if (types == NULL) {
+ goto cleanup;
+ }
+ for (j = 0; j < num_implementing_args; j++) {
+ PyObject *arg_type = (PyObject *)Py_TYPE(implementing_args[j]);
+ Py_INCREF(arg_type);
+ PyTuple_SET_ITEM(types, j, arg_type);
+ }
+
+ /* Call __array_function__ methods */
+ for (j = 0; j < num_implementing_args; j++) {
+ PyObject *argument = implementing_args[j];
+ PyObject *method = array_function_methods[j];
+
+ /*
+ * We use `public_api` instead of `implementation` here so
+ * __array_function__ implementations can do equality/identity
+ * comparisons.
+ */
+ result = call_array_function(
+ argument, method, public_api, types, args, kwargs);
+
+ if (result == Py_NotImplemented) {
+ /* Try the next one */
+ Py_DECREF(result);
+ result = NULL;
+ }
+ else {
+ /* Either a good result, or an exception was raised. */
+ goto cleanup;
+ }
+ }
+
+ /* No acceptable override found, raise TypeError. */
+ npy_cache_import("numpy.core._internal",
+ "array_function_errmsg_formatter",
+ &errmsg_formatter);
+ if (errmsg_formatter != NULL) {
+ PyObject *errmsg = PyObject_CallFunctionObjArgs(
+ errmsg_formatter, public_api, types, NULL);
+ if (errmsg != NULL) {
+ PyErr_SetObject(PyExc_TypeError, errmsg);
+ Py_DECREF(errmsg);
+ }
+ }
+
+cleanup:
+ for (j = 0; j < num_implementing_args; j++) {
+ Py_DECREF(implementing_args[j]);
+ Py_DECREF(array_function_methods[j]);
+ }
+ Py_XDECREF(types);
+ Py_DECREF(relevant_args);
+ return result;
+}
+
+
+/*
+ * Python wrapper for get_implementing_args_and_methods, for testing purposes.
+ */
+NPY_NO_EXPORT PyObject *
+array__get_implementing_args(
+ PyObject *NPY_UNUSED(dummy), PyObject *positional_args)
+{
+ PyObject *relevant_args;
+ int j;
+ int num_implementing_args = 0;
+ PyObject *implementing_args[NPY_MAXARGS];
+ PyObject *array_function_methods[NPY_MAXARGS];
+ PyObject *result = NULL;
+
+ if (!PyArg_ParseTuple(positional_args, "O:array__get_implementing_args",
+ &relevant_args)) {
+ return NULL;
+ }
+
+ relevant_args = PySequence_Fast(
+ relevant_args,
+ "dispatcher for __array_function__ did not return an iterable");
+ if (relevant_args == NULL) {
+ return NULL;
+ }
+
+ num_implementing_args = get_implementing_args_and_methods(
+ relevant_args, implementing_args, array_function_methods);
+ if (num_implementing_args == -1) {
+ goto cleanup;
+ }
+
+ /* create a Python object for implementing_args */
+ result = PyList_New(num_implementing_args);
+ if (result == NULL) {
+ goto cleanup;
+ }
+ for (j = 0; j < num_implementing_args; j++) {
+ PyObject *argument = implementing_args[j];
+ Py_INCREF(argument);
+ PyList_SET_ITEM(result, j, argument);
+ }
+
+cleanup:
+ for (j = 0; j < num_implementing_args; j++) {
+ Py_DECREF(implementing_args[j]);
+ Py_DECREF(array_function_methods[j]);
+ }
+ Py_DECREF(relevant_args);
+ return result;
+}
diff --git a/numpy/core/src/multiarray/arrayfunction_override.h b/numpy/core/src/multiarray/arrayfunction_override.h
new file mode 100644
index 000000000..0d224e2b6
--- /dev/null
+++ b/numpy/core/src/multiarray/arrayfunction_override.h
@@ -0,0 +1,16 @@
+#ifndef _NPY_PRIVATE__ARRAYFUNCTION_OVERRIDE_H
+#define _NPY_PRIVATE__ARRAYFUNCTION_OVERRIDE_H
+
+NPY_NO_EXPORT PyObject *
+array_implement_array_function(
+ PyObject *NPY_UNUSED(dummy), PyObject *positional_args);
+
+NPY_NO_EXPORT PyObject *
+array__get_implementing_args(
+ PyObject *NPY_UNUSED(dummy), PyObject *positional_args);
+
+NPY_NO_EXPORT PyObject *
+array_function_method_impl(PyObject *func, PyObject *types, PyObject *args,
+ PyObject *kwargs);
+
+#endif
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 7c814e6e6..085bc00c0 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -8,6 +8,7 @@
#include "numpy/arrayobject.h"
#include "numpy/arrayscalars.h"
+#include "arrayfunction_override.h"
#include "npy_config.h"
#include "npy_pycompat.h"
#include "npy_import.h"
@@ -1088,13 +1089,29 @@ cleanup:
return result;
}
-
static PyObject *
-array_function(PyArrayObject *self, PyObject *args, PyObject *kwds)
+array_function(PyArrayObject *self, PyObject *c_args, PyObject *c_kwds)
{
- NPY_FORWARD_NDARRAY_METHOD("_array_function");
-}
+ PyObject *func, *types, *args, *kwargs, *result;
+ static char *kwlist[] = {"func", "types", "args", "kwargs", NULL};
+
+ if (!PyArg_ParseTupleAndKeywords(
+ c_args, c_kwds, "OOOO:__array_function__", kwlist,
+ &func, &types, &args, &kwargs)) {
+ return NULL;
+ }
+ types = PySequence_Fast(
+ types,
+ "types argument to ndarray.__array_function__ must be iterable");
+ if (types == NULL) {
+ return NULL;
+ }
+
+ result = array_function_method_impl(func, types, args, kwargs);
+ Py_DECREF(types);
+ return result;
+}
static PyObject *
array_copy(PyArrayObject *self, PyObject *args, PyObject *kwds)
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 8135769d9..62345d2b0 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -34,6 +34,7 @@
NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
/* Internal APIs */
+#include "arrayfunction_override.h"
#include "arraytypes.h"
#include "arrayobject.h"
#include "hashdescr.h"
@@ -4062,6 +4063,9 @@ normalize_axis_index(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds)
}
static struct PyMethodDef array_module_methods[] = {
+ {"_get_implementing_args",
+ (PyCFunction)array__get_implementing_args,
+ METH_VARARGS, NULL},
{"_get_ndarray_c_version",
(PyCFunction)array__get_ndarray_c_version,
METH_VARARGS|METH_KEYWORDS, NULL},
@@ -4224,6 +4228,9 @@ static struct PyMethodDef array_module_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"_monotonicity", (PyCFunction)arr__monotonicity,
METH_VARARGS | METH_KEYWORDS, NULL},
+ {"implement_array_function",
+ (PyCFunction)array_implement_array_function,
+ METH_VARARGS, NULL},
{"interp", (PyCFunction)arr_interp,
METH_VARARGS | METH_KEYWORDS, NULL},
{"interp_complex", (PyCFunction)arr_interp_complex,
@@ -4476,6 +4483,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_array_wrap = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_array_finalize = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_buffer = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_ufunc = NULL;
+NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_wrapped = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_order = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_copy = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_dtype = NULL;
@@ -4492,6 +4500,7 @@ intern_strings(void)
npy_ma_str_array_finalize = PyUString_InternFromString("__array_finalize__");
npy_ma_str_buffer = PyUString_InternFromString("__buffer__");
npy_ma_str_ufunc = PyUString_InternFromString("__array_ufunc__");
+ npy_ma_str_wrapped = PyUString_InternFromString("__wrapped__");
npy_ma_str_order = PyUString_InternFromString("order");
npy_ma_str_copy = PyUString_InternFromString("copy");
npy_ma_str_dtype = PyUString_InternFromString("dtype");
@@ -4501,7 +4510,7 @@ intern_strings(void)
return npy_ma_str_array && npy_ma_str_array_prepare &&
npy_ma_str_array_wrap && npy_ma_str_array_finalize &&
- npy_ma_str_buffer && npy_ma_str_ufunc &&
+ npy_ma_str_buffer && npy_ma_str_ufunc && npy_ma_str_wrapped &&
npy_ma_str_order && npy_ma_str_copy && npy_ma_str_dtype &&
npy_ma_str_ndmin && npy_ma_str_axis1 && npy_ma_str_axis2;
}
diff --git a/numpy/core/src/multiarray/multiarraymodule.h b/numpy/core/src/multiarray/multiarraymodule.h
index 3de68c549..60a3965c9 100644
--- a/numpy/core/src/multiarray/multiarraymodule.h
+++ b/numpy/core/src/multiarray/multiarraymodule.h
@@ -7,6 +7,7 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_array_wrap;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_array_finalize;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_buffer;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_ufunc;
+NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_wrapped;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_order;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_copy;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_dtype;
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 8c7ef576e..8f1c16539 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -135,6 +135,17 @@ class TestGetImplementingArgs(object):
assert_equal(_get_implementing_args([a, b, c]), [b, c, a])
assert_equal(_get_implementing_args([a, c, b]), [c, b, a])
+ def test_too_many_duck_arrays(self):
+ namespace = dict(__array_function__=_return_not_implemented)
+ types = [type('A' + str(i), (object,), namespace) for i in range(33)]
+ relevant_args = [t() for t in types]
+
+ actual = _get_implementing_args(relevant_args[:32])
+ assert_equal(actual, relevant_args[:32])
+
+ with assert_raises_regex(TypeError, 'distinct argument types'):
+ _get_implementing_args(relevant_args)
+
@requires_array_function
class TestNDArrayArrayFunction(object):