diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 26 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 7 | ||||
-rw-r--r-- | numpy/matrixlib/tests/test_defmatrix.py | 5 |
3 files changed, 36 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index dbc1c9086..d034bc929 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -1793,6 +1793,29 @@ array_cumprod(PyArrayObject *self, PyObject *args, PyObject *kwds) static PyObject * +array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds) +{ + PyObject *b; + static PyObject *numpycore = NULL; + + if (!PyArg_ParseTuple(args, "O", &b)) { + return NULL; + } + + /* Since blas-dot is exposed only on the Python side, we need to grab it + * from there */ + if (numpycore == NULL) { + numpycore = PyImport_ImportModule("numpy.core"); + if (numpycore == NULL) { + return NULL; + } + } + + return PyObject_CallMethod(numpycore, "dot", "OO", self, b); +} + + +static PyObject * array_any(PyArrayObject *self, PyObject *args, PyObject *kwds) { int axis = MAX_DIMS; @@ -2192,6 +2215,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { {"diagonal", (PyCFunction)array_diagonal, METH_VARARGS | METH_KEYWORDS, NULL}, + {"dot", + (PyCFunction)array_dot, + METH_VARARGS, NULL}, {"fill", (PyCFunction)array_fill, METH_VARARGS, NULL}, diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index b8ea0a240..460261694 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -541,6 +541,13 @@ class TestMethods(TestCase): assert_equal(x1.flatten('F'), y1f) assert_equal(x1.flatten('F'), x1.T.flatten()) + def test_dot(self): + a = np.array([[1, 0], [0, 1]]) + b = np.array([[0, 1], [1, 0]]) + c = np.array([[9, 1], [1, -9]]) + + assert_equal(np.dot(a, b), a.dot(b)) + assert_equal(np.dot(np.dot(a, b), c), a.dot(b).dot(c)) class TestSubscripting(TestCase): def test_test_zero_rank(self): diff --git a/numpy/matrixlib/tests/test_defmatrix.py b/numpy/matrixlib/tests/test_defmatrix.py index cb6bd801d..65d79df0b 100644 --- a/numpy/matrixlib/tests/test_defmatrix.py +++ b/numpy/matrixlib/tests/test_defmatrix.py @@ -254,7 +254,8 @@ class TestMatrixReturn(TestCase): 'compress' : ([1],), 'repeat' : (1,), 'reshape' : (1,), - 'swapaxes' : (0,0) + 'swapaxes' : (0,0), + 'dot': np.array([1.0]), } excluded_methods = [ 'argmin', 'choose', 'dump', 'dumps', 'fill', 'getfield', @@ -267,7 +268,7 @@ class TestMatrixReturn(TestCase): for attrib in dir(a): if attrib.startswith('_') or attrib in excluded_methods: continue - f = eval('a.%s' % attrib) + f = getattr(a, attrib) if callable(f): # reset contents of a a.astype('f8') |