diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-05-02 11:31:24 -0400 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-06-07 14:37:21 -0400 |
commit | 2d67af51f4d87282237636cb3a288bf50f548fc8 (patch) | |
tree | 8641b7191c9abdd86c1421916bbd14370716c871 /numpy | |
parent | 2abef6cb51d6b7e8c704b4c65328f392152ad23d (diff) | |
download | numpy-2d67af51f4d87282237636cb3a288bf50f548fc8.tar.gz |
MAINT: let ufunc override reject passing in both axis and axes.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/umath/override.c | 12 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 1 |
2 files changed, 11 insertions, 2 deletions
diff --git a/numpy/core/src/umath/override.c b/numpy/core/src/umath/override.c index c298fe315..c0bc47b7b 100644 --- a/numpy/core/src/umath/override.c +++ b/numpy/core/src/umath/override.c @@ -51,6 +51,7 @@ normalize___call___args(PyUFuncObject *ufunc, PyObject *args, npy_intp nin = ufunc->nin; npy_intp nout = ufunc->nout; npy_intp nargs = PyTuple_GET_SIZE(args); + npy_intp nkwds = PyDict_Size(*normal_kwds); PyObject *obj; if (nargs < nin) { @@ -74,7 +75,7 @@ normalize___call___args(PyUFuncObject *ufunc, PyObject *args, /* If we have more args than nin, they must be the output variables.*/ if (nargs > nin) { - if(PyDict_GetItemString(*normal_kwds, "out")) { + if(nkwds > 0 && PyDict_GetItemString(*normal_kwds, "out")) { PyErr_Format(PyExc_TypeError, "argument given by name ('out') and position " "(%"NPY_INTP_FMT")", nin); @@ -112,8 +113,15 @@ normalize___call___args(PyUFuncObject *ufunc, PyObject *args, Py_DECREF(obj); } } + /* gufuncs accept either 'axes' or 'axis', but not both */ + if (nkwds >= 2 && (PyDict_GetItemString(*normal_kwds, "axis") && + PyDict_GetItemString(*normal_kwds, "axes"))) { + PyErr_SetString(PyExc_TypeError, + "cannot specify both 'axis' and 'axes'"); + return -1; + } /* finally, ufuncs accept 'sig' or 'signature' normalize to 'signature' */ - return normalize_signature_keyword(*normal_kwds); + return nkwds == 0 ? 0 : normalize_signature_keyword(*normal_kwds); } static int diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index 93ec73094..3c0d1759a 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -1810,6 +1810,7 @@ class TestSpecialMethods(object): assert_raises(TypeError, np.multiply, a) assert_raises(TypeError, np.multiply, a, a, a, a) assert_raises(TypeError, np.multiply, a, a, sig='a', signature='a') + assert_raises(TypeError, ncu_tests.inner1d, a, a, axis=0, axes=[0, 0]) # reduce, positional args res = np.multiply.reduce(a, 'axis0', 'dtype0', 'out0', 'keep0') |