diff options
author | seberg <sebastian@sipsolutions.net> | 2013-12-30 02:23:49 -0800 |
---|---|---|
committer | seberg <sebastian@sipsolutions.net> | 2013-12-30 02:23:49 -0800 |
commit | 14fb7000694523919fe4045096fbfa3703b1d1fe (patch) | |
tree | b1577b6a365976b9fd0f08d7f38cf60b4f3a7176 | |
parent | 61998c22df08e7ec0938bedef931ac824cfb634a (diff) | |
parent | 1d1803257467c554db208a24b857d613b872b9bd (diff) | |
download | numpy-14fb7000694523919fe4045096fbfa3703b1d1fe.tar.gz |
Merge pull request #3916 from larsmans/dotblas-64bit-safe
64-bit safe _dotblas
-rw-r--r-- | numpy/core/blasdot/_dotblas.c | 221 |
1 files changed, 131 insertions, 90 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c index a8b34593e..48aa39ff8 100644 --- a/numpy/core/blasdot/_dotblas.c +++ b/numpy/core/blasdot/_dotblas.c @@ -15,6 +15,8 @@ #endif #include CBLAS_HEADER +#include <assert.h> +#include <limits.h> #include <stdio.h> static char module_doc[] = @@ -22,66 +24,147 @@ static char module_doc[] = static PyArray_DotFunc *oldFunctions[NPY_NTYPES]; +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +/* + * Convert NumPy stride to BLAS stride. Returns 0 if conversion cannot be done + * (BLAS won't handle negative or zero strides the way we want). + */ +static NPY_INLINE int +blas_stride(npy_intp stride, unsigned itemsize) +{ + if (stride <= 0 || stride % itemsize != 0) { + return 0; + } + stride /= itemsize; + + if (stride > INT_MAX) { + return 0; + } + return stride; +} + +/* + * The following functions do a "chunked" dot product using BLAS when + * sizeof(npy_intp) > sizeof(int), because BLAS libraries can typically not + * handle more than INT_MAX elements per call. + * + * The chunksize is the greatest power of two less than INT_MAX. + */ +#if NPY_MAX_INTP > INT_MAX +# define CHUNKSIZE (INT_MAX / 2 + 1) +#else +# define CHUNKSIZE NPY_MAX_INTP +#endif + static void FLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { - register npy_intp na = stridea / sizeof(float); - register npy_intp nb = strideb / sizeof(float); + int na = blas_stride(stridea, sizeof(float)); + int nb = blas_stride(strideb, sizeof(float)); - if ((sizeof(float) * na == (size_t)stridea) && - (sizeof(float) * nb == (size_t)strideb) && - (na >= 0) && (nb >= 0)) - *((float *)res) = cblas_sdot((int)n, (float *)a, na, (float *)b, nb); + if (na && nb) { + double r = 0.; /* double for stability */ + float *fa = a, *fb = b; - else - oldFunctions[NPY_FLOAT](a, stridea, b, strideb, res, n, tmp); + while (n > 0) { + int chunk = MIN(n, CHUNKSIZE); + + r += cblas_sdot(chunk, fa, na, fb, nb); + fa += chunk * na; + fb += chunk * nb; + n -= chunk; + } + *((float *)res) = r; + } + else { + oldFunctions[NPY_FLOAT](a, stridea, b, strideb, res, n, tmp); + } } static void DOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { - register int na = stridea / sizeof(double); - register int nb = strideb / sizeof(double); - - if ((sizeof(double) * na == (size_t)stridea) && - (sizeof(double) * nb == (size_t)strideb) && - (na >= 0) && (nb >= 0)) - *((double *)res) = cblas_ddot((int)n, (double *)a, na, (double *)b, nb); - else - oldFunctions[NPY_DOUBLE](a, stridea, b, strideb, res, n, tmp); + int na = blas_stride(stridea, sizeof(double)); + int nb = blas_stride(strideb, sizeof(double)); + + if (na && nb) { + double r = 0.; + double *da = a, *db = b; + + while (n > 0) { + int chunk = MIN(n, CHUNKSIZE); + + r += cblas_ddot(chunk, da, na, db, nb); + da += chunk * na; + db += chunk * nb; + n -= chunk; + } + *((double *)res) = r; + } + else { + oldFunctions[NPY_DOUBLE](a, stridea, b, strideb, res, n, tmp); + } } static void CFLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { + int na = blas_stride(stridea, sizeof(npy_cfloat)); + int nb = blas_stride(strideb, sizeof(npy_cfloat)); - register int na = stridea / sizeof(npy_cfloat); - register int nb = strideb / sizeof(npy_cfloat); - - if ((sizeof(npy_cfloat) * na == (size_t)stridea) && - (sizeof(npy_cfloat) * nb == (size_t)strideb) && - (na >= 0) && (nb >= 0)) - cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res); - else - oldFunctions[NPY_CFLOAT](a, stridea, b, strideb, res, n, tmp); + if (na && nb) { + cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res); + } + else { + oldFunctions[NPY_CFLOAT](a, stridea, b, strideb, res, n, tmp); + } } static void CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { - register int na = stridea / sizeof(npy_cdouble); - register int nb = strideb / sizeof(npy_cdouble); - - if ((sizeof(npy_cdouble) * na == (size_t)stridea) && - (sizeof(npy_cdouble) * nb == (size_t)strideb) && - (na >= 0) && (nb >= 0)) - cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb, (double *)res); - else - oldFunctions[NPY_CDOUBLE](a, stridea, b, strideb, res, n, tmp); + int na = blas_stride(stridea, sizeof(npy_cdouble)); + int nb = blas_stride(strideb, sizeof(npy_cdouble)); + + if (na && nb) { + cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb, + (double *)res); + } + else { + oldFunctions[NPY_CDOUBLE](a, stridea, b, strideb, res, n, tmp); + } +} + +/* + * Helper: call appropriate BLAS dot function for typenum. + * Strides are NumPy strides. + */ +static void +blas_dot(int typenum, npy_intp n, + void *a, npy_intp stridea, void *b, npy_intp strideb, void *res) +{ + PyArray_DotFunc *dot = NULL; + switch (typenum) { + case NPY_DOUBLE: + dot = DOUBLE_dot; + break; + case NPY_FLOAT: + dot = FLOAT_dot; + break; + case NPY_CDOUBLE: + dot = CDOUBLE_dot; + break; + case NPY_CFLOAT: + dot = CFLOAT_dot; + break; + } + assert(dot != NULL); + dot(a, stridea, b, strideb, res, n, NULL); } @@ -296,7 +379,8 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa PyObject *op1, *op2; PyArrayObject *ap1 = NULL, *ap2 = NULL, *out = NULL, *ret = NULL; int errval; - int j, l, lda, ldb; + int j, lda, ldb; + npy_intp l; int typenum, nd; npy_intp ap1stride = 0; npy_intp dimensions[NPY_MAXDIMS]; @@ -700,36 +784,13 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa NPY_END_ALLOW_THREADS; } else if ((ap2shape == _column) && (ap1shape != _matrix)) { - int ap1s, ap2s; NPY_BEGIN_ALLOW_THREADS; - ap2s = PyArray_STRIDE(ap2, 0) / PyArray_ITEMSIZE(ap2); - if (ap1shape == _row) { - ap1s = PyArray_STRIDE(ap1, 1) / PyArray_ITEMSIZE(ap1); - } - else { - ap1s = PyArray_STRIDE(ap1, 0) / PyArray_ITEMSIZE(ap1); - } - /* Dot product between two vectors -- Level 1 BLAS */ - if (typenum == NPY_DOUBLE) { - double result = cblas_ddot(l, (double *)PyArray_DATA(ap1), ap1s, - (double *)PyArray_DATA(ap2), ap2s); - *((double *)PyArray_DATA(ret)) = result; - } - else if (typenum == NPY_FLOAT) { - float result = cblas_sdot(l, (float *)PyArray_DATA(ap1), ap1s, - (float *)PyArray_DATA(ap2), ap2s); - *((float *)PyArray_DATA(ret)) = result; - } - else if (typenum == NPY_CDOUBLE) { - cblas_zdotu_sub(l, (double *)PyArray_DATA(ap1), ap1s, - (double *)PyArray_DATA(ap2), ap2s, (double *)PyArray_DATA(ret)); - } - else if (typenum == NPY_CFLOAT) { - cblas_cdotu_sub(l, (float *)PyArray_DATA(ap1), ap1s, - (float *)PyArray_DATA(ap2), ap2s, (float *)PyArray_DATA(ret)); - } + blas_dot(typenum, l, + PyArray_DATA(ap1), PyArray_STRIDE(ap1, (ap1shape == _row)), + PyArray_DATA(ap2), PyArray_STRIDE(ap2, 0), + PyArray_DATA(ret)); NPY_END_ALLOW_THREADS; } else if (ap1shape == _matrix && ap2shape != _matrix) { @@ -994,24 +1055,8 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) } else if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 1) { /* Dot product between two vectors -- Level 1 BLAS */ - if (typenum == NPY_DOUBLE) { - double result = cblas_ddot(l, (double *)PyArray_DATA(ap1), 1, - (double *)PyArray_DATA(ap2), 1); - *((double *)PyArray_DATA(ret)) = result; - } - else if (typenum == NPY_CDOUBLE) { - cblas_zdotu_sub(l, (double *)PyArray_DATA(ap1), 1, - (double *)PyArray_DATA(ap2), 1, (double *)PyArray_DATA(ret)); - } - else if (typenum == NPY_FLOAT) { - float result = cblas_sdot(l, (float *)PyArray_DATA(ap1), 1, - (float *)PyArray_DATA(ap2), 1); - *((float *)PyArray_DATA(ret)) = result; - } - else if (typenum == NPY_CFLOAT) { - cblas_cdotu_sub(l, (float *)PyArray_DATA(ap1), 1, - (float *)PyArray_DATA(ap2), 1, (float *)PyArray_DATA(ret)); - } + blas_dot(typenum, l, PyArray_DATA(ap1), PyArray_ITEMSIZE(ap1), + PyArray_DATA(ap2), PyArray_ITEMSIZE(ap2), PyArray_DATA(ret)); } else if (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 1) { /* Matrix-vector multiplication -- Level 2 BLAS */ @@ -1116,16 +1161,12 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) { ret = (PyArrayObject *)PyArray_SimpleNew(0, dimensions, typenum); if (ret == NULL) goto fail; - NPY_BEGIN_ALLOW_THREADS + NPY_BEGIN_ALLOW_THREADS; /* Dot product between two vectors -- Level 1 BLAS */ - if (typenum == NPY_DOUBLE) { - *((double *)PyArray_DATA(ret)) = cblas_ddot(l, (double *)PyArray_DATA(ap1), 1, - (double *)PyArray_DATA(ap2), 1); - } - else if (typenum == NPY_FLOAT) { - *((float *)PyArray_DATA(ret)) = cblas_sdot(l, (float *)PyArray_DATA(ap1), 1, - (float *)PyArray_DATA(ap2), 1); + if (typenum == NPY_DOUBLE || typenum == NPY_FLOAT) { + blas_dot(typenum, l, PyArray_DATA(ap1), PyArray_ITEMSIZE(ap1), + PyArray_DATA(ap2), PyArray_ITEMSIZE(ap2), PyArray_DATA(ret)); } else if (typenum == NPY_CDOUBLE) { cblas_zdotc_sub(l, (double *)PyArray_DATA(ap1), 1, @@ -1136,7 +1177,7 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) { (float *)PyArray_DATA(ap2), 1, (float *)PyArray_DATA(ret)); } - NPY_END_ALLOW_THREADS + NPY_END_ALLOW_THREADS; Py_DECREF(ap1); Py_DECREF(ap2); |