summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorNathaniel J. Smith <njs@pobox.com>2016-01-06 23:28:32 +0000
committerNathaniel J. Smith <njs@pobox.com>2016-01-06 23:28:32 +0000
commit25c8d1c2d3590c63645d5267807b15be9abd00e4 (patch)
treee2c0ae8bba755913f67d868cc26f452d85c709dd /numpy
parenta7377d881469bb45ca4710df2a46c719aa9a80e3 (diff)
parent8d8a74d8b2f86d2548a04565744ab122537a4f62 (diff)
downloadnumpy-25c8d1c2d3590c63645d5267807b15be9abd00e4.tar.gz
Merge pull request #6932 from jakirkham/opt_dot_trans
ENH: Use `syrk` to compute certain dot products more quickly and accurately
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/cblasfuncs.c97
1 files changed, 96 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c
index 67f325ba1..516c6e8ae 100644
--- a/numpy/core/src/multiarray/cblasfuncs.c
+++ b/numpy/core/src/multiarray/cblasfuncs.c
@@ -111,6 +111,74 @@ gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
}
+/*
+ * Helper: dispatch to appropriate cblas_?syrk for typenum.
+ */
+static void
+syrk(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
+ int n, int k,
+ PyArrayObject *A, int lda, PyArrayObject *R)
+{
+ const void *Adata = PyArray_DATA(A);
+ void *Rdata = PyArray_DATA(R);
+ int ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1;
+
+ npy_intp i;
+ npy_intp j;
+
+ switch (typenum) {
+ case NPY_DOUBLE:
+ cblas_dsyrk(order, CblasUpper, trans, n, k, 1.,
+ Adata, lda, 0., Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_double*)PyArray_GETPTR2(R, j, i)) = *((npy_double*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_FLOAT:
+ cblas_ssyrk(order, CblasUpper, trans, n, k, 1.f,
+ Adata, lda, 0.f, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_float*)PyArray_GETPTR2(R, j, i)) = *((npy_float*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_CDOUBLE:
+ cblas_zsyrk(order, CblasUpper, trans, n, k, oneD,
+ Adata, lda, zeroD, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_cdouble*)PyArray_GETPTR2(R, j, i)) = *((npy_cdouble*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_CFLOAT:
+ cblas_csyrk(order, CblasUpper, trans, n, k, oneF,
+ Adata, lda, zeroF, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_cfloat*)PyArray_GETPTR2(R, j, i)) = *((npy_cfloat*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ }
+}
+
+
typedef enum {_scalar, _column, _row, _matrix} MatrixShape;
@@ -647,7 +715,34 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
Trans2 = CblasTrans;
ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
}
- gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
+
+ /*
+ * Use syrk if we have a case of a matrix times its transpose.
+ * Otherwise, use gemm for all other cases.
+ */
+ if (
+ (PyArray_BYTES(ap1) == PyArray_BYTES(ap2)) &&
+ (PyArray_DIM(ap1, 0) == PyArray_DIM(ap2, 1)) &&
+ (PyArray_DIM(ap1, 1) == PyArray_DIM(ap2, 0)) &&
+ (PyArray_STRIDE(ap1, 0) == PyArray_STRIDE(ap2, 1)) &&
+ (PyArray_STRIDE(ap1, 1) == PyArray_STRIDE(ap2, 0)) &&
+ ((Trans1 == CblasTrans) ^ (Trans2 == CblasTrans)) &&
+ ((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans))
+ )
+ {
+ if (Trans1 == CblasNoTrans)
+ {
+ syrk(typenum, Order, Trans1, N, M, ap1, lda, ret);
+ }
+ else
+ {
+ syrk(typenum, Order, Trans1, N, M, ap2, ldb, ret);
+ }
+ }
+ else
+ {
+ gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
+ }
NPY_END_ALLOW_THREADS;
}