diff options
author | John Kirkham <kirkhamj@janelia.hhmi.org> | 2016-01-03 19:42:51 -0500 |
---|---|---|
committer | John Kirkham <kirkhamj@janelia.hhmi.org> | 2016-01-06 15:03:13 -0500 |
commit | 924e08f8eabe0aafb77d6c3ce435e2a6cf2df2e6 (patch) | |
tree | a27b06684ddaf9046889a6cf7da65acb3186a585 | |
parent | fee3ccbe8c2248bb8946d057fab94fa7012df41b (diff) | |
download | numpy-924e08f8eabe0aafb77d6c3ce435e2a6cf2df2e6.tar.gz |
ENH: Use the helper function `syrk` to compute `dot` more quickly and accurately in certain special cases.
-rw-r--r-- | numpy/core/src/multiarray/cblasfuncs.c | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c index 1789b2caf..516c6e8ae 100644 --- a/numpy/core/src/multiarray/cblasfuncs.c +++ b/numpy/core/src/multiarray/cblasfuncs.c @@ -715,7 +715,34 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, Trans2 = CblasTrans; ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1); } - gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret); + + /* + * Use syrk if we have a case of a matrix times its transpose. + * Otherwise, use gemm for all other cases. + */ + if ( + (PyArray_BYTES(ap1) == PyArray_BYTES(ap2)) && + (PyArray_DIM(ap1, 0) == PyArray_DIM(ap2, 1)) && + (PyArray_DIM(ap1, 1) == PyArray_DIM(ap2, 0)) && + (PyArray_STRIDE(ap1, 0) == PyArray_STRIDE(ap2, 1)) && + (PyArray_STRIDE(ap1, 1) == PyArray_STRIDE(ap2, 0)) && + ((Trans1 == CblasTrans) ^ (Trans2 == CblasTrans)) && + ((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans)) + ) + { + if (Trans1 == CblasNoTrans) + { + syrk(typenum, Order, Trans1, N, M, ap1, lda, ret); + } + else + { + syrk(typenum, Order, Trans1, N, M, ap2, ldb, ret); + } + } + else + { + gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret); + } NPY_END_ALLOW_THREADS; } |