diff options
-rw-r--r-- | numpy/core/src/multiarray/cblasfuncs.c | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c index 1789b2caf..516c6e8ae 100644 --- a/numpy/core/src/multiarray/cblasfuncs.c +++ b/numpy/core/src/multiarray/cblasfuncs.c @@ -715,7 +715,34 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, Trans2 = CblasTrans; ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1); } - gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret); + + /* + * Use syrk if we have a case of a matrix times its transpose. + * Otherwise, use gemm for all other cases. + */ + if ( + (PyArray_BYTES(ap1) == PyArray_BYTES(ap2)) && + (PyArray_DIM(ap1, 0) == PyArray_DIM(ap2, 1)) && + (PyArray_DIM(ap1, 1) == PyArray_DIM(ap2, 0)) && + (PyArray_STRIDE(ap1, 0) == PyArray_STRIDE(ap2, 1)) && + (PyArray_STRIDE(ap1, 1) == PyArray_STRIDE(ap2, 0)) && + ((Trans1 == CblasTrans) ^ (Trans2 == CblasTrans)) && + ((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans)) + ) + { + if (Trans1 == CblasNoTrans) + { + syrk(typenum, Order, Trans1, N, M, ap1, lda, ret); + } + else + { + syrk(typenum, Order, Trans1, N, M, ap2, ldb, ret); + } + } + else + { + gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret); + } NPY_END_ALLOW_THREADS; } |