summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLars Buitinck <larsmans@gmail.com>2013-10-14 19:36:38 +0200
committerLars Buitinck <larsmans@gmail.com>2013-11-21 17:48:20 +0100
commit46dd681ff4aed02a87b27db684f8600b5b7a7fa3 (patch)
tree868c77952465ff3d6a2a9a0935fba900f9d0cf7e
parent78e29a323316642899f8ff85e538b785f0d5e31f (diff)
downloadnumpy-46dd681ff4aed02a87b27db684f8600b5b7a7fa3.tar.gz
BUG/ENH: blasdot: make *dot calls 64-bit safe
BLAS uses int for all its size types, so the size is going to be silently truncated on platforms where sizeof npy_intp > sizeof int, i.e. on modern 64-bit platforms. In this case, we do a dot product in reasonably-sized chunks and add the result ourselves. On platforms where sizeof(int) == sizeof(npy_intp), i.e. typical 32-bit platforms, it's up to the compiler to optimize the loop away. Also tried to make the code more readable, replaced int by npy_intp in a few places and removed useless register declarations (modern compilers ignore those).
-rw-r--r--numpy/core/blasdot/_dotblas.c120
1 files changed, 83 insertions, 37 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c
index a8b34593e..7fd6d25ff 100644
--- a/numpy/core/blasdot/_dotblas.c
+++ b/numpy/core/blasdot/_dotblas.c
@@ -15,6 +15,7 @@
#endif
#include CBLAS_HEADER
+#include <limits.h>
#include <stdio.h>
static char module_doc[] =
@@ -22,66 +23,111 @@ static char module_doc[] =
static PyArray_DotFunc *oldFunctions[NPY_NTYPES];
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+
+/*
+ * The following functions do a "chunked" dot product using BLAS when
+ * sizeof(npy_intp) > sizeof(int), because BLAS libraries can typically not
+ * handle more than INT_MAX elements per call.
+ *
+ * The chunksize is the greatest power of two less than INT_MAX.
+ */
+#if NPY_MAX_INTP > INT_MAX
+# define CHUNKSIZE (INT_MAX / 2 + 1)
+#else
+# define CHUNKSIZE NPY_MAX_INTP
+#endif
+
static void
FLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
- register npy_intp na = stridea / sizeof(float);
- register npy_intp nb = strideb / sizeof(float);
-
- if ((sizeof(float) * na == (size_t)stridea) &&
- (sizeof(float) * nb == (size_t)strideb) &&
- (na >= 0) && (nb >= 0))
- *((float *)res) = cblas_sdot((int)n, (float *)a, na, (float *)b, nb);
-
- else
- oldFunctions[NPY_FLOAT](a, stridea, b, strideb, res, n, tmp);
+ npy_intp na = stridea / (int)sizeof(float);
+ npy_intp nb = strideb / (int)sizeof(float);
+
+ /* XXX we could handle negative strides as well */
+ if (stridea > 0 && strideb > 0 &&
+ stridea == sizeof(float) * na &&
+ strideb == sizeof(float) * nb) {
+ double r = 0.; /* double for stability */
+ float *fa = a, *fb = b;
+
+ while (n > 0) {
+ int chunk = MIN(n, CHUNKSIZE);
+
+ r += cblas_sdot(chunk, fa, na, fb, nb);
+ fa += chunk * na;
+ fb += chunk * nb;
+ n -= chunk;
+ }
+ *((float *)res) = r;
+ }
+ else {
+ oldFunctions[NPY_FLOAT](a, stridea, b, strideb, res, n, tmp);
+ }
}
static void
DOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
- register int na = stridea / sizeof(double);
- register int nb = strideb / sizeof(double);
-
- if ((sizeof(double) * na == (size_t)stridea) &&
- (sizeof(double) * nb == (size_t)strideb) &&
- (na >= 0) && (nb >= 0))
- *((double *)res) = cblas_ddot((int)n, (double *)a, na, (double *)b, nb);
- else
- oldFunctions[NPY_DOUBLE](a, stridea, b, strideb, res, n, tmp);
+ npy_intp na = stridea / (int)sizeof(double);
+ npy_intp nb = strideb / (int)sizeof(double);
+
+ if (stridea > 0 && strideb > 0 &&
+ stridea == sizeof(double) * na &&
+ strideb == sizeof(double) * nb) {
+ double r = 0.;
+ double *da = a, *db = b;
+
+ while (n > 0) {
+ int chunk = MIN(n, CHUNKSIZE);
+
+ r += cblas_ddot(chunk, da, na, db, nb);
+ da += chunk * na;
+ db += chunk * nb;
+ n -= chunk;
+ }
+ *((double *)res) = r;
+ }
+ else {
+ oldFunctions[NPY_DOUBLE](a, stridea, b, strideb, res, n, tmp);
+ }
}
static void
CFLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
+ npy_intp na = stridea / (int)sizeof(npy_cfloat);
+ npy_intp nb = strideb / (int)sizeof(npy_cfloat);
- register int na = stridea / sizeof(npy_cfloat);
- register int nb = strideb / sizeof(npy_cfloat);
-
- if ((sizeof(npy_cfloat) * na == (size_t)stridea) &&
- (sizeof(npy_cfloat) * nb == (size_t)strideb) &&
- (na >= 0) && (nb >= 0))
- cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res);
- else
- oldFunctions[NPY_CFLOAT](a, stridea, b, strideb, res, n, tmp);
+ if (stridea > 0 && strideb > 0 && n <= INT_MAX &&
+ stridea == sizeof(npy_cfloat) * na &&
+ strideb == sizeof(npy_cfloat) * nb) {
+ cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res);
+ }
+ else {
+ oldFunctions[NPY_CFLOAT](a, stridea, b, strideb, res, n, tmp);
+ }
}
static void
CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
- register int na = stridea / sizeof(npy_cdouble);
- register int nb = strideb / sizeof(npy_cdouble);
-
- if ((sizeof(npy_cdouble) * na == (size_t)stridea) &&
- (sizeof(npy_cdouble) * nb == (size_t)strideb) &&
- (na >= 0) && (nb >= 0))
- cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb, (double *)res);
- else
- oldFunctions[NPY_CDOUBLE](a, stridea, b, strideb, res, n, tmp);
+ npy_intp na = stridea / (int)sizeof(npy_cdouble);
+ npy_intp nb = strideb / (int)sizeof(npy_cdouble);
+
+ if (stridea > 0 && strideb > 0 && n <= INT_MAX &&
+ stridea == sizeof(npy_cdouble) * na &&
+ strideb == sizeof(npy_cdouble) * nb) {
+ cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb,
+ (double *)res);
+ }
+ else {
+ oldFunctions[NPY_CDOUBLE](a, stridea, b, strideb, res, n, tmp);
+ }
}