diff options
author | Pauli Virtanen <pav@iki.fi> | 2010-04-30 07:06:02 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2010-04-30 07:06:02 +0000 |
commit | 6dc315f8aa7e67f637bcd2beb0c4105ddfcb4a83 (patch) | |
tree | dc381a77caf5913b30df444a1faae8f8d15d4b38 | |
parent | bceb9c20700514db5667831ce2878f1660fb071f (diff) | |
download | numpy-6dc315f8aa7e67f637bcd2beb0c4105ddfcb4a83.tar.gz |
ENH: core: add .dot() method to ndarrays; a.dot(b) == np.dot(a, b)
-rw-r--r-- | doc/release/2.0.0-notes.rst | 13 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 26 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 7 | ||||
-rw-r--r-- | numpy/matrixlib/tests/test_defmatrix.py | 5 |
4 files changed, 49 insertions, 2 deletions
diff --git a/doc/release/2.0.0-notes.rst b/doc/release/2.0.0-notes.rst index 06e29de6b..ccb596711 100644 --- a/doc/release/2.0.0-notes.rst +++ b/doc/release/2.0.0-notes.rst @@ -35,3 +35,16 @@ turned off in the standard way: >>> import warnings >>> warnings.simplefilter("ignore", np.ComplexWarning) + +Dot method for ndarrays +~~~~~~~~~~~~~~~~~~~~~~~ + +Ndarrays now have the dot product also as a method, which allows writing +chains of matrix products as + + >>> a.dot(b).dot(c) + +instead of the longer alternative + + >>> np.dot(a, np.dot(b, c)) + diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index dbc1c9086..d034bc929 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -1793,6 +1793,29 @@ array_cumprod(PyArrayObject *self, PyObject *args, PyObject *kwds) static PyObject * +array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds) +{ + PyObject *b; + static PyObject *numpycore = NULL; + + if (!PyArg_ParseTuple(args, "O", &b)) { + return NULL; + } + + /* Since blas-dot is exposed only on the Python side, we need to grab it + * from there */ + if (numpycore == NULL) { + numpycore = PyImport_ImportModule("numpy.core"); + if (numpycore == NULL) { + return NULL; + } + } + + return PyObject_CallMethod(numpycore, "dot", "OO", self, b); +} + + +static PyObject * array_any(PyArrayObject *self, PyObject *args, PyObject *kwds) { int axis = MAX_DIMS; @@ -2192,6 +2215,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { {"diagonal", (PyCFunction)array_diagonal, METH_VARARGS | METH_KEYWORDS, NULL}, + {"dot", + (PyCFunction)array_dot, + METH_VARARGS, NULL}, {"fill", (PyCFunction)array_fill, METH_VARARGS, NULL}, diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index b8ea0a240..460261694 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -541,6 +541,13 @@ class TestMethods(TestCase): assert_equal(x1.flatten('F'), y1f) assert_equal(x1.flatten('F'), x1.T.flatten()) + def test_dot(self): + a = np.array([[1, 0], [0, 1]]) + b = np.array([[0, 1], [1, 0]]) + c = np.array([[9, 1], [1, -9]]) + + assert_equal(np.dot(a, b), a.dot(b)) + assert_equal(np.dot(np.dot(a, b), c), a.dot(b).dot(c)) class TestSubscripting(TestCase): def test_test_zero_rank(self): diff --git a/numpy/matrixlib/tests/test_defmatrix.py b/numpy/matrixlib/tests/test_defmatrix.py index cb6bd801d..65d79df0b 100644 --- a/numpy/matrixlib/tests/test_defmatrix.py +++ b/numpy/matrixlib/tests/test_defmatrix.py @@ -254,7 +254,8 @@ class TestMatrixReturn(TestCase): 'compress' : ([1],), 'repeat' : (1,), 'reshape' : (1,), - 'swapaxes' : (0,0) + 'swapaxes' : (0,0), + 'dot': np.array([1.0]), } excluded_methods = [ 'argmin', 'choose', 'dump', 'dumps', 'fill', 'getfield', @@ -267,7 +268,7 @@ class TestMatrixReturn(TestCase): for attrib in dir(a): if attrib.startswith('_') or attrib in excluded_methods: continue - f = eval('a.%s' % attrib) + f = getattr(a, attrib) if callable(f): # reset contents of a a.astype('f8') |