summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/methods.c16
-rw-r--r--numpy/core/tests/test_multiarray.py10
2 files changed, 21 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 902384b0d..c81f4b1a8 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -1848,10 +1848,11 @@ array_cumprod(PyArrayObject *self, PyObject *args, PyObject *kwds)
static PyObject *
array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
- PyObject *b;
+ PyObject *fname, *ret, *b, *out = NULL;
static PyObject *numpycore = NULL;
+ char * kwords[] = {"b", "out", NULL };
- if (!PyArg_ParseTuple(args, "O", &b)) {
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwords, &b, &out)) {
return NULL;
}
@@ -1863,8 +1864,13 @@ array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds)
return NULL;
}
}
-
- return PyObject_CallMethod(numpycore, "dot", "OO", self, b);
+ fname = PyUString_FromString("dot");
+ if (out == NULL) {
+ ret = PyObject_CallMethodObjArgs(numpycore, fname, self, b, NULL);
+ }
+ ret = PyObject_CallMethodObjArgs(numpycore, fname, self, b, out, NULL);
+ Py_DECREF(fname);
+ return ret;
}
@@ -2223,7 +2229,7 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"dot",
(PyCFunction)array_dot,
- METH_VARARGS, NULL},
+ METH_VARARGS | METH_KEYWORDS, NULL},
{"fill",
(PyCFunction)array_fill,
METH_VARARGS, NULL},
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 319147970..b3c0626dc 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -902,6 +902,16 @@ class TestMethods(TestCase):
assert_equal(np.dot(a, b), a.dot(b))
assert_equal(np.dot(np.dot(a, b), c), a.dot(b).dot(c))
+ # test passing in an output array
+ c = np.zeros_like(a)
+ a.dot(b,c)
+ assert_equal(c, np.dot(a,b))
+
+ # test keyword args
+ c = np.zeros_like(a)
+ a.dot(b=b,out=c)
+ assert_equal(c, np.dot(a,b))
+
def test_diagonal(self):
a = np.arange(12).reshape((3, 4))
assert_equal(a.diagonal(), [0, 5, 10])