summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorJohn Kirkham <kirkhamj@janelia.hhmi.org>2016-01-07 15:53:13 -0500
committerJohn Kirkham <kirkhamj@janelia.hhmi.org>2016-01-11 19:42:53 -0500
commit88c8a9c22013eb6aa876adfe895b339bc602e6ac (patch)
treeca174976ca23245b7189c7b1366d0dfb86590506 /numpy
parentbab118da49f051aecf35296bb9d8a00edd5b4198 (diff)
downloadnumpy-88c8a9c22013eb6aa876adfe895b339bc602e6ac.tar.gz
MAINT: Refactor `cblas_innerproduct` to use `cblas_matrixproduct`.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/cblasfuncs.c138
1 files changed, 9 insertions, 129 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c
index 902946706..180d96e30 100644
--- a/numpy/core/src/multiarray/cblasfuncs.c
+++ b/numpy/core/src/multiarray/cblasfuncs.c
@@ -765,148 +765,28 @@ fail:
NPY_NO_EXPORT PyObject *
cblas_innerproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2)
{
- int j, l, lda, ldb;
- int nd;
- double prior1, prior2;
- PyArrayObject *ret = NULL;
- npy_intp dimensions[NPY_MAXDIMS];
- PyTypeObject *subtype;
+ PyArrayObject* ap2t = NULL;
+ PyArrayObject* ret = NULL;
- /* assure contiguous arrays */
- if (!PyArray_IS_C_CONTIGUOUS(ap1)) {
- PyObject *op1 = PyArray_NewCopy(ap1, NPY_CORDER);
- Py_DECREF(ap1);
- ap1 = (PyArrayObject *)op1;
- if (ap1 == NULL) {
- goto fail;
- }
- }
- if (!PyArray_IS_C_CONTIGUOUS(ap2)) {
- PyObject *op2 = PyArray_NewCopy(ap2, NPY_CORDER);
- Py_DECREF(ap2);
- ap2 = (PyArrayObject *)op2;
- if (ap2 == NULL) {
- goto fail;
- }
- }
-
- if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
- /* One of ap1 or ap2 is a scalar */
- if (PyArray_NDIM(ap1) == 0) {
- /* Make ap2 the scalar */
- PyArrayObject *t = ap1;
- ap1 = ap2;
- ap2 = t;
- }
- for (l = 1, j = 0; j < PyArray_NDIM(ap1); j++) {
- dimensions[j] = PyArray_DIM(ap1, j);
- l *= dimensions[j];
- }
- nd = PyArray_NDIM(ap1);
+ if ((ap1 == NULL) || (ap2 == NULL)) {
+ goto fail;
}
- else {
- /*
- * (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2)
- * Both ap1 and ap2 are vectors or matrices
- */
- l = PyArray_DIM(ap1, PyArray_NDIM(ap1) - 1);
-
- if (PyArray_DIM(ap2, PyArray_NDIM(ap2) - 1) != l) {
- dot_alignment_error(ap1, PyArray_NDIM(ap1) - 1,
- ap2, PyArray_NDIM(ap2) - 1);
- goto fail;
- }
- nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2;
- if (nd == 1)
- dimensions[0] = (PyArray_NDIM(ap1) == 2) ?
- PyArray_DIM(ap1, 0) : PyArray_DIM(ap2, 0);
- else if (nd == 2) {
- dimensions[0] = PyArray_DIM(ap1, 0);
- dimensions[1] = PyArray_DIM(ap2, 0);
- }
+ ap2t = (PyArrayObject *)PyArray_Transpose(ap2, NULL);
+ if (ap2t == NULL) {
+ goto fail;
}
- /* Choose which subtype to return */
- prior2 = PyArray_GetPriority((PyObject *)ap2, 0.0);
- prior1 = PyArray_GetPriority((PyObject *)ap1, 0.0);
- subtype = (prior2 > prior1 ? Py_TYPE(ap2) : Py_TYPE(ap1));
-
- ret = (PyArrayObject *)PyArray_New(subtype, nd, dimensions,
- typenum, NULL, NULL, 0, 0,
- (PyObject *)
- (prior2 > prior1 ? ap2 : ap1));
-
+ ret = (PyArrayObject *)cblas_matrixproduct(typenum, ap1, ap2t, NULL);
if (ret == NULL) {
goto fail;
}
- NPY_BEGIN_ALLOW_THREADS;
- memset(PyArray_DATA(ret), 0, PyArray_NBYTES(ret));
-
- if (PyArray_NDIM(ap2) == 0) {
- /* Multiplication by a scalar -- Level 1 BLAS */
- if (typenum == NPY_DOUBLE) {
- cblas_daxpy(l,
- *((double *)PyArray_DATA(ap2)),
- (double *)PyArray_DATA(ap1), 1,
- (double *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_CDOUBLE) {
- cblas_zaxpy(l,
- (double *)PyArray_DATA(ap2),
- (double *)PyArray_DATA(ap1), 1,
- (double *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_FLOAT) {
- cblas_saxpy(l,
- *((float *)PyArray_DATA(ap2)),
- (float *)PyArray_DATA(ap1), 1,
- (float *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_CFLOAT) {
- cblas_caxpy(l,
- (float *)PyArray_DATA(ap2),
- (float *)PyArray_DATA(ap1), 1,
- (float *)PyArray_DATA(ret), 1);
- }
- }
- else if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 1) {
- /* Dot product between two vectors -- Level 1 BLAS */
- blas_dot(typenum, l,
- PyArray_DATA(ap1), PyArray_ITEMSIZE(ap1),
- PyArray_DATA(ap2), PyArray_ITEMSIZE(ap2),
- PyArray_DATA(ret));
- }
- else if (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 1) {
- /* Matrix-vector multiplication -- Level 2 BLAS */
- lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1);
- gemv(typenum, CblasRowMajor, CblasNoTrans, ap1, lda, ap2, 1, ret);
- }
- else if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 2) {
- /* Vector matrix multiplication -- Level 2 BLAS */
- lda = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
- gemv(typenum, CblasRowMajor, CblasNoTrans, ap2, lda, ap1, 1, ret);
- }
- else {
- /*
- * (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 2)
- * Matrix matrix multiplication -- Level 3 BLAS
- */
- lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1);
- ldb = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
- gemm(typenum, CblasRowMajor, CblasNoTrans, CblasTrans,
- PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1),
- ap1, lda, ap2, ldb, ret);
- }
- NPY_END_ALLOW_THREADS;
- Py_DECREF(ap1);
Py_DECREF(ap2);
return PyArray_Return(ret);
- fail:
- Py_XDECREF(ap1);
+fail:
Py_XDECREF(ap2);
Py_XDECREF(ret);
return NULL;