diff options
author | Lars Buitinck <larsmans@gmail.com> | 2013-10-23 21:55:33 +0200 |
---|---|---|
committer | Lars Buitinck <larsmans@gmail.com> | 2013-11-21 17:50:01 +0100 |
commit | 1d1803257467c554db208a24b857d613b872b9bd (patch) | |
tree | ce2f8937b13dc757097e99418815214856907c50 | |
parent | 46dd681ff4aed02a87b27db684f8600b5b7a7fa3 (diff) | |
download | numpy-1d1803257467c554db208a24b857d613b872b9bd.tar.gz |
ENH: dotblas: re-route all cblas_?dot calls to chunked versions
Also factored out and simplified NumPy->BLAS stride conversion.
-rw-r--r-- | numpy/core/blasdot/_dotblas.c | 147 |
1 files changed, 71 insertions, 76 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c index 7fd6d25ff..48aa39ff8 100644 --- a/numpy/core/blasdot/_dotblas.c +++ b/numpy/core/blasdot/_dotblas.c @@ -15,6 +15,7 @@ #endif #include CBLAS_HEADER +#include <assert.h> #include <limits.h> #include <stdio.h> @@ -26,6 +27,24 @@ static PyArray_DotFunc *oldFunctions[NPY_NTYPES]; #define MIN(a, b) ((a) < (b) ? (a) : (b)) /* + * Convert NumPy stride to BLAS stride. Returns 0 if conversion cannot be done + * (BLAS won't handle negative or zero strides the way we want). + */ +static NPY_INLINE int +blas_stride(npy_intp stride, unsigned itemsize) +{ + if (stride <= 0 || stride % itemsize != 0) { + return 0; + } + stride /= itemsize; + + if (stride > INT_MAX) { + return 0; + } + return stride; +} + +/* * The following functions do a "chunked" dot product using BLAS when * sizeof(npy_intp) > sizeof(int), because BLAS libraries can typically not * handle more than INT_MAX elements per call. @@ -42,13 +61,10 @@ static void FLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { - npy_intp na = stridea / (int)sizeof(float); - npy_intp nb = strideb / (int)sizeof(float); + int na = blas_stride(stridea, sizeof(float)); + int nb = blas_stride(strideb, sizeof(float)); - /* XXX we could handle negative strides as well */ - if (stridea > 0 && strideb > 0 && - stridea == sizeof(float) * na && - strideb == sizeof(float) * nb) { + if (na && nb) { double r = 0.; /* double for stability */ float *fa = a, *fb = b; @@ -71,12 +87,10 @@ static void DOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { - npy_intp na = stridea / (int)sizeof(double); - npy_intp nb = strideb / (int)sizeof(double); + int na = blas_stride(stridea, sizeof(double)); + int nb = blas_stride(strideb, sizeof(double)); - if (stridea > 0 && strideb > 0 && - stridea == sizeof(double) * na && - strideb == sizeof(double) * nb) { + if (na && nb) { double r = 0.; double *da = a, *db = b; @@ -99,12 +113,10 @@ static void CFLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { - npy_intp na = stridea / (int)sizeof(npy_cfloat); - npy_intp nb = strideb / (int)sizeof(npy_cfloat); + int na = blas_stride(stridea, sizeof(npy_cfloat)); + int nb = blas_stride(strideb, sizeof(npy_cfloat)); - if (stridea > 0 && strideb > 0 && n <= INT_MAX && - stridea == sizeof(npy_cfloat) * na && - strideb == sizeof(npy_cfloat) * nb) { + if (na && nb) { cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res); } else { @@ -116,12 +128,10 @@ static void CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { - npy_intp na = stridea / (int)sizeof(npy_cdouble); - npy_intp nb = strideb / (int)sizeof(npy_cdouble); + int na = blas_stride(stridea, sizeof(npy_cdouble)); + int nb = blas_stride(strideb, sizeof(npy_cdouble)); - if (stridea > 0 && strideb > 0 && n <= INT_MAX && - stridea == sizeof(npy_cdouble) * na && - strideb == sizeof(npy_cdouble) * nb) { + if (na && nb) { cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb, (double *)res); } @@ -130,6 +140,33 @@ CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, } } +/* + * Helper: call appropriate BLAS dot function for typenum. + * Strides are NumPy strides. + */ +static void +blas_dot(int typenum, npy_intp n, + void *a, npy_intp stridea, void *b, npy_intp strideb, void *res) +{ + PyArray_DotFunc *dot = NULL; + switch (typenum) { + case NPY_DOUBLE: + dot = DOUBLE_dot; + break; + case NPY_FLOAT: + dot = FLOAT_dot; + break; + case NPY_CDOUBLE: + dot = CDOUBLE_dot; + break; + case NPY_CFLOAT: + dot = CFLOAT_dot; + break; + } + assert(dot != NULL); + dot(a, stridea, b, strideb, res, n, NULL); +} + static const double oneD[2] = {1.0, 0.0}, zeroD[2] = {0.0, 0.0}; static const float oneF[2] = {1.0, 0.0}, zeroF[2] = {0.0, 0.0}; @@ -342,7 +379,8 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa PyObject *op1, *op2; PyArrayObject *ap1 = NULL, *ap2 = NULL, *out = NULL, *ret = NULL; int errval; - int j, l, lda, ldb; + int j, lda, ldb; + npy_intp l; int typenum, nd; npy_intp ap1stride = 0; npy_intp dimensions[NPY_MAXDIMS]; @@ -746,36 +784,13 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa NPY_END_ALLOW_THREADS; } else if ((ap2shape == _column) && (ap1shape != _matrix)) { - int ap1s, ap2s; NPY_BEGIN_ALLOW_THREADS; - ap2s = PyArray_STRIDE(ap2, 0) / PyArray_ITEMSIZE(ap2); - if (ap1shape == _row) { - ap1s = PyArray_STRIDE(ap1, 1) / PyArray_ITEMSIZE(ap1); - } - else { - ap1s = PyArray_STRIDE(ap1, 0) / PyArray_ITEMSIZE(ap1); - } - /* Dot product between two vectors -- Level 1 BLAS */ - if (typenum == NPY_DOUBLE) { - double result = cblas_ddot(l, (double *)PyArray_DATA(ap1), ap1s, - (double *)PyArray_DATA(ap2), ap2s); - *((double *)PyArray_DATA(ret)) = result; - } - else if (typenum == NPY_FLOAT) { - float result = cblas_sdot(l, (float *)PyArray_DATA(ap1), ap1s, - (float *)PyArray_DATA(ap2), ap2s); - *((float *)PyArray_DATA(ret)) = result; - } - else if (typenum == NPY_CDOUBLE) { - cblas_zdotu_sub(l, (double *)PyArray_DATA(ap1), ap1s, - (double *)PyArray_DATA(ap2), ap2s, (double *)PyArray_DATA(ret)); - } - else if (typenum == NPY_CFLOAT) { - cblas_cdotu_sub(l, (float *)PyArray_DATA(ap1), ap1s, - (float *)PyArray_DATA(ap2), ap2s, (float *)PyArray_DATA(ret)); - } + blas_dot(typenum, l, + PyArray_DATA(ap1), PyArray_STRIDE(ap1, (ap1shape == _row)), + PyArray_DATA(ap2), PyArray_STRIDE(ap2, 0), + PyArray_DATA(ret)); NPY_END_ALLOW_THREADS; } else if (ap1shape == _matrix && ap2shape != _matrix) { @@ -1040,24 +1055,8 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) } else if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 1) { /* Dot product between two vectors -- Level 1 BLAS */ - if (typenum == NPY_DOUBLE) { - double result = cblas_ddot(l, (double *)PyArray_DATA(ap1), 1, - (double *)PyArray_DATA(ap2), 1); - *((double *)PyArray_DATA(ret)) = result; - } - else if (typenum == NPY_CDOUBLE) { - cblas_zdotu_sub(l, (double *)PyArray_DATA(ap1), 1, - (double *)PyArray_DATA(ap2), 1, (double *)PyArray_DATA(ret)); - } - else if (typenum == NPY_FLOAT) { - float result = cblas_sdot(l, (float *)PyArray_DATA(ap1), 1, - (float *)PyArray_DATA(ap2), 1); - *((float *)PyArray_DATA(ret)) = result; - } - else if (typenum == NPY_CFLOAT) { - cblas_cdotu_sub(l, (float *)PyArray_DATA(ap1), 1, - (float *)PyArray_DATA(ap2), 1, (float *)PyArray_DATA(ret)); - } + blas_dot(typenum, l, PyArray_DATA(ap1), PyArray_ITEMSIZE(ap1), + PyArray_DATA(ap2), PyArray_ITEMSIZE(ap2), PyArray_DATA(ret)); } else if (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 1) { /* Matrix-vector multiplication -- Level 2 BLAS */ @@ -1162,16 +1161,12 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) { ret = (PyArrayObject *)PyArray_SimpleNew(0, dimensions, typenum); if (ret == NULL) goto fail; - NPY_BEGIN_ALLOW_THREADS + NPY_BEGIN_ALLOW_THREADS; /* Dot product between two vectors -- Level 1 BLAS */ - if (typenum == NPY_DOUBLE) { - *((double *)PyArray_DATA(ret)) = cblas_ddot(l, (double *)PyArray_DATA(ap1), 1, - (double *)PyArray_DATA(ap2), 1); - } - else if (typenum == NPY_FLOAT) { - *((float *)PyArray_DATA(ret)) = cblas_sdot(l, (float *)PyArray_DATA(ap1), 1, - (float *)PyArray_DATA(ap2), 1); + if (typenum == NPY_DOUBLE || typenum == NPY_FLOAT) { + blas_dot(typenum, l, PyArray_DATA(ap1), PyArray_ITEMSIZE(ap1), + PyArray_DATA(ap2), PyArray_ITEMSIZE(ap2), PyArray_DATA(ret)); } else if (typenum == NPY_CDOUBLE) { cblas_zdotc_sub(l, (double *)PyArray_DATA(ap1), 1, @@ -1182,7 +1177,7 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) { (float *)PyArray_DATA(ap2), 1, (float *)PyArray_DATA(ret)); } - NPY_END_ALLOW_THREADS + NPY_END_ALLOW_THREADS; Py_DECREF(ap1); Py_DECREF(ap2); |