diff options
author | Eric Moore <ewm@redtetrahedron.org> | 2013-02-14 22:45:39 -0500 |
---|---|---|
committer | Eric Moore <ewm@redtetrahedron.org> | 2013-02-14 22:45:39 -0500 |
commit | 1d04300248882b207f493c7207f83a8b467be2be (patch) | |
tree | 4f246d0c6ff392f92a501dcf770dfe93aa10fc62 | |
parent | 4bf5a3feb00fe1d63e7d8fcf852cbf34e22fd60b (diff) | |
download | numpy-1d04300248882b207f493c7207f83a8b467be2be.tar.gz |
BUG: gh-2687 make multiarray dot method accept out array and keyword
args.
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 13 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 8 |
2 files changed, 16 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 7f0e3861b..ab3cae113 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -1847,10 +1847,11 @@ array_cumprod(PyArrayObject *self, PyObject *args, PyObject *kwds) static PyObject * array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds) { - PyObject *b; + PyObject *b, *out = NULL; static PyObject *numpycore = NULL; + char * kwords[] = {"b", "out", NULL }; - if (!PyArg_ParseTuple(args, "O", &b)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwords, &b, &out)) { return NULL; } @@ -1862,8 +1863,10 @@ array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds) return NULL; } } - - return PyObject_CallMethod(numpycore, "dot", "OO", self, b); + if (out == NULL) { + return PyObject_CallMethod(numpycore, "dot", "OO", self, b); + } + return PyObject_CallMethod(numpycore, "dot", "OOO", self, b, out); } @@ -2222,7 +2225,7 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"dot", (PyCFunction)array_dot, - METH_VARARGS, NULL}, + METH_VARARGS | METH_KEYWORDS, NULL}, {"fill", (PyCFunction)array_fill, METH_VARARGS, NULL}, diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 319147970..9449046b9 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -902,6 +902,14 @@ class TestMethods(TestCase): assert_equal(np.dot(a, b), a.dot(b)) assert_equal(np.dot(np.dot(a, b), c), a.dot(b).dot(c)) + # test changes from gh-2687 (trac 2096) + c = np.zeros_like(a) + a.dot(b,c) + assert_equal(c, np.dot(a,b)) + c = np.zeros_like(a) + a.dot(b=b,out=c) + assert_equal(c, np.dot(a,b)) + def test_diagonal(self): a = np.arange(12).reshape((3, 4)) assert_equal(a.diagonal(), [0, 5, 10]) |