summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2014-01-08 10:34:48 -0800
committerPauli Virtanen <pav@iki.fi>2014-01-08 10:34:48 -0800
commit5f36f57ecd0322c6a6110bd910c82ebdc65262b0 (patch)
tree51c9ba0e2e91f24a934f73f02e37763a68fcbf69
parentb3420fd08d45f591a2149a87ad0889c6c4d45dda (diff)
parentd849245b44417e3e632a19a5e04a627ca6434887 (diff)
downloadnumpy-5f36f57ecd0322c6a6110bd910c82ebdc65262b0.tar.gz
Merge pull request #4171 from cowlicks/ufunc-override-out
BUG: Allow __numpy_ufunc__ to handle multiple output ufuncs.
-rw-r--r--doc/neps/ufunc-overrides.rst17
-rw-r--r--numpy/core/src/private/ufunc_override.h13
-rw-r--r--numpy/core/tests/test_umath.py9
3 files changed, 33 insertions, 6 deletions
diff --git a/doc/neps/ufunc-overrides.rst b/doc/neps/ufunc-overrides.rst
index f57e77054..5a0a0334f 100644
--- a/doc/neps/ufunc-overrides.rst
+++ b/doc/neps/ufunc-overrides.rst
@@ -131,11 +131,22 @@ Here:
- *i* is the index of *self* in *inputs*.
- *inputs* is a tuple of the input arguments to the ``ufunc``
- *kwargs* are the keyword arguments passed to the function. The ``out``
- argument is always contained in *kwargs*, if given.
+ arguments are always contained in *kwargs*, how positional variables
+ are passed is discussed below.
The ufunc's arguments are first normalized into a tuple of input data
-(``inputs``), and dict of keyword arguments. If the output argument is
-passed as a positional argument it is moved to the keyword argmunets.
+(``inputs``), and dict of keyword arguments. If there are output
+arguments they are handeled as follows:
+
+- One positional output variable x is passed in the kwargs dict as ``out :
+ x``.
+- Multiple positional output variables ``x0, x1, ...`` are passed as a tuple
+ in the kwargs dict as ``out : (x0, x1, ...)``.
+- Keyword output variables like ``out = x`` and ``out = (x0, x1, ...)`` are
+ passed unchanged to the kwargs dict like ``out : x`` and ``out : (x0, x1,
+ ...)`` respectively.
+- Combinations of positional and keyword output variables are not
+ supported.
The function dispatch proceeds as follows:
diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h
index d445ac2b8..380aef714 100644
--- a/numpy/core/src/private/ufunc_override.h
+++ b/numpy/core/src/private/ufunc_override.h
@@ -94,10 +94,17 @@ PyUFunc_CheckOverride(PyObject *ufunc, char *method,
goto fail;
}
- /* If we have more args than nin, the last one must be `out`.*/
+ /* If we have more args than nin, they must be the output variables.*/
if (nargs > nin) {
- obj = PyTuple_GET_ITEM(args, nargs - 1);
- PyDict_SetItemString(normal_kwds, "out", obj);
+ if ((nargs - nin) == 1) {
+ obj = PyTuple_GET_ITEM(args, nargs - 1);
+ PyDict_SetItemString(normal_kwds, "out", obj);
+ }
+ else {
+ obj = PyTuple_GetSlice(args, nin, nargs);
+ PyDict_SetItemString(normal_kwds, "out", obj);
+ Py_DECREF(obj);
+ }
}
method_name = PyUString_FromString(method);
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index d61b516ac..6db95e9a6 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -1057,6 +1057,15 @@ class TestSpecialMethods(TestCase):
assert_equal(res4['out'], 'out_arg')
assert_equal(res5['out'], 'out_arg')
+ # ufuncs with multiple output modf and frexp.
+ res6 = np.modf(a, 'out0', 'out1')
+ res7 = np.frexp(a, 'out0', 'out1')
+ assert_equal(res6['out'][0], 'out0')
+ assert_equal(res6['out'][1], 'out1')
+ assert_equal(res7['out'][0], 'out0')
+ assert_equal(res7['out'][1], 'out1')
+
+
def test_ufunc_override_exception(self):
class A(object):
def __numpy_ufunc__(self, *a, **kwargs):