summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/cblasfuncs.c29
1 files changed, 28 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c
index 1789b2caf..516c6e8ae 100644
--- a/numpy/core/src/multiarray/cblasfuncs.c
+++ b/numpy/core/src/multiarray/cblasfuncs.c
@@ -715,7 +715,34 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
Trans2 = CblasTrans;
ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
}
- gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
+
+ /*
+ * Use syrk if we have a case of a matrix times its transpose.
+ * Otherwise, use gemm for all other cases.
+ */
+ if (
+ (PyArray_BYTES(ap1) == PyArray_BYTES(ap2)) &&
+ (PyArray_DIM(ap1, 0) == PyArray_DIM(ap2, 1)) &&
+ (PyArray_DIM(ap1, 1) == PyArray_DIM(ap2, 0)) &&
+ (PyArray_STRIDE(ap1, 0) == PyArray_STRIDE(ap2, 1)) &&
+ (PyArray_STRIDE(ap1, 1) == PyArray_STRIDE(ap2, 0)) &&
+ ((Trans1 == CblasTrans) ^ (Trans2 == CblasTrans)) &&
+ ((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans))
+ )
+ {
+ if (Trans1 == CblasNoTrans)
+ {
+ syrk(typenum, Order, Trans1, N, M, ap1, lda, ret);
+ }
+ else
+ {
+ syrk(typenum, Order, Trans1, N, M, ap2, ldb, ret);
+ }
+ }
+ else
+ {
+ gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
+ }
NPY_END_ALLOW_THREADS;
}