summaryrefslogtreecommitdiff
path: root/doc/source/user/basics.dispatch.rst
diff options
context:
space:
mode:
Diffstat (limited to 'doc/source/user/basics.dispatch.rst')
-rw-r--r--doc/source/user/basics.dispatch.rst27
1 files changed, 17 insertions, 10 deletions
diff --git a/doc/source/user/basics.dispatch.rst b/doc/source/user/basics.dispatch.rst
index 089a7df17..35c73dde4 100644
--- a/doc/source/user/basics.dispatch.rst
+++ b/doc/source/user/basics.dispatch.rst
@@ -57,7 +57,7 @@ array([[2., 0., 0., 0., 0.],
Notice that the return type is a standard ``numpy.ndarray``.
>>> type(np.multiply(arr, 2))
-numpy.ndarray
+<class 'numpy.ndarray'>
How can we pass our custom array type through this function? Numpy allows a
class to indicate that it would like to handle computations in a custom-defined
@@ -119,7 +119,9 @@ DiagonalArray(N=5, value=0.8414709848078965)
At this point ``arr + 3`` does not work.
>>> arr + 3
-TypeError: unsupported operand type(s) for *: 'DiagonalArray' and 'int'
+Traceback (most recent call last):
+...
+TypeError: unsupported operand type(s) for +: 'DiagonalArray' and 'int'
To support it, we need to define the Python interfaces ``__add__``, ``__lt__``,
and so on to dispatch to the corresponding ufunc. We can achieve this
@@ -193,14 +195,14 @@ functions to our custom variants.
... return self.__class__(N, ufunc(*scalars, **kwargs))
... else:
... return NotImplemented
-... def __array_function__(self, func, types, args, kwargs):
-... if func not in HANDLED_FUNCTIONS:
-... return NotImplemented
-... # Note: this allows subclasses that don't override
-... # __array_function__ to handle DiagonalArray objects.
-... if not all(issubclass(t, self.__class__) for t in types):
-... return NotImplemented
-... return HANDLED_FUNCTIONS[func](*args, **kwargs)
+... def __array_function__(self, func, types, args, kwargs):
+... if func not in HANDLED_FUNCTIONS:
+... return NotImplemented
+... # Note: this allows subclasses that don't override
+... # __array_function__ to handle DiagonalArray objects.
+... if not all(issubclass(t, self.__class__) for t in types):
+... return NotImplemented
+... return HANDLED_FUNCTIONS[func](*args, **kwargs)
...
A convenient pattern is to define a decorator ``implements`` that can be used
@@ -241,14 +243,19 @@ this operation is not supported. For example, concatenating two
supported.
>>> np.concatenate([arr, arr])
+Traceback (most recent call last):
+...
TypeError: no implementation found for 'numpy.concatenate' on types that implement __array_function__: [<class '__main__.DiagonalArray'>]
Additionally, our implementations of ``sum`` and ``mean`` do not accept the
optional arguments that numpy's implementation does.
>>> np.sum(arr, axis=0)
+Traceback (most recent call last):
+...
TypeError: sum() got an unexpected keyword argument 'axis'
+
The user always has the option of converting to a normal ``numpy.ndarray`` with
:func:`numpy.asarray` and using standard numpy from there.