diff options
author | Lars Buitinck <larsmans@gmail.com> | 2013-10-14 19:36:38 +0200 |
---|---|---|
committer | Lars Buitinck <larsmans@gmail.com> | 2013-11-21 17:48:20 +0100 |
commit | 46dd681ff4aed02a87b27db684f8600b5b7a7fa3 (patch) | |
tree | 868c77952465ff3d6a2a9a0935fba900f9d0cf7e | |
parent | 78e29a323316642899f8ff85e538b785f0d5e31f (diff) | |
download | numpy-46dd681ff4aed02a87b27db684f8600b5b7a7fa3.tar.gz |
BUG/ENH: blasdot: make *dot calls 64-bit safe
BLAS uses int for all its size types, so the size is going to be
silently truncated on platforms where sizeof npy_intp > sizeof int,
i.e. on modern 64-bit platforms. In this case, we do a dot product
in reasonably-sized chunks and add the result ourselves.
On platforms where sizeof(int) == sizeof(npy_intp), i.e. typical
32-bit platforms, it's up to the compiler to optimize the loop away.
Also tried to make the code more readable, replaced int by npy_intp
in a few places and removed useless register declarations (modern
compilers ignore those).
-rw-r--r-- | numpy/core/blasdot/_dotblas.c | 120 |
1 files changed, 83 insertions, 37 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c index a8b34593e..7fd6d25ff 100644 --- a/numpy/core/blasdot/_dotblas.c +++ b/numpy/core/blasdot/_dotblas.c @@ -15,6 +15,7 @@ #endif #include CBLAS_HEADER +#include <limits.h> #include <stdio.h> static char module_doc[] = @@ -22,66 +23,111 @@ static char module_doc[] = static PyArray_DotFunc *oldFunctions[NPY_NTYPES]; +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +/* + * The following functions do a "chunked" dot product using BLAS when + * sizeof(npy_intp) > sizeof(int), because BLAS libraries can typically not + * handle more than INT_MAX elements per call. + * + * The chunksize is the greatest power of two less than INT_MAX. + */ +#if NPY_MAX_INTP > INT_MAX +# define CHUNKSIZE (INT_MAX / 2 + 1) +#else +# define CHUNKSIZE NPY_MAX_INTP +#endif + static void FLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { - register npy_intp na = stridea / sizeof(float); - register npy_intp nb = strideb / sizeof(float); - - if ((sizeof(float) * na == (size_t)stridea) && - (sizeof(float) * nb == (size_t)strideb) && - (na >= 0) && (nb >= 0)) - *((float *)res) = cblas_sdot((int)n, (float *)a, na, (float *)b, nb); - - else - oldFunctions[NPY_FLOAT](a, stridea, b, strideb, res, n, tmp); + npy_intp na = stridea / (int)sizeof(float); + npy_intp nb = strideb / (int)sizeof(float); + + /* XXX we could handle negative strides as well */ + if (stridea > 0 && strideb > 0 && + stridea == sizeof(float) * na && + strideb == sizeof(float) * nb) { + double r = 0.; /* double for stability */ + float *fa = a, *fb = b; + + while (n > 0) { + int chunk = MIN(n, CHUNKSIZE); + + r += cblas_sdot(chunk, fa, na, fb, nb); + fa += chunk * na; + fb += chunk * nb; + n -= chunk; + } + *((float *)res) = r; + } + else { + oldFunctions[NPY_FLOAT](a, stridea, b, strideb, res, n, tmp); + } } static void DOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { - register int na = stridea / sizeof(double); - register int nb = strideb / sizeof(double); - - if ((sizeof(double) * na == (size_t)stridea) && - (sizeof(double) * nb == (size_t)strideb) && - (na >= 0) && (nb >= 0)) - *((double *)res) = cblas_ddot((int)n, (double *)a, na, (double *)b, nb); - else - oldFunctions[NPY_DOUBLE](a, stridea, b, strideb, res, n, tmp); + npy_intp na = stridea / (int)sizeof(double); + npy_intp nb = strideb / (int)sizeof(double); + + if (stridea > 0 && strideb > 0 && + stridea == sizeof(double) * na && + strideb == sizeof(double) * nb) { + double r = 0.; + double *da = a, *db = b; + + while (n > 0) { + int chunk = MIN(n, CHUNKSIZE); + + r += cblas_ddot(chunk, da, na, db, nb); + da += chunk * na; + db += chunk * nb; + n -= chunk; + } + *((double *)res) = r; + } + else { + oldFunctions[NPY_DOUBLE](a, stridea, b, strideb, res, n, tmp); + } } static void CFLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { + npy_intp na = stridea / (int)sizeof(npy_cfloat); + npy_intp nb = strideb / (int)sizeof(npy_cfloat); - register int na = stridea / sizeof(npy_cfloat); - register int nb = strideb / sizeof(npy_cfloat); - - if ((sizeof(npy_cfloat) * na == (size_t)stridea) && - (sizeof(npy_cfloat) * nb == (size_t)strideb) && - (na >= 0) && (nb >= 0)) - cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res); - else - oldFunctions[NPY_CFLOAT](a, stridea, b, strideb, res, n, tmp); + if (stridea > 0 && strideb > 0 && n <= INT_MAX && + stridea == sizeof(npy_cfloat) * na && + strideb == sizeof(npy_cfloat) * nb) { + cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res); + } + else { + oldFunctions[NPY_CFLOAT](a, stridea, b, strideb, res, n, tmp); + } } static void CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, npy_intp n, void *tmp) { - register int na = stridea / sizeof(npy_cdouble); - register int nb = strideb / sizeof(npy_cdouble); - - if ((sizeof(npy_cdouble) * na == (size_t)stridea) && - (sizeof(npy_cdouble) * nb == (size_t)strideb) && - (na >= 0) && (nb >= 0)) - cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb, (double *)res); - else - oldFunctions[NPY_CDOUBLE](a, stridea, b, strideb, res, n, tmp); + npy_intp na = stridea / (int)sizeof(npy_cdouble); + npy_intp nb = strideb / (int)sizeof(npy_cdouble); + + if (stridea > 0 && strideb > 0 && n <= INT_MAX && + stridea == sizeof(npy_cdouble) * na && + strideb == sizeof(npy_cdouble) * nb) { + cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb, + (double *)res); + } + else { + oldFunctions[NPY_CDOUBLE](a, stridea, b, strideb, res, n, tmp); + } } |