summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-05-02 11:31:24 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-06-07 14:37:21 -0400
commit2d67af51f4d87282237636cb3a288bf50f548fc8 (patch)
tree8641b7191c9abdd86c1421916bbd14370716c871 /numpy
parent2abef6cb51d6b7e8c704b4c65328f392152ad23d (diff)
downloadnumpy-2d67af51f4d87282237636cb3a288bf50f548fc8.tar.gz
MAINT: let ufunc override reject passing in both axis and axes.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/umath/override.c12
-rw-r--r--numpy/core/tests/test_umath.py1
2 files changed, 11 insertions, 2 deletions
diff --git a/numpy/core/src/umath/override.c b/numpy/core/src/umath/override.c
index c298fe315..c0bc47b7b 100644
--- a/numpy/core/src/umath/override.c
+++ b/numpy/core/src/umath/override.c
@@ -51,6 +51,7 @@ normalize___call___args(PyUFuncObject *ufunc, PyObject *args,
npy_intp nin = ufunc->nin;
npy_intp nout = ufunc->nout;
npy_intp nargs = PyTuple_GET_SIZE(args);
+ npy_intp nkwds = PyDict_Size(*normal_kwds);
PyObject *obj;
if (nargs < nin) {
@@ -74,7 +75,7 @@ normalize___call___args(PyUFuncObject *ufunc, PyObject *args,
/* If we have more args than nin, they must be the output variables.*/
if (nargs > nin) {
- if(PyDict_GetItemString(*normal_kwds, "out")) {
+ if(nkwds > 0 && PyDict_GetItemString(*normal_kwds, "out")) {
PyErr_Format(PyExc_TypeError,
"argument given by name ('out') and position "
"(%"NPY_INTP_FMT")", nin);
@@ -112,8 +113,15 @@ normalize___call___args(PyUFuncObject *ufunc, PyObject *args,
Py_DECREF(obj);
}
}
+ /* gufuncs accept either 'axes' or 'axis', but not both */
+ if (nkwds >= 2 && (PyDict_GetItemString(*normal_kwds, "axis") &&
+ PyDict_GetItemString(*normal_kwds, "axes"))) {
+ PyErr_SetString(PyExc_TypeError,
+ "cannot specify both 'axis' and 'axes'");
+ return -1;
+ }
/* finally, ufuncs accept 'sig' or 'signature' normalize to 'signature' */
- return normalize_signature_keyword(*normal_kwds);
+ return nkwds == 0 ? 0 : normalize_signature_keyword(*normal_kwds);
}
static int
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 93ec73094..3c0d1759a 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -1810,6 +1810,7 @@ class TestSpecialMethods(object):
assert_raises(TypeError, np.multiply, a)
assert_raises(TypeError, np.multiply, a, a, a, a)
assert_raises(TypeError, np.multiply, a, a, sig='a', signature='a')
+ assert_raises(TypeError, ncu_tests.inner1d, a, a, axis=0, axes=[0, 0])
# reduce, positional args
res = np.multiply.reduce(a, 'axis0', 'dtype0', 'out0', 'keep0')