summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Moore <ewm@redtetrahedron.org>2013-02-14 22:45:39 -0500
committerEric Moore <ewm@redtetrahedron.org>2013-02-14 22:45:39 -0500
commit1d04300248882b207f493c7207f83a8b467be2be (patch)
tree4f246d0c6ff392f92a501dcf770dfe93aa10fc62
parent4bf5a3feb00fe1d63e7d8fcf852cbf34e22fd60b (diff)
downloadnumpy-1d04300248882b207f493c7207f83a8b467be2be.tar.gz
BUG: gh-2687 make multiarray dot method accept out array and keyword
args.
-rw-r--r--numpy/core/src/multiarray/methods.c13
-rw-r--r--numpy/core/tests/test_multiarray.py8
2 files changed, 16 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 7f0e3861b..ab3cae113 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -1847,10 +1847,11 @@ array_cumprod(PyArrayObject *self, PyObject *args, PyObject *kwds)
static PyObject *
array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
- PyObject *b;
+ PyObject *b, *out = NULL;
static PyObject *numpycore = NULL;
+ char * kwords[] = {"b", "out", NULL };
- if (!PyArg_ParseTuple(args, "O", &b)) {
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwords, &b, &out)) {
return NULL;
}
@@ -1862,8 +1863,10 @@ array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds)
return NULL;
}
}
-
- return PyObject_CallMethod(numpycore, "dot", "OO", self, b);
+ if (out == NULL) {
+ return PyObject_CallMethod(numpycore, "dot", "OO", self, b);
+ }
+ return PyObject_CallMethod(numpycore, "dot", "OOO", self, b, out);
}
@@ -2222,7 +2225,7 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"dot",
(PyCFunction)array_dot,
- METH_VARARGS, NULL},
+ METH_VARARGS | METH_KEYWORDS, NULL},
{"fill",
(PyCFunction)array_fill,
METH_VARARGS, NULL},
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 319147970..9449046b9 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -902,6 +902,14 @@ class TestMethods(TestCase):
assert_equal(np.dot(a, b), a.dot(b))
assert_equal(np.dot(np.dot(a, b), c), a.dot(b).dot(c))
+ # test changes from gh-2687 (trac 2096)
+ c = np.zeros_like(a)
+ a.dot(b,c)
+ assert_equal(c, np.dot(a,b))
+ c = np.zeros_like(a)
+ a.dot(b=b,out=c)
+ assert_equal(c, np.dot(a,b))
+
def test_diagonal(self):
a = np.arange(12).reshape((3, 4))
assert_equal(a.diagonal(), [0, 5, 10])