summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2010-04-30 07:06:02 +0000
committerPauli Virtanen <pav@iki.fi>2010-04-30 07:06:02 +0000
commit6dc315f8aa7e67f637bcd2beb0c4105ddfcb4a83 (patch)
treedc381a77caf5913b30df444a1faae8f8d15d4b38 /numpy
parentbceb9c20700514db5667831ce2878f1660fb071f (diff)
downloadnumpy-6dc315f8aa7e67f637bcd2beb0c4105ddfcb4a83.tar.gz
ENH: core: add .dot() method to ndarrays; a.dot(b) == np.dot(a, b)
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/methods.c26
-rw-r--r--numpy/core/tests/test_multiarray.py7
-rw-r--r--numpy/matrixlib/tests/test_defmatrix.py5
3 files changed, 36 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index dbc1c9086..d034bc929 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -1793,6 +1793,29 @@ array_cumprod(PyArrayObject *self, PyObject *args, PyObject *kwds)
static PyObject *
+array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds)
+{
+ PyObject *b;
+ static PyObject *numpycore = NULL;
+
+ if (!PyArg_ParseTuple(args, "O", &b)) {
+ return NULL;
+ }
+
+ /* Since blas-dot is exposed only on the Python side, we need to grab it
+ * from there */
+ if (numpycore == NULL) {
+ numpycore = PyImport_ImportModule("numpy.core");
+ if (numpycore == NULL) {
+ return NULL;
+ }
+ }
+
+ return PyObject_CallMethod(numpycore, "dot", "OO", self, b);
+}
+
+
+static PyObject *
array_any(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
int axis = MAX_DIMS;
@@ -2192,6 +2215,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
{"diagonal",
(PyCFunction)array_diagonal,
METH_VARARGS | METH_KEYWORDS, NULL},
+ {"dot",
+ (PyCFunction)array_dot,
+ METH_VARARGS, NULL},
{"fill",
(PyCFunction)array_fill,
METH_VARARGS, NULL},
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index b8ea0a240..460261694 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -541,6 +541,13 @@ class TestMethods(TestCase):
assert_equal(x1.flatten('F'), y1f)
assert_equal(x1.flatten('F'), x1.T.flatten())
+ def test_dot(self):
+ a = np.array([[1, 0], [0, 1]])
+ b = np.array([[0, 1], [1, 0]])
+ c = np.array([[9, 1], [1, -9]])
+
+ assert_equal(np.dot(a, b), a.dot(b))
+ assert_equal(np.dot(np.dot(a, b), c), a.dot(b).dot(c))
class TestSubscripting(TestCase):
def test_test_zero_rank(self):
diff --git a/numpy/matrixlib/tests/test_defmatrix.py b/numpy/matrixlib/tests/test_defmatrix.py
index cb6bd801d..65d79df0b 100644
--- a/numpy/matrixlib/tests/test_defmatrix.py
+++ b/numpy/matrixlib/tests/test_defmatrix.py
@@ -254,7 +254,8 @@ class TestMatrixReturn(TestCase):
'compress' : ([1],),
'repeat' : (1,),
'reshape' : (1,),
- 'swapaxes' : (0,0)
+ 'swapaxes' : (0,0),
+ 'dot': np.array([1.0]),
}
excluded_methods = [
'argmin', 'choose', 'dump', 'dumps', 'fill', 'getfield',
@@ -267,7 +268,7 @@ class TestMatrixReturn(TestCase):
for attrib in dir(a):
if attrib.startswith('_') or attrib in excluded_methods:
continue
- f = eval('a.%s' % attrib)
+ f = getattr(a, attrib)
if callable(f):
# reset contents of a
a.astype('f8')