diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2015-09-27 17:50:12 +0200 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2015-09-27 18:59:23 +0200 |
commit | 3fd6de4c9492daf401ea1e58ea12c4ef3c8c2b2d (patch) | |
tree | ea7b0e13e02284eb11e0ca06fd3dd94c317de3e3 | |
parent | 97c35365beda55c6dead8c50df785eb857f843f0 (diff) | |
download | numpy-3fd6de4c9492daf401ea1e58ea12c4ef3c8c2b2d.tar.gz |
BUG: Fix vdot for uncontiguous arrays.
Note that using Newshape also means that less copying is done
in principle, because ravel will always return a contiguous
array.
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 16 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 22 |
2 files changed, 32 insertions, 6 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 0cf6a7cbb..14df0899e 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -2253,8 +2253,10 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) { int typenum; char *ip1, *ip2, *op; - npy_intp n, stride; + npy_intp n, stride1, stride2; PyObject *op1, *op2; + npy_intp newdimptr[1] = {-1}; + PyArray_Dims newdims = {newdimptr, 1}; PyArrayObject *ap1 = NULL, *ap2 = NULL, *ret = NULL; PyArray_Descr *type; PyArray_DotFunc *vdot; @@ -2278,7 +2280,8 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) Py_DECREF(type); goto fail; } - op1 = PyArray_Ravel(ap1, NPY_CORDER); + + op1 = PyArray_Newshape(ap1, &newdims, NPY_CORDER); if (op1 == NULL) { Py_DECREF(type); goto fail; @@ -2290,7 +2293,7 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) if (ap2 == NULL) { goto fail; } - op2 = PyArray_Ravel(ap2, NPY_CORDER); + op2 = PyArray_Newshape(ap2, &newdims, NPY_CORDER); if (op2 == NULL) { goto fail; } @@ -2310,7 +2313,8 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) } n = PyArray_DIM(ap1, 0); - stride = type->elsize; + stride1 = PyArray_STRIDE(ap1, 0); + stride2 = PyArray_STRIDE(ap2, 0); ip1 = PyArray_DATA(ap1); ip2 = PyArray_DATA(ap2); op = PyArray_DATA(ret); @@ -2338,11 +2342,11 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) } if (n < 500) { - vdot(ip1, stride, ip2, stride, op, n, NULL); + vdot(ip1, stride1, ip2, stride2, op, n, NULL); } else { NPY_BEGIN_THREADS_DESCR(type); - vdot(ip1, stride, ip2, stride, op, n, NULL); + vdot(ip1, stride1, ip2, stride2, op, n, NULL); NPY_END_THREADS_DESCR(type); } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index a2667172c..44bb9e02f 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -3997,6 +3997,28 @@ class TestVdot(TestCase): assert_equal(np.vdot(b, a), res) assert_equal(np.vdot(b, b), res) + def test_vdot_uncontiguous(self): + for size in [2, 1000]: + # Different sizes match different branches in vdot. + a = np.zeros((size, 2, 2)) + b = np.zeros((size, 2, 2)) + a[:, 0, 0] = np.arange(size) + b[:, 0, 0] = np.arange(size) + 1 + # Make a and b uncontiguous: + a = a[..., 0] + b = b[..., 0] + + assert_equal(np.vdot(a, b), + np.vdot(a.flatten(), b.flatten())) + assert_equal(np.vdot(a, b.copy()), + np.vdot(a.flatten(), b.flatten())) + assert_equal(np.vdot(a.copy(), b), + np.vdot(a.flatten(), b.flatten())) + assert_equal(np.vdot(a.copy('F'), b), + np.vdot(a.flatten(), b.flatten())) + assert_equal(np.vdot(a, b.copy('F')), + np.vdot(a.flatten(), b.flatten())) + class TestDot(TestCase): def setUp(self): |