diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2013-02-28 19:31:57 -0800 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2013-02-28 19:31:57 -0800 |
commit | 5ce8e28a97210fac9f244e79c186433cd373999d (patch) | |
tree | a44f605eb3b7e8e9b1f4b97f90fc928109d5337d | |
parent | 629a2d4daa376e5639ad5106289c77b8137f9f15 (diff) | |
parent | eb5de874f3e2e9e534a05ec791d3ae7c0573147e (diff) | |
download | numpy-5ce8e28a97210fac9f244e79c186433cd373999d.tar.gz |
Merge pull request #2988 from ewmoore/methdot2697
BUG: gh-2687 make multiarray dot method accept out array and keyword args
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 16 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 10 |
2 files changed, 21 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 902384b0d..c81f4b1a8 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -1848,10 +1848,11 @@ array_cumprod(PyArrayObject *self, PyObject *args, PyObject *kwds) static PyObject * array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds) { - PyObject *b; + PyObject *fname, *ret, *b, *out = NULL; static PyObject *numpycore = NULL; + char * kwords[] = {"b", "out", NULL }; - if (!PyArg_ParseTuple(args, "O", &b)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwords, &b, &out)) { return NULL; } @@ -1863,8 +1864,13 @@ array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds) return NULL; } } - - return PyObject_CallMethod(numpycore, "dot", "OO", self, b); + fname = PyUString_FromString("dot"); + if (out == NULL) { + ret = PyObject_CallMethodObjArgs(numpycore, fname, self, b, NULL); + } + ret = PyObject_CallMethodObjArgs(numpycore, fname, self, b, out, NULL); + Py_DECREF(fname); + return ret; } @@ -2223,7 +2229,7 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"dot", (PyCFunction)array_dot, - METH_VARARGS, NULL}, + METH_VARARGS | METH_KEYWORDS, NULL}, {"fill", (PyCFunction)array_fill, METH_VARARGS, NULL}, diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 319147970..b3c0626dc 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -902,6 +902,16 @@ class TestMethods(TestCase): assert_equal(np.dot(a, b), a.dot(b)) assert_equal(np.dot(np.dot(a, b), c), a.dot(b).dot(c)) + # test passing in an output array + c = np.zeros_like(a) + a.dot(b,c) + assert_equal(c, np.dot(a,b)) + + # test keyword args + c = np.zeros_like(a) + a.dot(b=b,out=c) + assert_equal(c, np.dot(a,b)) + def test_diagonal(self): a = np.arange(12).reshape((3, 4)) assert_equal(a.diagonal(), [0, 5, 10]) |