diff options
author | John Kirkham <kirkhamj@janelia.hhmi.org> | 2016-01-07 15:53:13 -0500 |
---|---|---|
committer | John Kirkham <kirkhamj@janelia.hhmi.org> | 2016-01-11 19:42:53 -0500 |
commit | 88c8a9c22013eb6aa876adfe895b339bc602e6ac (patch) | |
tree | ca174976ca23245b7189c7b1366d0dfb86590506 /numpy | |
parent | bab118da49f051aecf35296bb9d8a00edd5b4198 (diff) | |
download | numpy-88c8a9c22013eb6aa876adfe895b339bc602e6ac.tar.gz |
MAINT: Refactor `cblas_innerproduct` to use `cblas_matrixproduct`.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/cblasfuncs.c | 138 |
1 files changed, 9 insertions, 129 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c index 902946706..180d96e30 100644 --- a/numpy/core/src/multiarray/cblasfuncs.c +++ b/numpy/core/src/multiarray/cblasfuncs.c @@ -765,148 +765,28 @@ fail: NPY_NO_EXPORT PyObject * cblas_innerproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2) { - int j, l, lda, ldb; - int nd; - double prior1, prior2; - PyArrayObject *ret = NULL; - npy_intp dimensions[NPY_MAXDIMS]; - PyTypeObject *subtype; + PyArrayObject* ap2t = NULL; + PyArrayObject* ret = NULL; - /* assure contiguous arrays */ - if (!PyArray_IS_C_CONTIGUOUS(ap1)) { - PyObject *op1 = PyArray_NewCopy(ap1, NPY_CORDER); - Py_DECREF(ap1); - ap1 = (PyArrayObject *)op1; - if (ap1 == NULL) { - goto fail; - } - } - if (!PyArray_IS_C_CONTIGUOUS(ap2)) { - PyObject *op2 = PyArray_NewCopy(ap2, NPY_CORDER); - Py_DECREF(ap2); - ap2 = (PyArrayObject *)op2; - if (ap2 == NULL) { - goto fail; - } - } - - if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) { - /* One of ap1 or ap2 is a scalar */ - if (PyArray_NDIM(ap1) == 0) { - /* Make ap2 the scalar */ - PyArrayObject *t = ap1; - ap1 = ap2; - ap2 = t; - } - for (l = 1, j = 0; j < PyArray_NDIM(ap1); j++) { - dimensions[j] = PyArray_DIM(ap1, j); - l *= dimensions[j]; - } - nd = PyArray_NDIM(ap1); + if ((ap1 == NULL) || (ap2 == NULL)) { + goto fail; } - else { - /* - * (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2) - * Both ap1 and ap2 are vectors or matrices - */ - l = PyArray_DIM(ap1, PyArray_NDIM(ap1) - 1); - - if (PyArray_DIM(ap2, PyArray_NDIM(ap2) - 1) != l) { - dot_alignment_error(ap1, PyArray_NDIM(ap1) - 1, - ap2, PyArray_NDIM(ap2) - 1); - goto fail; - } - nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2; - if (nd == 1) - dimensions[0] = (PyArray_NDIM(ap1) == 2) ? - PyArray_DIM(ap1, 0) : PyArray_DIM(ap2, 0); - else if (nd == 2) { - dimensions[0] = PyArray_DIM(ap1, 0); - dimensions[1] = PyArray_DIM(ap2, 0); - } + ap2t = (PyArrayObject *)PyArray_Transpose(ap2, NULL); + if (ap2t == NULL) { + goto fail; } - /* Choose which subtype to return */ - prior2 = PyArray_GetPriority((PyObject *)ap2, 0.0); - prior1 = PyArray_GetPriority((PyObject *)ap1, 0.0); - subtype = (prior2 > prior1 ? Py_TYPE(ap2) : Py_TYPE(ap1)); - - ret = (PyArrayObject *)PyArray_New(subtype, nd, dimensions, - typenum, NULL, NULL, 0, 0, - (PyObject *) - (prior2 > prior1 ? ap2 : ap1)); - + ret = (PyArrayObject *)cblas_matrixproduct(typenum, ap1, ap2t, NULL); if (ret == NULL) { goto fail; } - NPY_BEGIN_ALLOW_THREADS; - memset(PyArray_DATA(ret), 0, PyArray_NBYTES(ret)); - - if (PyArray_NDIM(ap2) == 0) { - /* Multiplication by a scalar -- Level 1 BLAS */ - if (typenum == NPY_DOUBLE) { - cblas_daxpy(l, - *((double *)PyArray_DATA(ap2)), - (double *)PyArray_DATA(ap1), 1, - (double *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_CDOUBLE) { - cblas_zaxpy(l, - (double *)PyArray_DATA(ap2), - (double *)PyArray_DATA(ap1), 1, - (double *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_FLOAT) { - cblas_saxpy(l, - *((float *)PyArray_DATA(ap2)), - (float *)PyArray_DATA(ap1), 1, - (float *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_CFLOAT) { - cblas_caxpy(l, - (float *)PyArray_DATA(ap2), - (float *)PyArray_DATA(ap1), 1, - (float *)PyArray_DATA(ret), 1); - } - } - else if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 1) { - /* Dot product between two vectors -- Level 1 BLAS */ - blas_dot(typenum, l, - PyArray_DATA(ap1), PyArray_ITEMSIZE(ap1), - PyArray_DATA(ap2), PyArray_ITEMSIZE(ap2), - PyArray_DATA(ret)); - } - else if (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 1) { - /* Matrix-vector multiplication -- Level 2 BLAS */ - lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1); - gemv(typenum, CblasRowMajor, CblasNoTrans, ap1, lda, ap2, 1, ret); - } - else if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 2) { - /* Vector matrix multiplication -- Level 2 BLAS */ - lda = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1); - gemv(typenum, CblasRowMajor, CblasNoTrans, ap2, lda, ap1, 1, ret); - } - else { - /* - * (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 2) - * Matrix matrix multiplication -- Level 3 BLAS - */ - lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1); - ldb = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1); - gemm(typenum, CblasRowMajor, CblasNoTrans, CblasTrans, - PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1), - ap1, lda, ap2, ldb, ret); - } - NPY_END_ALLOW_THREADS; - Py_DECREF(ap1); Py_DECREF(ap2); return PyArray_Return(ret); - fail: - Py_XDECREF(ap1); +fail: Py_XDECREF(ap2); Py_XDECREF(ret); return NULL; |