diff options
author | Nathaniel J. Smith <njs@pobox.com> | 2016-01-06 23:28:32 +0000 |
---|---|---|
committer | Nathaniel J. Smith <njs@pobox.com> | 2016-01-06 23:28:32 +0000 |
commit | 25c8d1c2d3590c63645d5267807b15be9abd00e4 (patch) | |
tree | e2c0ae8bba755913f67d868cc26f452d85c709dd | |
parent | a7377d881469bb45ca4710df2a46c719aa9a80e3 (diff) | |
parent | 8d8a74d8b2f86d2548a04565744ab122537a4f62 (diff) | |
download | numpy-25c8d1c2d3590c63645d5267807b15be9abd00e4.tar.gz |
Merge pull request #6932 from jakirkham/opt_dot_trans
ENH: Use `syrk` to compute certain dot products more quickly and accurately
-rw-r--r-- | benchmarks/benchmarks/bench_linalg.py | 14 | ||||
-rw-r--r-- | doc/release/1.11.0-notes.rst | 8 | ||||
-rw-r--r-- | numpy/core/src/multiarray/cblasfuncs.c | 97 |
3 files changed, 118 insertions, 1 deletions
diff --git a/benchmarks/benchmarks/bench_linalg.py b/benchmarks/benchmarks/bench_linalg.py index a323609b7..c230d985a 100644 --- a/benchmarks/benchmarks/bench_linalg.py +++ b/benchmarks/benchmarks/bench_linalg.py @@ -8,6 +8,8 @@ import numpy as np class Eindot(Benchmark): def setup(self): self.a = np.arange(60000.0).reshape(150, 400) + self.at = self.a.T + self.atc = self.a.T.copy() self.b = np.arange(240000.0).reshape(400, 600) self.c = np.arange(600) self.d = np.arange(400) @@ -21,6 +23,18 @@ class Eindot(Benchmark): def time_dot_a_b(self): np.dot(self.a, self.b) + def time_dot_trans_a_at(self): + np.dot(self.a, self.at) + + def time_dot_trans_a_atc(self): + np.dot(self.a, self.atc) + + def time_dot_trans_at_a(self): + np.dot(self.at, self.a) + + def time_dot_trans_atc_a(self): + np.dot(self.atc, self.a) + def time_einsum_i_ij_j(self): np.einsum('i,ij,j', self.d, self.b, self.c) diff --git a/doc/release/1.11.0-notes.rst b/doc/release/1.11.0-notes.rst index 54b6874cc..099ce3976 100644 --- a/doc/release/1.11.0-notes.rst +++ b/doc/release/1.11.0-notes.rst @@ -130,6 +130,14 @@ useless computations when printing a masked array. The function now uses the fallocate system call to reserve sufficient diskspace on filesystems that support it. +``np.dot`` optimized for operations of the form ``A.T @ A`` and ``A @ A.T`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Previously, ``gemm`` BLAS operations were used for all matrix products. Now, +if the matrix product is between a matrix and its transpose, it will use +``syrk`` BLAS operations for a performance boost. + +**Note:** Requires the transposed and non-transposed matrices to share data. + Changes ======= diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c index 67f325ba1..516c6e8ae 100644 --- a/numpy/core/src/multiarray/cblasfuncs.c +++ b/numpy/core/src/multiarray/cblasfuncs.c @@ -111,6 +111,74 @@ gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans, } +/* + * Helper: dispatch to appropriate cblas_?syrk for typenum. + */ +static void +syrk(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans, + int n, int k, + PyArrayObject *A, int lda, PyArrayObject *R) +{ + const void *Adata = PyArray_DATA(A); + void *Rdata = PyArray_DATA(R); + int ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1; + + npy_intp i; + npy_intp j; + + switch (typenum) { + case NPY_DOUBLE: + cblas_dsyrk(order, CblasUpper, trans, n, k, 1., + Adata, lda, 0., Rdata, ldc); + + for (i = 0; i < n; i++) + { + for (j = i + 1; j < n; j++) + { + *((npy_double*)PyArray_GETPTR2(R, j, i)) = *((npy_double*)PyArray_GETPTR2(R, i, j)); + } + } + break; + case NPY_FLOAT: + cblas_ssyrk(order, CblasUpper, trans, n, k, 1.f, + Adata, lda, 0.f, Rdata, ldc); + + for (i = 0; i < n; i++) + { + for (j = i + 1; j < n; j++) + { + *((npy_float*)PyArray_GETPTR2(R, j, i)) = *((npy_float*)PyArray_GETPTR2(R, i, j)); + } + } + break; + case NPY_CDOUBLE: + cblas_zsyrk(order, CblasUpper, trans, n, k, oneD, + Adata, lda, zeroD, Rdata, ldc); + + for (i = 0; i < n; i++) + { + for (j = i + 1; j < n; j++) + { + *((npy_cdouble*)PyArray_GETPTR2(R, j, i)) = *((npy_cdouble*)PyArray_GETPTR2(R, i, j)); + } + } + break; + case NPY_CFLOAT: + cblas_csyrk(order, CblasUpper, trans, n, k, oneF, + Adata, lda, zeroF, Rdata, ldc); + + for (i = 0; i < n; i++) + { + for (j = i + 1; j < n; j++) + { + *((npy_cfloat*)PyArray_GETPTR2(R, j, i)) = *((npy_cfloat*)PyArray_GETPTR2(R, i, j)); + } + } + break; + } +} + + typedef enum {_scalar, _column, _row, _matrix} MatrixShape; @@ -647,7 +715,34 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, Trans2 = CblasTrans; ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1); } - gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret); + + /* + * Use syrk if we have a case of a matrix times its transpose. + * Otherwise, use gemm for all other cases. + */ + if ( + (PyArray_BYTES(ap1) == PyArray_BYTES(ap2)) && + (PyArray_DIM(ap1, 0) == PyArray_DIM(ap2, 1)) && + (PyArray_DIM(ap1, 1) == PyArray_DIM(ap2, 0)) && + (PyArray_STRIDE(ap1, 0) == PyArray_STRIDE(ap2, 1)) && + (PyArray_STRIDE(ap1, 1) == PyArray_STRIDE(ap2, 0)) && + ((Trans1 == CblasTrans) ^ (Trans2 == CblasTrans)) && + ((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans)) + ) + { + if (Trans1 == CblasNoTrans) + { + syrk(typenum, Order, Trans1, N, M, ap1, lda, ret); + } + else + { + syrk(typenum, Order, Trans1, N, M, ap2, ldb, ret); + } + } + else + { + gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret); + } NPY_END_ALLOW_THREADS; } |