summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-05-30 11:04:37 -0600
committerCharles Harris <charlesr.harris@gmail.com>2015-06-04 18:00:50 -0600
commit721675e19dd3f0c798f9582f69ea153484917e34 (patch)
treead51f8875b4904922c0a875136b8bdfb3927d562
parent8c18581110925ebe9e2b3530d4d5fc4afefbc8b0 (diff)
downloadnumpy-721675e19dd3f0c798f9582f69ea153484917e34.tar.gz
ENH: Add a matmul function to multiarray
This is the functional counterpart of the '@' operator that will be available in Python 3.5 with the addition of an out keyword. It operates like the dot function except that - scalar multiplication is not allowed. - multiplication of arrays with more than 2 dimensions broadcasts. The last means that when arrays have more than 2 dimensions they are treated as stacks of matrices and those stacks are broadcast against each other unlike the current behavior of dot that does an outer product. Like dot, matmul is aware of `__numpy_ufunc__` and can be overridden. The current version of the function uses einsum when cblas cannot be used, hence object arrays are not yet supported.
-rw-r--r--numpy/core/numeric.py13
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c169
2 files changed, 175 insertions, 7 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index ea2d4d0a2..bf22f6954 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -43,7 +43,8 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'nditer', 'nested_iters', 'ufunc',
'Inf', 'inf', 'infty', 'Infinity',
'nan', 'NaN', 'False_', 'True_', 'bitwise_not',
'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS',
- 'ComplexWarning', 'may_share_memory', 'full', 'full_like']
+ 'ComplexWarning', 'may_share_memory', 'full', 'full_like',
+ 'matmul']
if sys.version_info[0] < 3:
__all__.extend(['getbuffer', 'newbuffer'])
@@ -390,6 +391,11 @@ lexsort = multiarray.lexsort
compare_chararrays = multiarray.compare_chararrays
putmask = multiarray.putmask
einsum = multiarray.einsum
+dot = multiarray.dot
+inner = multiarray.inner
+vdot = multiarray.vdot
+matmul = multiarray.matmul
+
def asarray(a, dtype=None, order=None):
"""
@@ -1081,11 +1087,6 @@ def outer(a, b, out=None):
b = asarray(b)
return multiply(a.ravel()[:, newaxis], b.ravel()[newaxis,:], out)
-# try to import blas optimized dot if available
-envbak = os.environ.copy()
-dot = multiarray.dot
-inner = multiarray.inner
-vdot = multiarray.vdot
def alterdot():
"""
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 369b5a8e1..b9259214f 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -27,8 +27,8 @@
#include "numpy/npy_math.h"
#include "npy_config.h"
-
#include "npy_pycompat.h"
+#include "npy_import.h"
NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
@@ -2369,6 +2369,170 @@ fail:
}
+
+/*
+ * matmul
+ *
+ * Implements the protocol used by the '@' operator defined in PEP 364.
+ * Not in the NUMPY API at this time, maybe later.
+ *
+ *
+ * in1: Left hand side operand
+ * in2: Right hand side operand
+ * out: Either NULL, or an array into which the output should be placed.
+ *
+ * Returns NULL on error.
+ * Returns NotImplemented on priority override.
+ */
+static PyObject *
+array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds)
+{
+ static PyObject *matmul = NULL;
+ int errval;
+ PyObject *override = NULL;
+ PyObject *in1, *in2, *out = NULL;
+ char* kwlist[] = {"a", "b", "out", NULL };
+ PyArrayObject *ap1, *ap2, *ret = NULL;
+ NPY_ORDER order = NPY_KEEPORDER;
+ NPY_CASTING casting = NPY_SAFE_CASTING;
+ PyArray_Descr *dtype;
+ int nd1, nd2, typenum;
+ char *subscripts;
+ PyArrayObject *ops[2];
+
+ npy_cache_pyfunc("numpy.core.multiarray", "matmul", &matmul);
+ if (matmul == NULL) {
+ return NULL;
+ }
+
+ errval = PyUFunc_CheckOverride(matmul, "__call__",
+ args, kwds, &override, 2);
+ if (errval) {
+ return NULL;
+ }
+ else if (override) {
+ return override;
+ }
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|O", kwlist,
+ &in1, &in2, &out)) {
+ return NULL;
+ }
+
+ if (out == Py_None) {
+ out = NULL;
+ }
+ if (out != NULL && !PyArray_Check(out)) {
+ PyErr_SetString(PyExc_TypeError,
+ "'out' must be an array");
+ return NULL;
+ }
+
+ dtype = PyArray_DescrFromObject(in1, NULL);
+ dtype = PyArray_DescrFromObject(in2, dtype);
+ if (dtype == NULL) {
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot find a common data type.");
+ return NULL;
+ }
+ typenum = dtype->type_num;
+
+ if (typenum == NPY_OBJECT) {
+ /* matmul is not currently implemented for object arrays */
+ PyErr_SetString(PyExc_TypeError,
+ "Object arrays are not currently supported");
+ Py_DECREF(dtype);
+ return NULL;
+ }
+
+ ap1 = (PyArrayObject *)PyArray_FromAny(in1, dtype, 0, 0,
+ NPY_ARRAY_ALIGNED, NULL);
+ if (ap1 == NULL) {
+ return NULL;
+ }
+
+ Py_INCREF(dtype);
+ ap2 = (PyArrayObject *)PyArray_FromAny(in2, dtype, 0, 0,
+ NPY_ARRAY_ALIGNED, NULL);
+ if (ap2 == NULL) {
+ Py_DECREF(ap1);
+ return NULL;
+ }
+
+ if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
+ /* Scalars are rejected */
+ PyErr_SetString(PyExc_ValueError,
+ "Scalar operands are not allowed, use '*' instead");
+ return NULL;
+ }
+
+ nd1 = PyArray_NDIM(ap1);
+ nd2 = PyArray_NDIM(ap2);
+
+#if defined(HAVE_CBLAS)
+ if (nd1 <= 2 && nd2 <= 2 &&
+ (NPY_DOUBLE == typenum || NPY_CDOUBLE == typenum ||
+ NPY_FLOAT == typenum || NPY_CFLOAT == typenum)) {
+ return cblas_matrixproduct(typenum, ap1, ap2, out);
+ }
+#endif
+
+ /*
+ * Use einsum for the stacked cases. This is a quick implementation
+ * to avoid setting up the proper iterators. Einsum broadcasts, so
+ * we need to check dimensions before the call.
+ */
+ if (nd1 == 1 && nd2 == 1) {
+ /* vector vector */
+ if (PyArray_DIM(ap1, 0) != PyArray_DIM(ap2, 0)) {
+ dot_alignment_error(ap1, 0, ap2, 0);
+ goto fail;
+ }
+ subscripts = "i, i";
+ }
+ else if (nd1 == 1) {
+ /* vector matrix */
+ if (PyArray_DIM(ap1, 0) != PyArray_DIM(ap2, nd2 - 2)) {
+ dot_alignment_error(ap1, 0, ap2, nd2 - 2);
+ goto fail;
+ }
+ subscripts = "i, ...ij";
+ }
+ else if (nd2 == 1) {
+ /* matrix vector */
+ if (PyArray_DIM(ap1, nd1 - 1) != PyArray_DIM(ap2, 0)) {
+ dot_alignment_error(ap1, nd1 - 1, ap2, 0);
+ goto fail;
+ }
+ subscripts = "...i, i";
+ }
+ else {
+ /* matrix * matrix */
+ if (PyArray_DIM(ap1, nd1 - 1) != PyArray_DIM(ap2, nd2 - 2)) {
+ dot_alignment_error(ap1, nd1 - 1, ap2, nd2 - 2);
+ goto fail;
+ }
+ subscripts = "...ij, ...jk";
+ }
+ ops[0] = ap1;
+ ops[1] = ap2;
+ ret = PyArray_EinsteinSum(subscripts, 2, ops, NULL, order, casting, out);
+ Py_DECREF(ap1);
+ Py_DECREF(ap2);
+
+ /* If no output was supplied, possibly convert to a scalar */
+ if (ret != NULL && out == NULL) {
+ ret = PyArray_Return((PyArrayObject *)ret);
+ }
+ return (PyObject *)ret;
+
+fail:
+ Py_XDECREF(ap1);
+ Py_XDECREF(ap2);
+ return NULL;
+}
+
+
static int
einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
PyArrayObject **op)
@@ -3936,6 +4100,9 @@ static struct PyMethodDef array_module_methods[] = {
{"vdot",
(PyCFunction)array_vdot,
METH_VARARGS | METH_KEYWORDS, NULL},
+ {"matmul",
+ (PyCFunction)array_matmul,
+ METH_VARARGS | METH_KEYWORDS, NULL},
{"einsum",
(PyCFunction)array_einsum,
METH_VARARGS|METH_KEYWORDS, NULL},