diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2015-05-30 11:04:37 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-06-04 18:00:50 -0600 |
commit | 721675e19dd3f0c798f9582f69ea153484917e34 (patch) | |
tree | ad51f8875b4904922c0a875136b8bdfb3927d562 | |
parent | 8c18581110925ebe9e2b3530d4d5fc4afefbc8b0 (diff) | |
download | numpy-721675e19dd3f0c798f9582f69ea153484917e34.tar.gz |
ENH: Add a matmul function to multiarray
This is the functional counterpart of the '@' operator that will be
available in Python 3.5 with the addition of an out keyword. It
operates like the dot function except that
- scalar multiplication is not allowed.
- multiplication of arrays with more than 2 dimensions broadcasts.
The last means that when arrays have more than 2 dimensions they are
treated as stacks of matrices and those stacks are broadcast against
each other unlike the current behavior of dot that does an outer
product. Like dot, matmul is aware of `__numpy_ufunc__` and can be
overridden.
The current version of the function uses einsum when cblas cannot be
used, hence object arrays are not yet supported.
-rw-r--r-- | numpy/core/numeric.py | 13 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 169 |
2 files changed, 175 insertions, 7 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index ea2d4d0a2..bf22f6954 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -43,7 +43,8 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'nditer', 'nested_iters', 'ufunc', 'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_', 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS', - 'ComplexWarning', 'may_share_memory', 'full', 'full_like'] + 'ComplexWarning', 'may_share_memory', 'full', 'full_like', + 'matmul'] if sys.version_info[0] < 3: __all__.extend(['getbuffer', 'newbuffer']) @@ -390,6 +391,11 @@ lexsort = multiarray.lexsort compare_chararrays = multiarray.compare_chararrays putmask = multiarray.putmask einsum = multiarray.einsum +dot = multiarray.dot +inner = multiarray.inner +vdot = multiarray.vdot +matmul = multiarray.matmul + def asarray(a, dtype=None, order=None): """ @@ -1081,11 +1087,6 @@ def outer(a, b, out=None): b = asarray(b) return multiply(a.ravel()[:, newaxis], b.ravel()[newaxis,:], out) -# try to import blas optimized dot if available -envbak = os.environ.copy() -dot = multiarray.dot -inner = multiarray.inner -vdot = multiarray.vdot def alterdot(): """ diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 369b5a8e1..b9259214f 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -27,8 +27,8 @@ #include "numpy/npy_math.h" #include "npy_config.h" - #include "npy_pycompat.h" +#include "npy_import.h" NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0; @@ -2369,6 +2369,170 @@ fail: } + +/* + * matmul + * + * Implements the protocol used by the '@' operator defined in PEP 364. + * Not in the NUMPY API at this time, maybe later. + * + * + * in1: Left hand side operand + * in2: Right hand side operand + * out: Either NULL, or an array into which the output should be placed. + * + * Returns NULL on error. + * Returns NotImplemented on priority override. + */ +static PyObject * +array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds) +{ + static PyObject *matmul = NULL; + int errval; + PyObject *override = NULL; + PyObject *in1, *in2, *out = NULL; + char* kwlist[] = {"a", "b", "out", NULL }; + PyArrayObject *ap1, *ap2, *ret = NULL; + NPY_ORDER order = NPY_KEEPORDER; + NPY_CASTING casting = NPY_SAFE_CASTING; + PyArray_Descr *dtype; + int nd1, nd2, typenum; + char *subscripts; + PyArrayObject *ops[2]; + + npy_cache_pyfunc("numpy.core.multiarray", "matmul", &matmul); + if (matmul == NULL) { + return NULL; + } + + errval = PyUFunc_CheckOverride(matmul, "__call__", + args, kwds, &override, 2); + if (errval) { + return NULL; + } + else if (override) { + return override; + } + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|O", kwlist, + &in1, &in2, &out)) { + return NULL; + } + + if (out == Py_None) { + out = NULL; + } + if (out != NULL && !PyArray_Check(out)) { + PyErr_SetString(PyExc_TypeError, + "'out' must be an array"); + return NULL; + } + + dtype = PyArray_DescrFromObject(in1, NULL); + dtype = PyArray_DescrFromObject(in2, dtype); + if (dtype == NULL) { + PyErr_SetString(PyExc_ValueError, + "Cannot find a common data type."); + return NULL; + } + typenum = dtype->type_num; + + if (typenum == NPY_OBJECT) { + /* matmul is not currently implemented for object arrays */ + PyErr_SetString(PyExc_TypeError, + "Object arrays are not currently supported"); + Py_DECREF(dtype); + return NULL; + } + + ap1 = (PyArrayObject *)PyArray_FromAny(in1, dtype, 0, 0, + NPY_ARRAY_ALIGNED, NULL); + if (ap1 == NULL) { + return NULL; + } + + Py_INCREF(dtype); + ap2 = (PyArrayObject *)PyArray_FromAny(in2, dtype, 0, 0, + NPY_ARRAY_ALIGNED, NULL); + if (ap2 == NULL) { + Py_DECREF(ap1); + return NULL; + } + + if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) { + /* Scalars are rejected */ + PyErr_SetString(PyExc_ValueError, + "Scalar operands are not allowed, use '*' instead"); + return NULL; + } + + nd1 = PyArray_NDIM(ap1); + nd2 = PyArray_NDIM(ap2); + +#if defined(HAVE_CBLAS) + if (nd1 <= 2 && nd2 <= 2 && + (NPY_DOUBLE == typenum || NPY_CDOUBLE == typenum || + NPY_FLOAT == typenum || NPY_CFLOAT == typenum)) { + return cblas_matrixproduct(typenum, ap1, ap2, out); + } +#endif + + /* + * Use einsum for the stacked cases. This is a quick implementation + * to avoid setting up the proper iterators. Einsum broadcasts, so + * we need to check dimensions before the call. + */ + if (nd1 == 1 && nd2 == 1) { + /* vector vector */ + if (PyArray_DIM(ap1, 0) != PyArray_DIM(ap2, 0)) { + dot_alignment_error(ap1, 0, ap2, 0); + goto fail; + } + subscripts = "i, i"; + } + else if (nd1 == 1) { + /* vector matrix */ + if (PyArray_DIM(ap1, 0) != PyArray_DIM(ap2, nd2 - 2)) { + dot_alignment_error(ap1, 0, ap2, nd2 - 2); + goto fail; + } + subscripts = "i, ...ij"; + } + else if (nd2 == 1) { + /* matrix vector */ + if (PyArray_DIM(ap1, nd1 - 1) != PyArray_DIM(ap2, 0)) { + dot_alignment_error(ap1, nd1 - 1, ap2, 0); + goto fail; + } + subscripts = "...i, i"; + } + else { + /* matrix * matrix */ + if (PyArray_DIM(ap1, nd1 - 1) != PyArray_DIM(ap2, nd2 - 2)) { + dot_alignment_error(ap1, nd1 - 1, ap2, nd2 - 2); + goto fail; + } + subscripts = "...ij, ...jk"; + } + ops[0] = ap1; + ops[1] = ap2; + ret = PyArray_EinsteinSum(subscripts, 2, ops, NULL, order, casting, out); + Py_DECREF(ap1); + Py_DECREF(ap2); + + /* If no output was supplied, possibly convert to a scalar */ + if (ret != NULL && out == NULL) { + ret = PyArray_Return((PyArrayObject *)ret); + } + return (PyObject *)ret; + +fail: + Py_XDECREF(ap1); + Py_XDECREF(ap2); + return NULL; +} + + static int einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts, PyArrayObject **op) @@ -3936,6 +4100,9 @@ static struct PyMethodDef array_module_methods[] = { {"vdot", (PyCFunction)array_vdot, METH_VARARGS | METH_KEYWORDS, NULL}, + {"matmul", + (PyCFunction)array_matmul, + METH_VARARGS | METH_KEYWORDS, NULL}, {"einsum", (PyCFunction)array_einsum, METH_VARARGS|METH_KEYWORDS, NULL}, |