summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2015-09-27 17:50:12 +0200
committerSebastian Berg <sebastian@sipsolutions.net>2015-09-27 18:59:23 +0200
commit3fd6de4c9492daf401ea1e58ea12c4ef3c8c2b2d (patch)
treeea7b0e13e02284eb11e0ca06fd3dd94c317de3e3
parent97c35365beda55c6dead8c50df785eb857f843f0 (diff)
downloadnumpy-3fd6de4c9492daf401ea1e58ea12c4ef3c8c2b2d.tar.gz
BUG: Fix vdot for uncontiguous arrays.
Note that using Newshape also means that less copying is done in principle, because ravel will always return a contiguous array.
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c16
-rw-r--r--numpy/core/tests/test_multiarray.py22
2 files changed, 32 insertions, 6 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 0cf6a7cbb..14df0899e 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -2253,8 +2253,10 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
{
int typenum;
char *ip1, *ip2, *op;
- npy_intp n, stride;
+ npy_intp n, stride1, stride2;
PyObject *op1, *op2;
+ npy_intp newdimptr[1] = {-1};
+ PyArray_Dims newdims = {newdimptr, 1};
PyArrayObject *ap1 = NULL, *ap2 = NULL, *ret = NULL;
PyArray_Descr *type;
PyArray_DotFunc *vdot;
@@ -2278,7 +2280,8 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
Py_DECREF(type);
goto fail;
}
- op1 = PyArray_Ravel(ap1, NPY_CORDER);
+
+ op1 = PyArray_Newshape(ap1, &newdims, NPY_CORDER);
if (op1 == NULL) {
Py_DECREF(type);
goto fail;
@@ -2290,7 +2293,7 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
if (ap2 == NULL) {
goto fail;
}
- op2 = PyArray_Ravel(ap2, NPY_CORDER);
+ op2 = PyArray_Newshape(ap2, &newdims, NPY_CORDER);
if (op2 == NULL) {
goto fail;
}
@@ -2310,7 +2313,8 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
}
n = PyArray_DIM(ap1, 0);
- stride = type->elsize;
+ stride1 = PyArray_STRIDE(ap1, 0);
+ stride2 = PyArray_STRIDE(ap2, 0);
ip1 = PyArray_DATA(ap1);
ip2 = PyArray_DATA(ap2);
op = PyArray_DATA(ret);
@@ -2338,11 +2342,11 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
}
if (n < 500) {
- vdot(ip1, stride, ip2, stride, op, n, NULL);
+ vdot(ip1, stride1, ip2, stride2, op, n, NULL);
}
else {
NPY_BEGIN_THREADS_DESCR(type);
- vdot(ip1, stride, ip2, stride, op, n, NULL);
+ vdot(ip1, stride1, ip2, stride2, op, n, NULL);
NPY_END_THREADS_DESCR(type);
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index a2667172c..44bb9e02f 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -3997,6 +3997,28 @@ class TestVdot(TestCase):
assert_equal(np.vdot(b, a), res)
assert_equal(np.vdot(b, b), res)
+ def test_vdot_uncontiguous(self):
+ for size in [2, 1000]:
+ # Different sizes match different branches in vdot.
+ a = np.zeros((size, 2, 2))
+ b = np.zeros((size, 2, 2))
+ a[:, 0, 0] = np.arange(size)
+ b[:, 0, 0] = np.arange(size) + 1
+ # Make a and b uncontiguous:
+ a = a[..., 0]
+ b = b[..., 0]
+
+ assert_equal(np.vdot(a, b),
+ np.vdot(a.flatten(), b.flatten()))
+ assert_equal(np.vdot(a, b.copy()),
+ np.vdot(a.flatten(), b.flatten()))
+ assert_equal(np.vdot(a.copy(), b),
+ np.vdot(a.flatten(), b.flatten()))
+ assert_equal(np.vdot(a.copy('F'), b),
+ np.vdot(a.flatten(), b.flatten()))
+ assert_equal(np.vdot(a, b.copy('F')),
+ np.vdot(a.flatten(), b.flatten()))
+
class TestDot(TestCase):
def setUp(self):