summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
authormattip <matti.picus@gmail.com>2018-10-19 15:38:53 +0300
committermattip <matti.picus@gmail.com>2018-11-25 12:23:35 -0600
commit8c9450a7fd69d5b74b47ffec60b5c235361daeff (patch)
treec8d86a390ff4c167dcf1ff84d17cf189e6ed17c6 /numpy/core/src
parentd7e411bcbf4fb8279b4a8485517fd38ce6eb43a9 (diff)
downloadnumpy-8c9450a7fd69d5b74b47ffec60b5c235361daeff.tar.gz
ENH: make matmul into a ufunc
MAINT: fixes from review
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/common/cblasfuncs.c2
-rw-r--r--numpy/core/src/common/cblasfuncs.h2
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c155
-rw-r--r--numpy/core/src/multiarray/number.c10
-rw-r--r--numpy/core/src/multiarray/number.h1
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src16
-rw-r--r--numpy/core/src/umath/matmul.c.src403
-rw-r--r--numpy/core/src/umath/matmul.h.src12
-rw-r--r--numpy/core/src/umath/scalarmath.c.src17
-rw-r--r--numpy/core/src/umath/ufunc_object.c12
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.c87
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.h6
12 files changed, 538 insertions, 185 deletions
diff --git a/numpy/core/src/common/cblasfuncs.c b/numpy/core/src/common/cblasfuncs.c
index 6460c5db1..514297940 100644
--- a/numpy/core/src/common/cblasfuncs.c
+++ b/numpy/core/src/common/cblasfuncs.c
@@ -182,7 +182,7 @@ _select_matrix_shape(PyArrayObject *array)
* This also makes sure that the data segment is aligned with
* an itemsize address as well by returning one if not true.
*/
-static int
+NPY_NO_EXPORT int
_bad_strides(PyArrayObject *ap)
{
int itemsize = PyArray_ITEMSIZE(ap);
diff --git a/numpy/core/src/common/cblasfuncs.h b/numpy/core/src/common/cblasfuncs.h
index 66ce4ca5b..78eff25a0 100644
--- a/numpy/core/src/common/cblasfuncs.h
+++ b/numpy/core/src/common/cblasfuncs.h
@@ -3,5 +3,7 @@
NPY_NO_EXPORT PyObject *
cblas_matrixproduct(int, PyArrayObject *, PyArrayObject *, PyArrayObject *);
+NPY_NO_EXPORT int
+_bad_strides(PyArrayObject *ap);
#endif
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 909a24359..5ccb7f6d6 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -72,7 +72,6 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
*****************************************************************************
*/
#include "funcs.inc"
-#include "loops.h"
#include "umathmodule.h"
NPY_NO_EXPORT int initscalarmath(PyObject *);
@@ -2318,157 +2317,6 @@ fail:
return NULL;
}
-
-
-/*
- * matmul
- *
- * Implements the protocol used by the '@' operator defined in PEP 364.
- * Not in the NUMPY API at this time, maybe later.
- *
- *
- * in1: Left hand side operand
- * in2: Right hand side operand
- * out: Either NULL, or an array into which the output should be placed.
- *
- * Returns NULL on error.
- */
-static PyObject *
-array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds)
-{
- PyObject *in1, *in2, *out = NULL;
- char* kwlist[] = {"a", "b", "out", NULL };
- PyArrayObject *ap1, *ap2, *ret = NULL;
- NPY_ORDER order = NPY_KEEPORDER;
- NPY_CASTING casting = NPY_SAFE_CASTING;
- PyArray_Descr *dtype;
- int nd1, nd2, typenum;
- char *subscripts;
- PyArrayObject *ops[2];
-
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|O:matmul", kwlist,
- &in1, &in2, &out)) {
- return NULL;
- }
-
- if (out != NULL) {
- if (out == Py_None) {
- out = NULL;
- }
- else if (!PyArray_Check(out)) {
- PyErr_SetString(PyExc_TypeError, "'out' must be an array");
- return NULL;
- }
- }
-
- dtype = PyArray_DescrFromObject(in1, NULL);
- dtype = PyArray_DescrFromObject(in2, dtype);
- if (dtype == NULL) {
- if (!PyErr_Occurred()) {
- PyErr_SetString(PyExc_ValueError,
- "Cannot find a common data type.");
- }
- return NULL;
- }
- typenum = dtype->type_num;
-
- if (typenum == NPY_OBJECT) {
- /* matmul is not currently implemented for object arrays */
- PyErr_SetString(PyExc_TypeError,
- "Object arrays are not currently supported");
- Py_DECREF(dtype);
- return NULL;
- }
-
- ap1 = (PyArrayObject *)PyArray_FromAny(in1, dtype, 0, 0,
- NPY_ARRAY_ALIGNED, NULL);
- if (ap1 == NULL) {
- return NULL;
- }
-
- Py_INCREF(dtype);
- ap2 = (PyArrayObject *)PyArray_FromAny(in2, dtype, 0, 0,
- NPY_ARRAY_ALIGNED, NULL);
- if (ap2 == NULL) {
- Py_DECREF(ap1);
- return NULL;
- }
-
- if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
- /* Scalars are rejected */
- PyErr_SetString(PyExc_ValueError,
- "Scalar operands are not allowed, use '*' instead");
- return NULL;
- }
-
- nd1 = PyArray_NDIM(ap1);
- nd2 = PyArray_NDIM(ap2);
-
-#if defined(HAVE_CBLAS)
- if (nd1 <= 2 && nd2 <= 2 &&
- (NPY_DOUBLE == typenum || NPY_CDOUBLE == typenum ||
- NPY_FLOAT == typenum || NPY_CFLOAT == typenum)) {
- return cblas_matrixproduct(typenum, ap1, ap2, (PyArrayObject *)out);
- }
-#endif
-
- /*
- * Use einsum for the stacked cases. This is a quick implementation
- * to avoid setting up the proper iterators. Einsum broadcasts, so
- * we need to check dimensions before the call.
- */
- if (nd1 == 1 && nd2 == 1) {
- /* vector vector */
- if (PyArray_DIM(ap1, 0) != PyArray_DIM(ap2, 0)) {
- dot_alignment_error(ap1, 0, ap2, 0);
- goto fail;
- }
- subscripts = "i, i";
- }
- else if (nd1 == 1) {
- /* vector matrix */
- if (PyArray_DIM(ap1, 0) != PyArray_DIM(ap2, nd2 - 2)) {
- dot_alignment_error(ap1, 0, ap2, nd2 - 2);
- goto fail;
- }
- subscripts = "i, ...ij";
- }
- else if (nd2 == 1) {
- /* matrix vector */
- if (PyArray_DIM(ap1, nd1 - 1) != PyArray_DIM(ap2, 0)) {
- dot_alignment_error(ap1, nd1 - 1, ap2, 0);
- goto fail;
- }
- subscripts = "...i, i";
- }
- else {
- /* matrix * matrix */
- if (PyArray_DIM(ap1, nd1 - 1) != PyArray_DIM(ap2, nd2 - 2)) {
- dot_alignment_error(ap1, nd1 - 1, ap2, nd2 - 2);
- goto fail;
- }
- subscripts = "...ij, ...jk";
- }
- ops[0] = ap1;
- ops[1] = ap2;
- ret = PyArray_EinsteinSum(subscripts, 2, ops, NULL, order, casting,
- (PyArrayObject *)out);
- Py_DECREF(ap1);
- Py_DECREF(ap2);
-
- /* If no output was supplied, possibly convert to a scalar */
- if (ret != NULL && out == NULL) {
- return PyArray_Return((PyArrayObject *)ret);
- }
- return (PyObject *)ret;
-
-fail:
- Py_XDECREF(ap1);
- Py_XDECREF(ap2);
- return NULL;
-}
-
-
static int
einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
PyArrayObject **op)
@@ -4285,9 +4133,6 @@ static struct PyMethodDef array_module_methods[] = {
{"vdot",
(PyCFunction)array_vdot,
METH_VARARGS | METH_KEYWORDS, NULL},
- {"matmul",
- (PyCFunction)array_matmul,
- METH_VARARGS | METH_KEYWORDS, NULL},
{"c_einsum",
(PyCFunction)array_einsum,
METH_VARARGS|METH_KEYWORDS, NULL},
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 5ee536d4f..d153a8a64 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -112,6 +112,7 @@ _PyArray_SetNumericOps(PyObject *dict)
SET(minimum);
SET(rint);
SET(conjugate);
+ SET(matmul);
return 0;
}
@@ -177,6 +178,7 @@ _PyArray_GetNumericOps(void)
GET(minimum);
GET(rint);
GET(conjugate);
+ GET(matmul);
return dict;
fail:
@@ -382,14 +384,8 @@ array_divmod(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_matrix_multiply(PyArrayObject *m1, PyObject *m2)
{
- static PyObject *matmul = NULL;
-
- npy_cache_import("numpy.core.multiarray", "matmul", &matmul);
- if (matmul == NULL) {
- return NULL;
- }
BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_matrix_multiply, array_matrix_multiply);
- return PyArray_GenericBinaryFunction(m1, m2, matmul);
+ return PyArray_GenericBinaryFunction(m1, m2, n_ops.matmul);
}
static PyObject *
diff --git a/numpy/core/src/multiarray/number.h b/numpy/core/src/multiarray/number.h
index fbdfe6f94..33a7cf872 100644
--- a/numpy/core/src/multiarray/number.h
+++ b/numpy/core/src/multiarray/number.h
@@ -39,6 +39,7 @@ typedef struct {
PyObject *minimum;
PyObject *rint;
PyObject *conjugate;
+ PyObject *matmul;
} NumericOps;
extern NPY_NO_EXPORT NumericOps n_ops;
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 0f201b966..2f71c8ae9 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1104,8 +1104,7 @@ static PyNumberMethods gentype_as_number = {
(binaryfunc)gentype_add, /*nb_add*/
(binaryfunc)gentype_subtract, /*nb_subtract*/
(binaryfunc)gentype_multiply, /*nb_multiply*/
-#if defined(NPY_PY3K)
-#else
+#if !defined(NPY_PY3K)
(binaryfunc)gentype_divide, /*nb_divide*/
#endif
(binaryfunc)gentype_remainder, /*nb_remainder*/
@@ -1121,8 +1120,7 @@ static PyNumberMethods gentype_as_number = {
(binaryfunc)gentype_and, /*nb_and*/
(binaryfunc)gentype_xor, /*nb_xor*/
(binaryfunc)gentype_or, /*nb_or*/
-#if defined(NPY_PY3K)
-#else
+#if !defined(NPY_PY3K)
0, /*nb_coerce*/
#endif
(unaryfunc)gentype_int, /*nb_int*/
@@ -1132,16 +1130,14 @@ static PyNumberMethods gentype_as_number = {
(unaryfunc)gentype_long, /*nb_long*/
#endif
(unaryfunc)gentype_float, /*nb_float*/
-#if defined(NPY_PY3K)
-#else
+#if !defined(NPY_PY3K)
(unaryfunc)gentype_oct, /*nb_oct*/
(unaryfunc)gentype_hex, /*nb_hex*/
#endif
0, /*inplace_add*/
0, /*inplace_subtract*/
0, /*inplace_multiply*/
-#if defined(NPY_PY3K)
-#else
+#if !defined(NPY_PY3K)
0, /*inplace_divide*/
#endif
0, /*inplace_remainder*/
@@ -1156,6 +1152,10 @@ static PyNumberMethods gentype_as_number = {
0, /*nb_inplace_floor_divide*/
0, /*nb_inplace_true_divide*/
(unaryfunc)NULL, /*nb_index*/
+#if PY_VERSION_HEX >= 0x03050000
+ 0, /*np_matmul*/
+ 0, /*np_inplace_matmul*/
+#endif
};
diff --git a/numpy/core/src/umath/matmul.c.src b/numpy/core/src/umath/matmul.c.src
new file mode 100644
index 000000000..5f5bd3506
--- /dev/null
+++ b/numpy/core/src/umath/matmul.c.src
@@ -0,0 +1,403 @@
+/* -*- c -*- */
+
+#define _UMATHMODULE
+#define _MULTIARRAYMODULE
+#define NPY_NO_DEPRECATED_API NPY_API_VERSION
+
+#include "Python.h"
+
+#include "npy_config.h"
+#include "numpy/npy_common.h"
+#include "numpy/arrayobject.h"
+#include "numpy/ufuncobject.h"
+#include "numpy/npy_math.h"
+#include "numpy/halffloat.h"
+#include "lowlevel_strided_loops.h"
+
+#include "npy_pycompat.h"
+
+#include "npy_cblas.h"
+#include "arraytypes.h" /* For TYPE_dot functions */
+#include <assert.h>
+
+/*
+ *****************************************************************************
+ ** BASICS **
+ *****************************************************************************
+ */
+
+#if defined(HAVE_CBLAS)
+static const npy_cdouble oneD = {1.0, 0.0}, zeroD = {0.0, 0.0};
+static const npy_cfloat oneF = {1.0, 0.0}, zeroF = {0.0, 0.0};
+
+/**begin repeat
+ *
+ * #name = FLOAT, DOUBLE, CFLOAT, CDOUBLE#
+ * #ctype = npy_float, npy_double, npy_cfloat, npy_cdouble#
+ * #type = npy_float, npy_double, npy_cfloat, npy_cdouble#
+ * #prefix = s, d, c, z#
+ * #step1 = 1.F, 1., &oneF, &oneD#
+ * #step0 = 0.F, 0., &zeroF, &zeroD#
+ */
+NPY_NO_EXPORT void
+@name@_gemv(void *ip1, npy_intp is1_m, void *ip2, npy_intp is2_n, void *op,
+ npy_intp m, npy_intp n)
+{
+ /*
+ * Vector matrix multiplication -- Level 2 BLAS
+ * arguments
+ * ip1: contiguous data, m*n shape
+ * ip2: data in c order, n*1 shape
+ * op: contiguous data in c order, m shape
+ */
+ enum CBLAS_ORDER order;
+ int lda;
+
+ if (is1_m == sizeof(@type@)) {
+ order = CblasColMajor;
+ lda = n;
+ }
+ else {
+ /* If not ColMajor, caller should have ensured we are RowMajor */
+ /* will not assert in release mode */
+ assert(is1_m == sizeof(@type@) * m);
+ order = CblasRowMajor;
+ lda = m;
+ }
+ cblas_@prefix@gemv(order, CblasTrans, n, m, @step1@, ip1, lda, ip2,
+ is2_n / sizeof(@type@), @step0@, op, 1);
+}
+
+NPY_NO_EXPORT void
+@name@_matmul_matrixmatrix(void *ip1, npy_intp is1_m, npy_intp is1_n,
+ void *ip2, npy_intp is2_n, npy_intp is2_p,
+ void *op, npy_intp m, npy_intp n, npy_intp p)
+{
+ /*
+ * matrix matrix multiplication -- Level 3 BLAS
+ */
+ enum CBLAS_ORDER order = CblasRowMajor;
+ enum CBLAS_TRANSPOSE trans1, trans2;
+ int M, N, P, lda, ldb;
+ M = m;
+ N = n;
+ P = p;
+
+ if (is1_m == sizeof(@type@)) {
+ trans1 = CblasTrans;
+ lda = N > 1 ? is1_n / sizeof(@type@) : 1;
+ }
+ else {
+ /* If not ColMajor, caller should have ensured we are RowMajor */
+ /* will not assert in release mode */
+ assert(is1_n == sizeof(@type@));
+ trans1 = CblasNoTrans;
+ lda = N > 1 ? is1_m / sizeof(@type@) : 1;
+ }
+ if (is2_n == sizeof(@type@)) {
+ trans2 = CblasTrans;
+ ldb = N > 1 ? is2_p / sizeof(@type@) : 1;
+ }
+ else {
+ /* If not ColMajor, caller should have ensured we are RowMajor */
+ /* will not assert in release mode */
+ assert(is2_p == sizeof(@type@));
+ trans2 = CblasNoTrans;
+ ldb = P > 1 ? P : 1;
+ }
+ /*
+ * Use syrk if we have a case of a matrix times its transpose.
+ * Otherwise, use gemm for all other cases.
+ */
+ if (
+ (ip1 == ip2) &&
+ (m == p) &&
+ (is1_m == is2_p) &&
+ (is1_n == is2_n) &&
+ (trans1 != trans2)
+ ) {
+ npy_intp i,j;
+ if (trans1 == CblasNoTrans) {
+ cblas_@prefix@syrk(order, CblasUpper, trans1, P, N, @step1@,
+ ip1, lda, @step0@, op, P);
+ }
+ else {
+ cblas_@prefix@syrk(order, CblasUpper, trans1, P, N, @step1@,
+ ip1, ldb, @step0@, op, P);
+ }
+ /* Copy the triangle */
+ for (i = 0; i < P; i++) {
+ for (j = i + 1; j < P; j++) {
+ ((@type@*)op)[j * P + i] = ((@type@*)op)[i * P + j];
+ }
+ }
+
+ }
+ else {
+ cblas_@prefix@gemm(order, trans1, trans2, M, P, N, @step1@, ip1, lda,
+ ip2, ldb, @step0@, op, P);
+ }
+}
+
+/**end repeat**/
+#endif
+
+/*
+ * matmul loops
+ * signature is (m?,n),(n,p?)->(m?,p?)
+ */
+
+/**begin repeat
+ * #TYPE = FLOAT, DOUBLE, HALF#
+ * #typ = npy_float,npy_double,npy_half#
+ * #SPECL = 0,0,1#
+ */
+
+NPY_NO_EXPORT void
+@TYPE@_matmul_inner_noblas(char *ip1, char *ip2, char *op,
+ npy_intp dm, npy_intp dn, npy_intp dp,
+ npy_intp ib1_n, npy_intp ib2_n, npy_intp ib2_p,
+ npy_intp ob_p, npy_intp is1_m, npy_intp is1_n,
+ npy_intp is2_n, npy_intp is2_p, npy_intp os_m,
+ npy_intp os_p)
+{
+ npy_intp m, n, p;
+ for (m = 0; m < dm; m++) {
+ for (p = 0; p < dp; p++) {
+ /*
+ * Use a double as an intermediate sum, which is natural for
+ * npy_double, slightly increases the accuracy of npy_float,
+ * and is perhaps overkill for npy_half.
+ */
+ double sum = 0;
+ for (n = 0; n < dn; n++) {
+#if @SPECL@ == 1
+ @typ@ val1 = (*(@typ@ *)ip1);
+ @typ@ val2 = (*(@typ@ *)ip2);
+ sum += npy_half_to_float(val1) * npy_half_to_float(val2);
+#else
+ sum += ((double)*(@typ@ *)ip1) * ((double)*(@typ@ *)ip2);
+#endif
+ ip2 += is2_n;
+ ip1 += is1_n;
+ }
+#if @SPECL@ == 1
+ *(@typ@ *)op = npy_float_to_half((float)sum);
+#else
+ /* in the case of double -> float, may produce INF */
+ *(@typ@ *)op = (@typ@)sum;
+#endif
+ ip1 -= ib1_n;
+ ip2 -= ib2_n;
+ op += os_p;
+ ip2 += is2_p;
+ }
+ op -= ob_p;
+ ip2 -= ib2_p;
+ ip1 += is1_m;
+ op += os_m;
+ }
+}
+
+/**end repeat**/
+
+/**begin repeat
+ * #TYPE = LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * BYTE, SHORT, INT, LONG, LONGLONG,
+ * BOOL#
+ * #typ = npy_longdouble,
+ * npy_cfloat, npy_cdouble, npy_clongdouble,
+ * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
+ * npy_byte, npy_short, npy_int, npy_long, npy_longlong,
+ * npy_bool#
+ * #IS_COMPLEX = 0, 1, 1, 1, 0*11#
+ */
+
+NPY_NO_EXPORT void
+@TYPE@_matmul_inner_noblas(char *ip1, char *ip2, char *op,
+ npy_intp dm, npy_intp dn, npy_intp dp,
+ npy_intp ib1_n, npy_intp ib2_n, npy_intp ib2_p,
+ npy_intp ob_p, npy_intp is1_m, npy_intp is1_n,
+ npy_intp is2_n, npy_intp is2_p, npy_intp os_m,
+ npy_intp os_p)
+{
+ npy_intp m, n, p;
+ for (m = 0; m < dm; m++) {
+ for (p = 0; p < dp; p++) {
+#if @IS_COMPLEX@ == 1
+ (*(@typ@ *)op).real = 0;
+ (*(@typ@ *)op).imag = 0;
+#else
+ *(@typ@ *)op = 0;
+#endif
+ for (n = 0; n < dn; n++) {
+ @typ@ val1 = (*(@typ@ *)ip1);
+ @typ@ val2 = (*(@typ@ *)ip2);
+#if @IS_COMPLEX@ == 1
+ (*(@typ@ *)op).real += (val1.real * val2.real) -
+ (val1.imag * val2.imag);
+ (*(@typ@ *)op).imag += (val1.real * val2.imag) +
+ (val1.imag * val2.real);
+#else
+ *(@typ@ *)op += val1 * val2;
+#endif
+ ip2 += is2_n;
+ ip1 += is1_n;
+ }
+ ip1 -= ib1_n;
+ ip2 -= ib2_n;
+ op += os_p;
+ ip2 += is2_p;
+ }
+ op -= ob_p;
+ ip2 -= ib2_p;
+ ip1 += is1_m;
+ op += os_m;
+ }
+}
+
+/**end repeat**/
+
+/**begin repeat
+ * #TYPE = FLOAT, DOUBLE, LONGDOUBLE, HALF,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * BYTE, SHORT, INT, LONG, LONGLONG,
+ * BOOL#
+ * #typ = npy_float,npy_double,npy_longdouble, npy_half,
+ * npy_cfloat, npy_cdouble, npy_clongdouble,
+ * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
+ * npy_byte, npy_short, npy_int, npy_long, npy_longlong,
+ * npy_bool#
+ * #SPECL = 0, 0, 0, 2, 1, 1, 1, 0*11#
+ * #USEBLAS = 1, 1, 0, 0, 1, 1, 0*12#
+ * #chr = s, d, 0, 0, c, z, 0*12#
+ * #blas_typ = npy_float, npy_double, void*16#
+ */
+
+
+NPY_NO_EXPORT void
+@TYPE@_matmul(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func))
+{
+ npy_intp dOuter = *dimensions++;
+ npy_intp iOuter;
+ npy_intp s0 = *steps++;
+ npy_intp s1 = *steps++;
+ npy_intp s2 = *steps++;
+ npy_intp dm = dimensions[0];
+ npy_intp dn = dimensions[1];
+ npy_intp dp = dimensions[2];
+ npy_intp is1_m=steps[0], is1_n=steps[1], is2_n=steps[2], is2_p=steps[3],
+ os_m=steps[4], os_p=steps[5];
+ npy_intp ib1_n, ib2_n, ib2_p, ob_p;
+#if @USEBLAS@ & defined(HAVE_CBLAS)
+ npy_bool special_case = (dm == 1 || dn == 1 || dp == 1);
+ npy_bool scalar_out = (dm ==1 && dp == 1);
+ npy_bool scalar_vec = (dn == 1 && (dp == 1 || dm == 1));
+ npy_bool too_big_for_blas = (dm > NPY_MAX_INT || dn > NPY_MAX_INT ||
+ dp >= NPY_MAX_INT);
+ npy_bool input_contiguous = ((is1_m == sizeof(@typ@) ||
+ is1_n == sizeof(@typ@)) &&
+ (is2_n == sizeof(@typ@) ||
+ is2_p == sizeof(@typ@)));
+ npy_bool vector_matrix = ((dm == 1) &&
+ (is2_n == sizeof(@typ@) || (is2_n == sizeof(@typ@) * dp)));
+ npy_bool matrix_vector = ((dp == 1) &&
+ (is1_n == sizeof(@typ@) || (is1_n == sizeof(@typ@) * dm)));
+#endif
+
+ ib1_n = is1_n*dn;
+ ib2_n = is2_n*dn;
+ ib2_p = is2_p*dp;
+ ob_p = os_p *dp;
+
+ if (dn == 0) {
+ /* No operand, need to zero the output */
+ for (iOuter = 0; iOuter < dOuter; iOuter++,
+ args[0] += s0, args[1] += s1, args[2] += s2) {
+ npy_intp m, p;
+ char *op=args[2];
+ for (m = 0; m < dm; m++) {
+ for (p = 0; p < dp; p++) {
+#if @SPECL@ == 1
+ (*(@typ@ *)op).real = 0;
+ (*(@typ@ *)op).imag = 0;
+#else
+ *(@typ@ *)op = 0;
+#endif
+ op += os_p;
+ }
+ op += os_m - ob_p;
+ }
+ }
+ return;
+ }
+ for (iOuter = 0; iOuter < dOuter; iOuter++,
+ args[0] += s0, args[1] += s1, args[2] += s2) {
+ char *ip1=args[0], *ip2=args[1], *op=args[2];
+#if @USEBLAS@ & defined(HAVE_CBLAS)
+ /*
+ * TODO: refactor this out to a inner_loop_selector, in
+ * PyUFunc_MatmulLoopSelector. But that call does not have access to
+ * n, m, p and strides.
+ */
+ if (too_big_for_blas) {
+ @TYPE@_matmul_inner_noblas(ip1, ip2, op, dm, dn, dp,
+ ib1_n, ib2_n, ib2_p, ob_p,
+ is1_m, is1_n, is2_n, is2_p, os_m, os_p);
+ }
+ else if (special_case) {
+ /* Special case variants that have a 1 in the core dimensions */
+ if (scalar_out) {
+ /* row @ column, 1,1 output */
+ @TYPE@_dot(ip1, is1_n, ip2, is2_n, op, dn, NULL);
+ } else if (scalar_vec){
+ /*
+ * 0d @ vector or vector @ 0d
+ * could use cblas_Xaxy, but that requires 0ing output
+ * and would not be faster (XXX prove it)
+ */
+ @TYPE@_matmul_inner_noblas(ip1, ip2, op, dm, dn, dp,
+ ib1_n, ib2_n, ib2_p, ob_p,
+ is1_m, is1_n, is2_n, is2_p, os_m, os_p);
+ } else if (vector_matrix) {
+ /* vector @ matrix, switch ip1, ip2, p and m */
+ @TYPE@_gemv((void*)ip2, is2_n, (void*)ip1, is1_n,
+ (void*)op, dp, dn);
+ } else if (matrix_vector) {
+ /* matrix @ vector */
+ @TYPE@_gemv((void*)ip1, is1_n, (void*)ip2, is2_n,
+ (void*)op, dm, dn);
+ } else {
+ /* column @ row, 2d output, no blas needed or non-contiguous input */
+ @TYPE@_matmul_inner_noblas(ip1, ip2, op, dm, dn, dp,
+ ib1_n, ib2_n, ib2_p, ob_p,
+ is1_m, is1_n, is2_n, is2_p, os_m, os_p);
+ }
+ } else {
+ /* matrix @ matrix */
+ if (input_contiguous) {
+ /* can only use blas if input is contiguous */
+ @TYPE@_matmul_matrixmatrix((void*)ip1, is1_m, is1_n,
+ (void*)ip2, is2_n, is2_p,
+ (void*)op, dm, dn, dp);
+ } else {
+ @TYPE@_matmul_inner_noblas(ip1, ip2, op, dm, dn, dp,
+ ib1_n, ib2_n, ib2_p, ob_p,
+ is1_m, is1_n, is2_n, is2_p, os_m, os_p);
+ }
+ }
+#else
+ @TYPE@_matmul_inner_noblas(ip1, ip2, op, dm, dn, dp,
+ ib1_n, ib2_n, ib2_p, ob_p,
+ is1_m, is1_n, is2_n, is2_p, os_m, os_p);
+
+#endif
+ }
+}
+
+/**end repeat**/
+
+
diff --git a/numpy/core/src/umath/matmul.h.src b/numpy/core/src/umath/matmul.h.src
new file mode 100644
index 000000000..16be7675b
--- /dev/null
+++ b/numpy/core/src/umath/matmul.h.src
@@ -0,0 +1,12 @@
+/**begin repeat
+ * #TYPE = FLOAT, DOUBLE, LONGDOUBLE, HALF,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * BYTE, SHORT, INT, LONG, LONGLONG,
+ * BOOL#
+ **/
+NPY_NO_EXPORT void
+@TYPE@_matmul(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func));
+/**end repeat**/
+
+
diff --git a/numpy/core/src/umath/scalarmath.c.src b/numpy/core/src/umath/scalarmath.c.src
index e98d9f865..a7987acda 100644
--- a/numpy/core/src/umath/scalarmath.c.src
+++ b/numpy/core/src/umath/scalarmath.c.src
@@ -1564,7 +1564,6 @@ static PyObject*
}
/**end repeat**/
-
/**begin repeat
* #name = byte, ubyte, short, ushort, int, uint,
* long, ulong, longlong, ulonglong,
@@ -1575,8 +1574,7 @@ static PyNumberMethods @name@_as_number = {
(binaryfunc)@name@_add, /*nb_add*/
(binaryfunc)@name@_subtract, /*nb_subtract*/
(binaryfunc)@name@_multiply, /*nb_multiply*/
-#if defined(NPY_PY3K)
-#else
+#if !defined(NPY_PY3K)
(binaryfunc)@name@_divide, /*nb_divide*/
#endif
(binaryfunc)@name@_remainder, /*nb_remainder*/
@@ -1596,8 +1594,7 @@ static PyNumberMethods @name@_as_number = {
(binaryfunc)@name@_and, /*nb_and*/
(binaryfunc)@name@_xor, /*nb_xor*/
(binaryfunc)@name@_or, /*nb_or*/
-#if defined(NPY_PY3K)
-#else
+#if !defined(NPY_PY3K)
0, /*nb_coerce*/
#endif
(unaryfunc)@name@_int, /*nb_int*/
@@ -1607,16 +1604,14 @@ static PyNumberMethods @name@_as_number = {
(unaryfunc)@name@_long, /*nb_long*/
#endif
(unaryfunc)@name@_float, /*nb_float*/
-#if defined(NPY_PY3K)
-#else
+#if !defined(NPY_PY3K)
(unaryfunc)@name@_oct, /*nb_oct*/
(unaryfunc)@name@_hex, /*nb_hex*/
#endif
0, /*inplace_add*/
0, /*inplace_subtract*/
0, /*inplace_multiply*/
-#if defined(NPY_PY3K)
-#else
+#if !defined(NPY_PY3K)
0, /*inplace_divide*/
#endif
0, /*inplace_remainder*/
@@ -1631,6 +1626,10 @@ static PyNumberMethods @name@_as_number = {
0, /*nb_inplace_floor_divide*/
0, /*nb_inplace_true_divide*/
(unaryfunc)NULL, /*nb_index*/
+#if PY_VERSION_HEX >= 0x03050000
+ 0, /*nb_matrix_multiply*/
+ 0, /*nb_inplace_matrix_multiply*/
+#endif
};
/**end repeat**/
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 1fe8745a0..66f512f7b 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -2845,13 +2845,15 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
}
/* Fill in any allocated outputs */
- for (i = nin; i < nop; ++i) {
- if (op[i] == NULL) {
- op[i] = NpyIter_GetOperandArray(iter)[i];
- Py_INCREF(op[i]);
+ {
+ PyArrayObject **operands = NpyIter_GetOperandArray(iter);
+ for (i = 0; i < nop; ++i) {
+ if (op[i] == NULL) {
+ op[i] = operands[i];
+ Py_INCREF(op[i]);
+ }
}
}
-
/*
* Set up the inner strides array. Because we're not doing
* buffering, the strides are fixed throughout the looping.
diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c
index 6b042d837..efd923972 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.c
+++ b/numpy/core/src/umath/ufunc_type_resolution.c
@@ -22,6 +22,11 @@
#include "ufunc_object.h"
#include "common.h"
+#include "mem_overlap.h"
+#if defined(HAVE_CBLAS)
+#include "cblasfuncs.h"
+#endif
+
static const char *
npy_casting_to_string(NPY_CASTING casting)
{
@@ -1299,6 +1304,88 @@ PyUFunc_MixedDivisionTypeResolver(PyUFuncObject *ufunc,
type_tup, out_dtypes);
}
+/*
+ * XXX This is too restrictive, we should only check the inner axis involved
+ * gh-12365
+ */
+#define NOT_CONTIGUOUS(a) (PyArray_NDIM(a) >=2 && !( \
+ PyArray_IS_C_CONTIGUOUS(a) || PyArray_IS_F_CONTIGUOUS(a)))
+
+/*
+ * This function applies special type resolution rules for the case
+ * where all the functions have the pattern XX->X, using
+ * PyArray_ResultType instead of a linear search to get the best
+ * loop, like PyUFunc_SimpleBinaryOperationTypeResolver, and adds
+ * memory overlap and contiguity considerations to the operands,
+ * possibly creating Writeback temporary data
+ *
+ * Returns 0 on success, -1 on error.
+ */
+NPY_NO_EXPORT int
+PyUFunc_MatmulTypeResolver(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes)
+{
+
+ if (PyUFunc_SimpleBinaryOperationTypeResolver(ufunc, casting, operands,
+ type_tup, out_dtypes) < 0) {
+ return -1;
+ }
+ if (PyArray_NDIM(operands[0]) >= 2 || PyArray_NDIM(operands[1]) >= 2) {
+#if defined(HAVE_CBLAS)
+ int typenum = out_dtypes[2]->type_num;
+ if ( (NPY_DOUBLE == typenum || NPY_CDOUBLE == typenum ||
+ NPY_FLOAT == typenum || NPY_CFLOAT == typenum)) {
+ /*
+ * We are going to use BLAS
+ * make sure 2d and more arrays are contiguous
+ */
+ if (_bad_strides(operands[0]) || NOT_CONTIGUOUS(operands[0])) {
+ PyObject *op = PyArray_NewCopy(operands[0], NPY_ANYORDER);
+
+ if (op == NULL) {
+ return -1;
+ }
+ Py_DECREF(operands[0]);
+ operands[0] = (PyArrayObject *)op;
+ }
+ if (_bad_strides(operands[1]) || NOT_CONTIGUOUS(operands[1])) {
+ PyObject *op = PyArray_NewCopy(operands[1], NPY_ANYORDER);
+
+ if (op == NULL) {
+ return -1;
+ }
+ Py_DECREF(operands[1]);
+ operands[1] = (PyArrayObject *)op;
+ }
+ }
+#endif
+ if (operands[2] != NULL) {
+ npy_intp last_stride = PyArray_STRIDE(operands[2],
+ PyArray_NDIM(operands[2]) - 1);
+ if (last_stride != PyArray_ITEMSIZE(operands[2])) {
+ PyErr_SetString(PyExc_ValueError,
+ "output array is not acceptable (must be C-contiguous)");
+ return -1;
+ }
+ /*
+ * Use all the flags in PyUFunc_GeneralizedFunction
+ * except NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE and
+ * NPY_ITER_WRITEONLY. Without NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE
+ * nditer will properly check overlap between output and any input.
+ */
+ ufunc->op_flags[2] = NPY_ITER_WRITEONLY |
+ NPY_ITER_UPDATEIFCOPY |
+ NPY_ITER_ALIGNED |
+ NPY_ITER_ALLOCATE |
+ NPY_ITER_NO_BROADCAST;
+ }
+ }
+ return 0;
+}
+
static int
find_userloop(PyUFuncObject *ufunc,
diff --git a/numpy/core/src/umath/ufunc_type_resolution.h b/numpy/core/src/umath/ufunc_type_resolution.h
index bb4823d24..5306a5983 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.h
+++ b/numpy/core/src/umath/ufunc_type_resolution.h
@@ -145,5 +145,11 @@ PyUFunc_DefaultMaskedInnerLoopSelector(PyUFuncObject *ufunc,
NpyAuxData **out_innerloopdata,
int *out_needs_api);
+NPY_NO_EXPORT int
+PyUFunc_MatmulTypeResolver(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes);
#endif