summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorseberg <sebastian@sipsolutions.net>2013-12-30 02:23:49 -0800
committerseberg <sebastian@sipsolutions.net>2013-12-30 02:23:49 -0800
commit14fb7000694523919fe4045096fbfa3703b1d1fe (patch)
treeb1577b6a365976b9fd0f08d7f38cf60b4f3a7176
parent61998c22df08e7ec0938bedef931ac824cfb634a (diff)
parent1d1803257467c554db208a24b857d613b872b9bd (diff)
downloadnumpy-14fb7000694523919fe4045096fbfa3703b1d1fe.tar.gz
Merge pull request #3916 from larsmans/dotblas-64bit-safe
64-bit safe _dotblas
-rw-r--r--numpy/core/blasdot/_dotblas.c221
1 files changed, 131 insertions, 90 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c
index a8b34593e..48aa39ff8 100644
--- a/numpy/core/blasdot/_dotblas.c
+++ b/numpy/core/blasdot/_dotblas.c
@@ -15,6 +15,8 @@
#endif
#include CBLAS_HEADER
+#include <assert.h>
+#include <limits.h>
#include <stdio.h>
static char module_doc[] =
@@ -22,66 +24,147 @@ static char module_doc[] =
static PyArray_DotFunc *oldFunctions[NPY_NTYPES];
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+
+/*
+ * Convert NumPy stride to BLAS stride. Returns 0 if conversion cannot be done
+ * (BLAS won't handle negative or zero strides the way we want).
+ */
+static NPY_INLINE int
+blas_stride(npy_intp stride, unsigned itemsize)
+{
+ if (stride <= 0 || stride % itemsize != 0) {
+ return 0;
+ }
+ stride /= itemsize;
+
+ if (stride > INT_MAX) {
+ return 0;
+ }
+ return stride;
+}
+
+/*
+ * The following functions do a "chunked" dot product using BLAS when
+ * sizeof(npy_intp) > sizeof(int), because BLAS libraries can typically not
+ * handle more than INT_MAX elements per call.
+ *
+ * The chunksize is the greatest power of two less than INT_MAX.
+ */
+#if NPY_MAX_INTP > INT_MAX
+# define CHUNKSIZE (INT_MAX / 2 + 1)
+#else
+# define CHUNKSIZE NPY_MAX_INTP
+#endif
+
static void
FLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
- register npy_intp na = stridea / sizeof(float);
- register npy_intp nb = strideb / sizeof(float);
+ int na = blas_stride(stridea, sizeof(float));
+ int nb = blas_stride(strideb, sizeof(float));
- if ((sizeof(float) * na == (size_t)stridea) &&
- (sizeof(float) * nb == (size_t)strideb) &&
- (na >= 0) && (nb >= 0))
- *((float *)res) = cblas_sdot((int)n, (float *)a, na, (float *)b, nb);
+ if (na && nb) {
+ double r = 0.; /* double for stability */
+ float *fa = a, *fb = b;
- else
- oldFunctions[NPY_FLOAT](a, stridea, b, strideb, res, n, tmp);
+ while (n > 0) {
+ int chunk = MIN(n, CHUNKSIZE);
+
+ r += cblas_sdot(chunk, fa, na, fb, nb);
+ fa += chunk * na;
+ fb += chunk * nb;
+ n -= chunk;
+ }
+ *((float *)res) = r;
+ }
+ else {
+ oldFunctions[NPY_FLOAT](a, stridea, b, strideb, res, n, tmp);
+ }
}
static void
DOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
- register int na = stridea / sizeof(double);
- register int nb = strideb / sizeof(double);
-
- if ((sizeof(double) * na == (size_t)stridea) &&
- (sizeof(double) * nb == (size_t)strideb) &&
- (na >= 0) && (nb >= 0))
- *((double *)res) = cblas_ddot((int)n, (double *)a, na, (double *)b, nb);
- else
- oldFunctions[NPY_DOUBLE](a, stridea, b, strideb, res, n, tmp);
+ int na = blas_stride(stridea, sizeof(double));
+ int nb = blas_stride(strideb, sizeof(double));
+
+ if (na && nb) {
+ double r = 0.;
+ double *da = a, *db = b;
+
+ while (n > 0) {
+ int chunk = MIN(n, CHUNKSIZE);
+
+ r += cblas_ddot(chunk, da, na, db, nb);
+ da += chunk * na;
+ db += chunk * nb;
+ n -= chunk;
+ }
+ *((double *)res) = r;
+ }
+ else {
+ oldFunctions[NPY_DOUBLE](a, stridea, b, strideb, res, n, tmp);
+ }
}
static void
CFLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
+ int na = blas_stride(stridea, sizeof(npy_cfloat));
+ int nb = blas_stride(strideb, sizeof(npy_cfloat));
- register int na = stridea / sizeof(npy_cfloat);
- register int nb = strideb / sizeof(npy_cfloat);
-
- if ((sizeof(npy_cfloat) * na == (size_t)stridea) &&
- (sizeof(npy_cfloat) * nb == (size_t)strideb) &&
- (na >= 0) && (nb >= 0))
- cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res);
- else
- oldFunctions[NPY_CFLOAT](a, stridea, b, strideb, res, n, tmp);
+ if (na && nb) {
+ cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res);
+ }
+ else {
+ oldFunctions[NPY_CFLOAT](a, stridea, b, strideb, res, n, tmp);
+ }
}
static void
CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
- register int na = stridea / sizeof(npy_cdouble);
- register int nb = strideb / sizeof(npy_cdouble);
-
- if ((sizeof(npy_cdouble) * na == (size_t)stridea) &&
- (sizeof(npy_cdouble) * nb == (size_t)strideb) &&
- (na >= 0) && (nb >= 0))
- cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb, (double *)res);
- else
- oldFunctions[NPY_CDOUBLE](a, stridea, b, strideb, res, n, tmp);
+ int na = blas_stride(stridea, sizeof(npy_cdouble));
+ int nb = blas_stride(strideb, sizeof(npy_cdouble));
+
+ if (na && nb) {
+ cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb,
+ (double *)res);
+ }
+ else {
+ oldFunctions[NPY_CDOUBLE](a, stridea, b, strideb, res, n, tmp);
+ }
+}
+
+/*
+ * Helper: call appropriate BLAS dot function for typenum.
+ * Strides are NumPy strides.
+ */
+static void
+blas_dot(int typenum, npy_intp n,
+ void *a, npy_intp stridea, void *b, npy_intp strideb, void *res)
+{
+ PyArray_DotFunc *dot = NULL;
+ switch (typenum) {
+ case NPY_DOUBLE:
+ dot = DOUBLE_dot;
+ break;
+ case NPY_FLOAT:
+ dot = FLOAT_dot;
+ break;
+ case NPY_CDOUBLE:
+ dot = CDOUBLE_dot;
+ break;
+ case NPY_CFLOAT:
+ dot = CFLOAT_dot;
+ break;
+ }
+ assert(dot != NULL);
+ dot(a, stridea, b, strideb, res, n, NULL);
}
@@ -296,7 +379,8 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
PyObject *op1, *op2;
PyArrayObject *ap1 = NULL, *ap2 = NULL, *out = NULL, *ret = NULL;
int errval;
- int j, l, lda, ldb;
+ int j, lda, ldb;
+ npy_intp l;
int typenum, nd;
npy_intp ap1stride = 0;
npy_intp dimensions[NPY_MAXDIMS];
@@ -700,36 +784,13 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
NPY_END_ALLOW_THREADS;
}
else if ((ap2shape == _column) && (ap1shape != _matrix)) {
- int ap1s, ap2s;
NPY_BEGIN_ALLOW_THREADS;
- ap2s = PyArray_STRIDE(ap2, 0) / PyArray_ITEMSIZE(ap2);
- if (ap1shape == _row) {
- ap1s = PyArray_STRIDE(ap1, 1) / PyArray_ITEMSIZE(ap1);
- }
- else {
- ap1s = PyArray_STRIDE(ap1, 0) / PyArray_ITEMSIZE(ap1);
- }
-
/* Dot product between two vectors -- Level 1 BLAS */
- if (typenum == NPY_DOUBLE) {
- double result = cblas_ddot(l, (double *)PyArray_DATA(ap1), ap1s,
- (double *)PyArray_DATA(ap2), ap2s);
- *((double *)PyArray_DATA(ret)) = result;
- }
- else if (typenum == NPY_FLOAT) {
- float result = cblas_sdot(l, (float *)PyArray_DATA(ap1), ap1s,
- (float *)PyArray_DATA(ap2), ap2s);
- *((float *)PyArray_DATA(ret)) = result;
- }
- else if (typenum == NPY_CDOUBLE) {
- cblas_zdotu_sub(l, (double *)PyArray_DATA(ap1), ap1s,
- (double *)PyArray_DATA(ap2), ap2s, (double *)PyArray_DATA(ret));
- }
- else if (typenum == NPY_CFLOAT) {
- cblas_cdotu_sub(l, (float *)PyArray_DATA(ap1), ap1s,
- (float *)PyArray_DATA(ap2), ap2s, (float *)PyArray_DATA(ret));
- }
+ blas_dot(typenum, l,
+ PyArray_DATA(ap1), PyArray_STRIDE(ap1, (ap1shape == _row)),
+ PyArray_DATA(ap2), PyArray_STRIDE(ap2, 0),
+ PyArray_DATA(ret));
NPY_END_ALLOW_THREADS;
}
else if (ap1shape == _matrix && ap2shape != _matrix) {
@@ -994,24 +1055,8 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
}
else if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 1) {
/* Dot product between two vectors -- Level 1 BLAS */
- if (typenum == NPY_DOUBLE) {
- double result = cblas_ddot(l, (double *)PyArray_DATA(ap1), 1,
- (double *)PyArray_DATA(ap2), 1);
- *((double *)PyArray_DATA(ret)) = result;
- }
- else if (typenum == NPY_CDOUBLE) {
- cblas_zdotu_sub(l, (double *)PyArray_DATA(ap1), 1,
- (double *)PyArray_DATA(ap2), 1, (double *)PyArray_DATA(ret));
- }
- else if (typenum == NPY_FLOAT) {
- float result = cblas_sdot(l, (float *)PyArray_DATA(ap1), 1,
- (float *)PyArray_DATA(ap2), 1);
- *((float *)PyArray_DATA(ret)) = result;
- }
- else if (typenum == NPY_CFLOAT) {
- cblas_cdotu_sub(l, (float *)PyArray_DATA(ap1), 1,
- (float *)PyArray_DATA(ap2), 1, (float *)PyArray_DATA(ret));
- }
+ blas_dot(typenum, l, PyArray_DATA(ap1), PyArray_ITEMSIZE(ap1),
+ PyArray_DATA(ap2), PyArray_ITEMSIZE(ap2), PyArray_DATA(ret));
}
else if (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 1) {
/* Matrix-vector multiplication -- Level 2 BLAS */
@@ -1116,16 +1161,12 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) {
ret = (PyArrayObject *)PyArray_SimpleNew(0, dimensions, typenum);
if (ret == NULL) goto fail;
- NPY_BEGIN_ALLOW_THREADS
+ NPY_BEGIN_ALLOW_THREADS;
/* Dot product between two vectors -- Level 1 BLAS */
- if (typenum == NPY_DOUBLE) {
- *((double *)PyArray_DATA(ret)) = cblas_ddot(l, (double *)PyArray_DATA(ap1), 1,
- (double *)PyArray_DATA(ap2), 1);
- }
- else if (typenum == NPY_FLOAT) {
- *((float *)PyArray_DATA(ret)) = cblas_sdot(l, (float *)PyArray_DATA(ap1), 1,
- (float *)PyArray_DATA(ap2), 1);
+ if (typenum == NPY_DOUBLE || typenum == NPY_FLOAT) {
+ blas_dot(typenum, l, PyArray_DATA(ap1), PyArray_ITEMSIZE(ap1),
+ PyArray_DATA(ap2), PyArray_ITEMSIZE(ap2), PyArray_DATA(ret));
}
else if (typenum == NPY_CDOUBLE) {
cblas_zdotc_sub(l, (double *)PyArray_DATA(ap1), 1,
@@ -1136,7 +1177,7 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) {
(float *)PyArray_DATA(ap2), 1, (float *)PyArray_DATA(ret));
}
- NPY_END_ALLOW_THREADS
+ NPY_END_ALLOW_THREADS;
Py_DECREF(ap1);
Py_DECREF(ap2);