summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_methods.py15
-rw-r--r--numpy/core/overrides.py86
-rw-r--r--numpy/core/src/multiarray/methods.c10
-rw-r--r--numpy/core/tests/test_overrides.py273
4 files changed, 384 insertions, 0 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 33f6d01a8..8974f0ce1 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -154,3 +154,18 @@ def _ptp(a, axis=None, out=None, keepdims=False):
umr_minimum(a, axis, None, None, keepdims),
out
)
+
+_NDARRAY_ARRAY_FUNCTION = mu.ndarray.__array_function__
+
+def _array_function(self, func, types, args, kwargs):
+ # TODO: rewrite this in C
+ # Cannot handle items that have __array_function__ other than our own.
+ for t in types:
+ if t is not mu.ndarray:
+ method = getattr(t, '__array_function__', _NDARRAY_ARRAY_FUNCTION)
+ if method is not _NDARRAY_ARRAY_FUNCTION:
+ return NotImplemented
+
+ # Arguments contain no overrides, so we can safely call the
+ # overloaded function again.
+ return func(*args, **kwargs)
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
new file mode 100644
index 000000000..c1d5e3864
--- /dev/null
+++ b/numpy/core/overrides.py
@@ -0,0 +1,86 @@
+"""Preliminary implementation of NEP-18
+
+TODO: rewrite this in C for performance.
+"""
+import functools
+from numpy.core.multiarray import ndarray
+
+
+_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
+
+
+def get_overloaded_types_and_args(relevant_args):
+ """Returns a list of arguments on which to call __array_function__.
+
+ __array_function__ implementations should be called in order on the return
+ values from this function.
+ """
+ # Runtime is O(num_arguments * num_unique_types)
+ overloaded_types = []
+ overloaded_args = []
+ for arg in relevant_args:
+ arg_type = type(arg)
+ if (arg_type not in overloaded_types and
+ hasattr(arg_type, '__array_function__')):
+
+ overloaded_types.append(arg_type)
+
+ # By default, insert this argument at the end, but if it is
+ # subclass of another argument, insert it before that argument.
+ # This ensures "subclasses before superclasses".
+ index = len(overloaded_args)
+ for i, old_arg in enumerate(overloaded_args):
+ if issubclass(arg_type, type(old_arg)):
+ index = i
+ break
+ overloaded_args.insert(index, arg)
+
+ # Special handling for ndarray.
+ overloaded_args = [
+ arg for arg in overloaded_args
+ if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
+ ]
+
+ return overloaded_types, overloaded_args
+
+
+def array_function_override(overloaded_args, func, types, args, kwargs):
+ """Call __array_function__ implementations."""
+ for overloaded_arg in overloaded_args:
+ # Note that we're only calling __array_function__ on the *first*
+ # occurence of each argument type. This is necessary for reasonable
+ # performance with a possibly long list of overloaded arguments, for
+ # which each __array_function__ implementation might reasonably need to
+ # check all argument types.
+ result = overloaded_arg.__array_function__(func, types, args, kwargs)
+
+ if result is not NotImplemented:
+ return result
+
+ raise TypeError('no implementation found for {} on types that implement '
+ '__array_function__: {}'
+ .format(func, list(map(type, overloaded_args))))
+
+
+def array_function_dispatch(dispatcher):
+ """Wrap a function for dispatch with the __array_function__ protocol."""
+ def decorator(func):
+ @functools.wraps(func)
+ def new_func(*args, **kwargs):
+ # Collect array-like arguments.
+ relevant_arguments = dispatcher(*args, **kwargs)
+ # Check for __array_function__ methods.
+ types, overloaded_args = get_overloaded_types_and_args(
+ relevant_arguments)
+ # Call overrides, if necessary.
+ if overloaded_args:
+ # new_func is the function exposed in NumPy's public API. We
+ # use it instead of func so __array_function__ implementations
+ # can do equality/identity comparisons.
+ return array_function_override(
+ overloaded_args, new_func, types, args, kwargs)
+ else:
+ return func(*args, **kwargs)
+
+ return new_func
+ return decorator
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 3d2cce5e1..6317d6a16 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -1022,6 +1022,13 @@ cleanup:
static PyObject *
+array_function(PyArrayObject *self, PyObject *args, PyObject *kwds)
+{
+ NPY_FORWARD_NDARRAY_METHOD("_array_function");
+}
+
+
+static PyObject *
array_copy(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
NPY_ORDER order = NPY_CORDER;
@@ -2472,6 +2479,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
{"__array_ufunc__",
(PyCFunction)array_ufunc,
METH_VARARGS | METH_KEYWORDS, NULL},
+ {"__array_function__",
+ (PyCFunction)array_function,
+ METH_VARARGS | METH_KEYWORDS, NULL},
#ifndef NPY_PY3K
{"__unicode__",
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
new file mode 100644
index 000000000..7f6157a5b
--- /dev/null
+++ b/numpy/core/tests/test_overrides.py
@@ -0,0 +1,273 @@
+from __future__ import division, absolute_import, print_function
+
+import pickle
+import sys
+
+import numpy as np
+from numpy.testing import (
+ assert_, assert_equal, assert_raises, assert_raises_regex)
+from numpy.core.overrides import (
+ get_overloaded_types_and_args, array_function_dispatch)
+
+
+def _get_overloaded_args(relevant_args):
+ types, args = get_overloaded_types_and_args(relevant_args)
+ return args
+
+
+def _return_self(self, *args, **kwargs):
+ return self
+
+
+class TestGetOverloadedTypesAndArgs(object):
+
+ def test_ndarray(self):
+ array = np.array(1)
+
+ types, args = get_overloaded_types_and_args([array])
+ assert_equal(set(types), {np.ndarray})
+ assert_equal(list(args), [])
+
+ types, args = get_overloaded_types_and_args([array, array])
+ assert_equal(len(types), 1)
+ assert_equal(set(types), {np.ndarray})
+ assert_equal(list(args), [])
+
+ types, args = get_overloaded_types_and_args([array, 1])
+ assert_equal(set(types), {np.ndarray})
+ assert_equal(list(args), [])
+
+ types, args = get_overloaded_types_and_args([1, array])
+ assert_equal(set(types), {np.ndarray})
+ assert_equal(list(args), [])
+
+ def test_ndarray_subclasses(self):
+
+ class OverrideSub(np.ndarray):
+ __array_function__ = _return_self
+
+ class NoOverrideSub(np.ndarray):
+ pass
+
+ array = np.array(1).view(np.ndarray)
+ override_sub = np.array(1).view(OverrideSub)
+ no_override_sub = np.array(1).view(NoOverrideSub)
+
+ types, args = get_overloaded_types_and_args([array, override_sub])
+ assert_equal(set(types), {np.ndarray, OverrideSub})
+ assert_equal(list(args), [override_sub])
+
+ types, args = get_overloaded_types_and_args([array, no_override_sub])
+ assert_equal(set(types), {np.ndarray, NoOverrideSub})
+ assert_equal(list(args), [])
+
+ types, args = get_overloaded_types_and_args(
+ [override_sub, no_override_sub])
+ assert_equal(set(types), {OverrideSub, NoOverrideSub})
+ assert_equal(list(args), [override_sub])
+
+ def test_ndarray_and_duck_array(self):
+
+ class Other(object):
+ __array_function__ = _return_self
+
+ array = np.array(1)
+ other = Other()
+
+ types, args = get_overloaded_types_and_args([other, array])
+ assert_equal(set(types), {np.ndarray, Other})
+ assert_equal(list(args), [other])
+
+ types, args = get_overloaded_types_and_args([array, other])
+ assert_equal(set(types), {np.ndarray, Other})
+ assert_equal(list(args), [other])
+
+ def test_ndarray_subclass_and_duck_array(self):
+
+ class OverrideSub(np.ndarray):
+ __array_function__ = _return_self
+
+ class Other(object):
+ __array_function__ = _return_self
+
+ array = np.array(1)
+ subarray = np.array(1).view(OverrideSub)
+ other = Other()
+
+ assert_equal(_get_overloaded_args([array, subarray, other]),
+ [subarray, other])
+ assert_equal(_get_overloaded_args([array, other, subarray]),
+ [subarray, other])
+
+ def test_many_duck_arrays(self):
+
+ class A(object):
+ __array_function__ = _return_self
+
+ class B(A):
+ __array_function__ = _return_self
+
+ class C(A):
+ __array_function__ = _return_self
+
+ class D(object):
+ __array_function__ = _return_self
+
+ a = A()
+ b = B()
+ c = C()
+ d = D()
+
+ assert_equal(_get_overloaded_args([1]), [])
+ assert_equal(_get_overloaded_args([a]), [a])
+ assert_equal(_get_overloaded_args([a, 1]), [a])
+ assert_equal(_get_overloaded_args([a, a, a]), [a])
+ assert_equal(_get_overloaded_args([a, d, a]), [a, d])
+ assert_equal(_get_overloaded_args([a, b]), [b, a])
+ assert_equal(_get_overloaded_args([b, a]), [b, a])
+ assert_equal(_get_overloaded_args([a, b, c]), [b, c, a])
+ assert_equal(_get_overloaded_args([a, c, b]), [c, b, a])
+
+
+class TestNDArrayArrayFunction(object):
+
+ def test_method(self):
+
+ class SubOverride(np.ndarray):
+ __array_function__ = _return_self
+
+ class NoOverrideSub(np.ndarray):
+ pass
+
+ array = np.array(1)
+
+ def func():
+ return 'original'
+
+ result = array.__array_function__(
+ func=func, types=(np.ndarray,), args=(), kwargs={})
+ assert_equal(result, 'original')
+
+ result = array.__array_function__(
+ func=func, types=(np.ndarray, SubOverride), args=(), kwargs={})
+ assert_(result is NotImplemented)
+
+ result = array.__array_function__(
+ func=func, types=(np.ndarray, NoOverrideSub), args=(), kwargs={})
+ assert_equal(result, 'original')
+
+
+# need to define this at the top level to test pickling
+@array_function_dispatch(lambda array: (array,))
+def dispatched_one_arg(array):
+ """Docstring."""
+ return 'original'
+
+
+class TestArrayFunctionDispatch(object):
+
+ def test_pickle(self):
+ roundtripped = pickle.loads(pickle.dumps(dispatched_one_arg))
+ assert_(roundtripped is dispatched_one_arg)
+
+ def test_name_and_docstring(self):
+ assert_equal(dispatched_one_arg.__name__, 'dispatched_one_arg')
+ if sys.flags.optimize < 2:
+ assert_equal(dispatched_one_arg.__doc__, 'Docstring.')
+
+ def test_interface(self):
+
+ class MyArray(object):
+ def __array_function__(self, func, types, args, kwargs):
+ return (self, func, types, args, kwargs)
+
+ original = MyArray()
+ (obj, func, types, args, kwargs) = dispatched_one_arg(original)
+ assert_(obj is original)
+ assert_(func is dispatched_one_arg)
+ assert_equal(set(types), {MyArray})
+ assert_equal(args, (original,))
+ assert_equal(kwargs, {})
+
+ def test_not_implemented(self):
+
+ class MyArray(object):
+ def __array_function__(self, func, types, args, kwargs):
+ return NotImplemented
+
+ array = MyArray()
+ with assert_raises_regex(TypeError, 'no implementation found'):
+ dispatched_one_arg(array)
+
+
+def _new_duck_type_and_implements():
+ """Create a duck array type and implements functions."""
+ HANDLED_FUNCTIONS = {}
+
+ class MyArray(object):
+ def __array_function__(self, func, types, args, kwargs):
+ if func not in HANDLED_FUNCTIONS:
+ return NotImplemented
+ if not all(issubclass(t, MyArray) for t in types):
+ return NotImplemented
+ return HANDLED_FUNCTIONS[func](*args, **kwargs)
+
+ def implements(numpy_function):
+ """Register an __array_function__ implementations."""
+ def decorator(func):
+ HANDLED_FUNCTIONS[numpy_function] = func
+ return func
+ return decorator
+
+ return (MyArray, implements)
+
+
+class TestArrayFunctionImplementation(object):
+
+ def test_one_arg(self):
+ MyArray, implements = _new_duck_type_and_implements()
+
+ @implements(dispatched_one_arg)
+ def _(array):
+ return 'myarray'
+
+ assert_equal(dispatched_one_arg(1), 'original')
+ assert_equal(dispatched_one_arg(MyArray()), 'myarray')
+
+ def test_optional_args(self):
+ MyArray, implements = _new_duck_type_and_implements()
+
+ @array_function_dispatch(lambda array, option=None: (array,))
+ def func_with_option(array, option='default'):
+ return option
+
+ @implements(func_with_option)
+ def my_array_func_with_option(array, new_option='myarray'):
+ return new_option
+
+ # we don't need to implement every option on __array_function__
+ # implementations
+ assert_equal(func_with_option(1), 'default')
+ assert_equal(func_with_option(1, option='extra'), 'extra')
+ assert_equal(func_with_option(MyArray()), 'myarray')
+ with assert_raises(TypeError):
+ func_with_option(MyArray(), option='extra')
+
+ # but new options on implementations can't be used
+ result = my_array_func_with_option(MyArray(), new_option='yes')
+ assert_equal(result, 'yes')
+ with assert_raises(TypeError):
+ func_with_option(MyArray(), new_option='no')
+
+ def test_not_implemented(self):
+ MyArray, implements = _new_duck_type_and_implements()
+
+ @array_function_dispatch(lambda array: (array,))
+ def func(array):
+ return array
+
+ array = np.array(1)
+ assert_(func(array) is array)
+
+ with assert_raises_regex(TypeError, 'no implementation found'):
+ func(MyArray())