summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Kirkham <kirkhamj@janelia.hhmi.org>2016-01-03 19:42:51 -0500
committerJohn Kirkham <kirkhamj@janelia.hhmi.org>2016-01-06 15:03:13 -0500
commit924e08f8eabe0aafb77d6c3ce435e2a6cf2df2e6 (patch)
treea27b06684ddaf9046889a6cf7da65acb3186a585
parentfee3ccbe8c2248bb8946d057fab94fa7012df41b (diff)
downloadnumpy-924e08f8eabe0aafb77d6c3ce435e2a6cf2df2e6.tar.gz
ENH: Use the helper function `syrk` to compute `dot` more quickly and accurately in certain special cases.
-rw-r--r--numpy/core/src/multiarray/cblasfuncs.c29
1 files changed, 28 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c
index 1789b2caf..516c6e8ae 100644
--- a/numpy/core/src/multiarray/cblasfuncs.c
+++ b/numpy/core/src/multiarray/cblasfuncs.c
@@ -715,7 +715,34 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
Trans2 = CblasTrans;
ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
}
- gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
+
+ /*
+ * Use syrk if we have a case of a matrix times its transpose.
+ * Otherwise, use gemm for all other cases.
+ */
+ if (
+ (PyArray_BYTES(ap1) == PyArray_BYTES(ap2)) &&
+ (PyArray_DIM(ap1, 0) == PyArray_DIM(ap2, 1)) &&
+ (PyArray_DIM(ap1, 1) == PyArray_DIM(ap2, 0)) &&
+ (PyArray_STRIDE(ap1, 0) == PyArray_STRIDE(ap2, 1)) &&
+ (PyArray_STRIDE(ap1, 1) == PyArray_STRIDE(ap2, 0)) &&
+ ((Trans1 == CblasTrans) ^ (Trans2 == CblasTrans)) &&
+ ((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans))
+ )
+ {
+ if (Trans1 == CblasNoTrans)
+ {
+ syrk(typenum, Order, Trans1, N, M, ap1, lda, ret);
+ }
+ else
+ {
+ syrk(typenum, Order, Trans1, N, M, ap2, ldb, ret);
+ }
+ }
+ else
+ {
+ gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
+ }
NPY_END_ALLOW_THREADS;
}