summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathaniel J. Smith <njs@pobox.com>2016-01-06 23:28:32 +0000
committerNathaniel J. Smith <njs@pobox.com>2016-01-06 23:28:32 +0000
commit25c8d1c2d3590c63645d5267807b15be9abd00e4 (patch)
treee2c0ae8bba755913f67d868cc26f452d85c709dd
parenta7377d881469bb45ca4710df2a46c719aa9a80e3 (diff)
parent8d8a74d8b2f86d2548a04565744ab122537a4f62 (diff)
downloadnumpy-25c8d1c2d3590c63645d5267807b15be9abd00e4.tar.gz
Merge pull request #6932 from jakirkham/opt_dot_trans
ENH: Use `syrk` to compute certain dot products more quickly and accurately
-rw-r--r--benchmarks/benchmarks/bench_linalg.py14
-rw-r--r--doc/release/1.11.0-notes.rst8
-rw-r--r--numpy/core/src/multiarray/cblasfuncs.c97
3 files changed, 118 insertions, 1 deletions
diff --git a/benchmarks/benchmarks/bench_linalg.py b/benchmarks/benchmarks/bench_linalg.py
index a323609b7..c230d985a 100644
--- a/benchmarks/benchmarks/bench_linalg.py
+++ b/benchmarks/benchmarks/bench_linalg.py
@@ -8,6 +8,8 @@ import numpy as np
class Eindot(Benchmark):
def setup(self):
self.a = np.arange(60000.0).reshape(150, 400)
+ self.at = self.a.T
+ self.atc = self.a.T.copy()
self.b = np.arange(240000.0).reshape(400, 600)
self.c = np.arange(600)
self.d = np.arange(400)
@@ -21,6 +23,18 @@ class Eindot(Benchmark):
def time_dot_a_b(self):
np.dot(self.a, self.b)
+ def time_dot_trans_a_at(self):
+ np.dot(self.a, self.at)
+
+ def time_dot_trans_a_atc(self):
+ np.dot(self.a, self.atc)
+
+ def time_dot_trans_at_a(self):
+ np.dot(self.at, self.a)
+
+ def time_dot_trans_atc_a(self):
+ np.dot(self.atc, self.a)
+
def time_einsum_i_ij_j(self):
np.einsum('i,ij,j', self.d, self.b, self.c)
diff --git a/doc/release/1.11.0-notes.rst b/doc/release/1.11.0-notes.rst
index 54b6874cc..099ce3976 100644
--- a/doc/release/1.11.0-notes.rst
+++ b/doc/release/1.11.0-notes.rst
@@ -130,6 +130,14 @@ useless computations when printing a masked array.
The function now uses the fallocate system call to reserve sufficient
diskspace on filesystems that support it.
+``np.dot`` optimized for operations of the form ``A.T @ A`` and ``A @ A.T``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Previously, ``gemm`` BLAS operations were used for all matrix products. Now,
+if the matrix product is between a matrix and its transpose, it will use
+``syrk`` BLAS operations for a performance boost.
+
+**Note:** Requires the transposed and non-transposed matrices to share data.
+
Changes
=======
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c
index 67f325ba1..516c6e8ae 100644
--- a/numpy/core/src/multiarray/cblasfuncs.c
+++ b/numpy/core/src/multiarray/cblasfuncs.c
@@ -111,6 +111,74 @@ gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
}
+/*
+ * Helper: dispatch to appropriate cblas_?syrk for typenum.
+ */
+static void
+syrk(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
+ int n, int k,
+ PyArrayObject *A, int lda, PyArrayObject *R)
+{
+ const void *Adata = PyArray_DATA(A);
+ void *Rdata = PyArray_DATA(R);
+ int ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1;
+
+ npy_intp i;
+ npy_intp j;
+
+ switch (typenum) {
+ case NPY_DOUBLE:
+ cblas_dsyrk(order, CblasUpper, trans, n, k, 1.,
+ Adata, lda, 0., Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_double*)PyArray_GETPTR2(R, j, i)) = *((npy_double*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_FLOAT:
+ cblas_ssyrk(order, CblasUpper, trans, n, k, 1.f,
+ Adata, lda, 0.f, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_float*)PyArray_GETPTR2(R, j, i)) = *((npy_float*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_CDOUBLE:
+ cblas_zsyrk(order, CblasUpper, trans, n, k, oneD,
+ Adata, lda, zeroD, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_cdouble*)PyArray_GETPTR2(R, j, i)) = *((npy_cdouble*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_CFLOAT:
+ cblas_csyrk(order, CblasUpper, trans, n, k, oneF,
+ Adata, lda, zeroF, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_cfloat*)PyArray_GETPTR2(R, j, i)) = *((npy_cfloat*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ }
+}
+
+
typedef enum {_scalar, _column, _row, _matrix} MatrixShape;
@@ -647,7 +715,34 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
Trans2 = CblasTrans;
ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
}
- gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
+
+ /*
+ * Use syrk if we have a case of a matrix times its transpose.
+ * Otherwise, use gemm for all other cases.
+ */
+ if (
+ (PyArray_BYTES(ap1) == PyArray_BYTES(ap2)) &&
+ (PyArray_DIM(ap1, 0) == PyArray_DIM(ap2, 1)) &&
+ (PyArray_DIM(ap1, 1) == PyArray_DIM(ap2, 0)) &&
+ (PyArray_STRIDE(ap1, 0) == PyArray_STRIDE(ap2, 1)) &&
+ (PyArray_STRIDE(ap1, 1) == PyArray_STRIDE(ap2, 0)) &&
+ ((Trans1 == CblasTrans) ^ (Trans2 == CblasTrans)) &&
+ ((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans))
+ )
+ {
+ if (Trans1 == CblasNoTrans)
+ {
+ syrk(typenum, Order, Trans1, N, M, ap1, lda, ret);
+ }
+ else
+ {
+ syrk(typenum, Order, Trans1, N, M, ap2, ldb, ret);
+ }
+ }
+ else
+ {
+ gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
+ }
NPY_END_ALLOW_THREADS;
}