diff options
author | John Kirkham <kirkhamj@janelia.hhmi.org> | 2016-01-03 19:42:22 -0500 |
---|---|---|
committer | John Kirkham <kirkhamj@janelia.hhmi.org> | 2016-01-06 15:03:12 -0500 |
commit | fee3ccbe8c2248bb8946d057fab94fa7012df41b (patch) | |
tree | 2c9b311aba90dd152825c4189d322090c62c46aa /numpy | |
parent | 89dc3cfe0c1ad749b6b7e987883bb3bdddbb7792 (diff) | |
download | numpy-fee3ccbe8c2248bb8946d057fab94fa7012df41b.tar.gz |
ENH: Added the helper function `syrk` that computes `a.T @ a` or `a @ a.T`.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/cblasfuncs.c | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c index 67f325ba1..1789b2caf 100644 --- a/numpy/core/src/multiarray/cblasfuncs.c +++ b/numpy/core/src/multiarray/cblasfuncs.c @@ -111,6 +111,74 @@ gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans, } +/* + * Helper: dispatch to appropriate cblas_?syrk for typenum. + */ +static void +syrk(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans, + int n, int k, + PyArrayObject *A, int lda, PyArrayObject *R) +{ + const void *Adata = PyArray_DATA(A); + void *Rdata = PyArray_DATA(R); + int ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1; + + npy_intp i; + npy_intp j; + + switch (typenum) { + case NPY_DOUBLE: + cblas_dsyrk(order, CblasUpper, trans, n, k, 1., + Adata, lda, 0., Rdata, ldc); + + for (i = 0; i < n; i++) + { + for (j = i + 1; j < n; j++) + { + *((npy_double*)PyArray_GETPTR2(R, j, i)) = *((npy_double*)PyArray_GETPTR2(R, i, j)); + } + } + break; + case NPY_FLOAT: + cblas_ssyrk(order, CblasUpper, trans, n, k, 1.f, + Adata, lda, 0.f, Rdata, ldc); + + for (i = 0; i < n; i++) + { + for (j = i + 1; j < n; j++) + { + *((npy_float*)PyArray_GETPTR2(R, j, i)) = *((npy_float*)PyArray_GETPTR2(R, i, j)); + } + } + break; + case NPY_CDOUBLE: + cblas_zsyrk(order, CblasUpper, trans, n, k, oneD, + Adata, lda, zeroD, Rdata, ldc); + + for (i = 0; i < n; i++) + { + for (j = i + 1; j < n; j++) + { + *((npy_cdouble*)PyArray_GETPTR2(R, j, i)) = *((npy_cdouble*)PyArray_GETPTR2(R, i, j)); + } + } + break; + case NPY_CFLOAT: + cblas_csyrk(order, CblasUpper, trans, n, k, oneF, + Adata, lda, zeroF, Rdata, ldc); + + for (i = 0; i < n; i++) + { + for (j = i + 1; j < n; j++) + { + *((npy_cfloat*)PyArray_GETPTR2(R, j, i)) = *((npy_cfloat*)PyArray_GETPTR2(R, i, j)); + } + } + break; + } +} + + typedef enum {_scalar, _column, _row, _matrix} MatrixShape; |