summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLars Buitinck <larsmans@gmail.com>2013-10-23 21:55:33 +0200
committerLars Buitinck <larsmans@gmail.com>2013-11-21 17:50:01 +0100
commit1d1803257467c554db208a24b857d613b872b9bd (patch)
treece2f8937b13dc757097e99418815214856907c50
parent46dd681ff4aed02a87b27db684f8600b5b7a7fa3 (diff)
downloadnumpy-1d1803257467c554db208a24b857d613b872b9bd.tar.gz
ENH: dotblas: re-route all cblas_?dot calls to chunked versions
Also factored out and simplified NumPy->BLAS stride conversion.
-rw-r--r--numpy/core/blasdot/_dotblas.c147
1 files changed, 71 insertions, 76 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c
index 7fd6d25ff..48aa39ff8 100644
--- a/numpy/core/blasdot/_dotblas.c
+++ b/numpy/core/blasdot/_dotblas.c
@@ -15,6 +15,7 @@
#endif
#include CBLAS_HEADER
+#include <assert.h>
#include <limits.h>
#include <stdio.h>
@@ -26,6 +27,24 @@ static PyArray_DotFunc *oldFunctions[NPY_NTYPES];
#define MIN(a, b) ((a) < (b) ? (a) : (b))
/*
+ * Convert NumPy stride to BLAS stride. Returns 0 if conversion cannot be done
+ * (BLAS won't handle negative or zero strides the way we want).
+ */
+static NPY_INLINE int
+blas_stride(npy_intp stride, unsigned itemsize)
+{
+ if (stride <= 0 || stride % itemsize != 0) {
+ return 0;
+ }
+ stride /= itemsize;
+
+ if (stride > INT_MAX) {
+ return 0;
+ }
+ return stride;
+}
+
+/*
* The following functions do a "chunked" dot product using BLAS when
* sizeof(npy_intp) > sizeof(int), because BLAS libraries can typically not
* handle more than INT_MAX elements per call.
@@ -42,13 +61,10 @@ static void
FLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
- npy_intp na = stridea / (int)sizeof(float);
- npy_intp nb = strideb / (int)sizeof(float);
+ int na = blas_stride(stridea, sizeof(float));
+ int nb = blas_stride(strideb, sizeof(float));
- /* XXX we could handle negative strides as well */
- if (stridea > 0 && strideb > 0 &&
- stridea == sizeof(float) * na &&
- strideb == sizeof(float) * nb) {
+ if (na && nb) {
double r = 0.; /* double for stability */
float *fa = a, *fb = b;
@@ -71,12 +87,10 @@ static void
DOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
- npy_intp na = stridea / (int)sizeof(double);
- npy_intp nb = strideb / (int)sizeof(double);
+ int na = blas_stride(stridea, sizeof(double));
+ int nb = blas_stride(strideb, sizeof(double));
- if (stridea > 0 && strideb > 0 &&
- stridea == sizeof(double) * na &&
- strideb == sizeof(double) * nb) {
+ if (na && nb) {
double r = 0.;
double *da = a, *db = b;
@@ -99,12 +113,10 @@ static void
CFLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
- npy_intp na = stridea / (int)sizeof(npy_cfloat);
- npy_intp nb = strideb / (int)sizeof(npy_cfloat);
+ int na = blas_stride(stridea, sizeof(npy_cfloat));
+ int nb = blas_stride(strideb, sizeof(npy_cfloat));
- if (stridea > 0 && strideb > 0 && n <= INT_MAX &&
- stridea == sizeof(npy_cfloat) * na &&
- strideb == sizeof(npy_cfloat) * nb) {
+ if (na && nb) {
cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res);
}
else {
@@ -116,12 +128,10 @@ static void
CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
npy_intp n, void *tmp)
{
- npy_intp na = stridea / (int)sizeof(npy_cdouble);
- npy_intp nb = strideb / (int)sizeof(npy_cdouble);
+ int na = blas_stride(stridea, sizeof(npy_cdouble));
+ int nb = blas_stride(strideb, sizeof(npy_cdouble));
- if (stridea > 0 && strideb > 0 && n <= INT_MAX &&
- stridea == sizeof(npy_cdouble) * na &&
- strideb == sizeof(npy_cdouble) * nb) {
+ if (na && nb) {
cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb,
(double *)res);
}
@@ -130,6 +140,33 @@ CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
}
}
+/*
+ * Helper: call appropriate BLAS dot function for typenum.
+ * Strides are NumPy strides.
+ */
+static void
+blas_dot(int typenum, npy_intp n,
+ void *a, npy_intp stridea, void *b, npy_intp strideb, void *res)
+{
+ PyArray_DotFunc *dot = NULL;
+ switch (typenum) {
+ case NPY_DOUBLE:
+ dot = DOUBLE_dot;
+ break;
+ case NPY_FLOAT:
+ dot = FLOAT_dot;
+ break;
+ case NPY_CDOUBLE:
+ dot = CDOUBLE_dot;
+ break;
+ case NPY_CFLOAT:
+ dot = CFLOAT_dot;
+ break;
+ }
+ assert(dot != NULL);
+ dot(a, stridea, b, strideb, res, n, NULL);
+}
+
static const double oneD[2] = {1.0, 0.0}, zeroD[2] = {0.0, 0.0};
static const float oneF[2] = {1.0, 0.0}, zeroF[2] = {0.0, 0.0};
@@ -342,7 +379,8 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
PyObject *op1, *op2;
PyArrayObject *ap1 = NULL, *ap2 = NULL, *out = NULL, *ret = NULL;
int errval;
- int j, l, lda, ldb;
+ int j, lda, ldb;
+ npy_intp l;
int typenum, nd;
npy_intp ap1stride = 0;
npy_intp dimensions[NPY_MAXDIMS];
@@ -746,36 +784,13 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
NPY_END_ALLOW_THREADS;
}
else if ((ap2shape == _column) && (ap1shape != _matrix)) {
- int ap1s, ap2s;
NPY_BEGIN_ALLOW_THREADS;
- ap2s = PyArray_STRIDE(ap2, 0) / PyArray_ITEMSIZE(ap2);
- if (ap1shape == _row) {
- ap1s = PyArray_STRIDE(ap1, 1) / PyArray_ITEMSIZE(ap1);
- }
- else {
- ap1s = PyArray_STRIDE(ap1, 0) / PyArray_ITEMSIZE(ap1);
- }
-
/* Dot product between two vectors -- Level 1 BLAS */
- if (typenum == NPY_DOUBLE) {
- double result = cblas_ddot(l, (double *)PyArray_DATA(ap1), ap1s,
- (double *)PyArray_DATA(ap2), ap2s);
- *((double *)PyArray_DATA(ret)) = result;
- }
- else if (typenum == NPY_FLOAT) {
- float result = cblas_sdot(l, (float *)PyArray_DATA(ap1), ap1s,
- (float *)PyArray_DATA(ap2), ap2s);
- *((float *)PyArray_DATA(ret)) = result;
- }
- else if (typenum == NPY_CDOUBLE) {
- cblas_zdotu_sub(l, (double *)PyArray_DATA(ap1), ap1s,
- (double *)PyArray_DATA(ap2), ap2s, (double *)PyArray_DATA(ret));
- }
- else if (typenum == NPY_CFLOAT) {
- cblas_cdotu_sub(l, (float *)PyArray_DATA(ap1), ap1s,
- (float *)PyArray_DATA(ap2), ap2s, (float *)PyArray_DATA(ret));
- }
+ blas_dot(typenum, l,
+ PyArray_DATA(ap1), PyArray_STRIDE(ap1, (ap1shape == _row)),
+ PyArray_DATA(ap2), PyArray_STRIDE(ap2, 0),
+ PyArray_DATA(ret));
NPY_END_ALLOW_THREADS;
}
else if (ap1shape == _matrix && ap2shape != _matrix) {
@@ -1040,24 +1055,8 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
}
else if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 1) {
/* Dot product between two vectors -- Level 1 BLAS */
- if (typenum == NPY_DOUBLE) {
- double result = cblas_ddot(l, (double *)PyArray_DATA(ap1), 1,
- (double *)PyArray_DATA(ap2), 1);
- *((double *)PyArray_DATA(ret)) = result;
- }
- else if (typenum == NPY_CDOUBLE) {
- cblas_zdotu_sub(l, (double *)PyArray_DATA(ap1), 1,
- (double *)PyArray_DATA(ap2), 1, (double *)PyArray_DATA(ret));
- }
- else if (typenum == NPY_FLOAT) {
- float result = cblas_sdot(l, (float *)PyArray_DATA(ap1), 1,
- (float *)PyArray_DATA(ap2), 1);
- *((float *)PyArray_DATA(ret)) = result;
- }
- else if (typenum == NPY_CFLOAT) {
- cblas_cdotu_sub(l, (float *)PyArray_DATA(ap1), 1,
- (float *)PyArray_DATA(ap2), 1, (float *)PyArray_DATA(ret));
- }
+ blas_dot(typenum, l, PyArray_DATA(ap1), PyArray_ITEMSIZE(ap1),
+ PyArray_DATA(ap2), PyArray_ITEMSIZE(ap2), PyArray_DATA(ret));
}
else if (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 1) {
/* Matrix-vector multiplication -- Level 2 BLAS */
@@ -1162,16 +1161,12 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) {
ret = (PyArrayObject *)PyArray_SimpleNew(0, dimensions, typenum);
if (ret == NULL) goto fail;
- NPY_BEGIN_ALLOW_THREADS
+ NPY_BEGIN_ALLOW_THREADS;
/* Dot product between two vectors -- Level 1 BLAS */
- if (typenum == NPY_DOUBLE) {
- *((double *)PyArray_DATA(ret)) = cblas_ddot(l, (double *)PyArray_DATA(ap1), 1,
- (double *)PyArray_DATA(ap2), 1);
- }
- else if (typenum == NPY_FLOAT) {
- *((float *)PyArray_DATA(ret)) = cblas_sdot(l, (float *)PyArray_DATA(ap1), 1,
- (float *)PyArray_DATA(ap2), 1);
+ if (typenum == NPY_DOUBLE || typenum == NPY_FLOAT) {
+ blas_dot(typenum, l, PyArray_DATA(ap1), PyArray_ITEMSIZE(ap1),
+ PyArray_DATA(ap2), PyArray_ITEMSIZE(ap2), PyArray_DATA(ret));
}
else if (typenum == NPY_CDOUBLE) {
cblas_zdotc_sub(l, (double *)PyArray_DATA(ap1), 1,
@@ -1182,7 +1177,7 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) {
(float *)PyArray_DATA(ap2), 1, (float *)PyArray_DATA(ret));
}
- NPY_END_ALLOW_THREADS
+ NPY_END_ALLOW_THREADS;
Py_DECREF(ap1);
Py_DECREF(ap2);