diff options
author | Pauli Virtanen <pav@iki.fi> | 2017-01-28 18:32:45 +0100 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2017-01-28 18:38:00 +0100 |
commit | cacc2ed5fd6bc06f16061f6f230a592ea8c68d3e (patch) | |
tree | e3fa5bce523f4d961148791b9a1276f0ea80016e /numpy | |
parent | c5e1773f0d77755e21d072eb106b8e51a672bfa8 (diff) | |
download | numpy-cacc2ed5fd6bc06f16061f6f230a592ea8c68d3e.tar.gz |
BUG: core: in dot(), make copies if out has memory overlap with input
BLAS does not allow aliased inputs. If user-provided out= argument may
overlap in memory with one of the inputs to dot(), put the output to a
temporary work array and copy back after the operation.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/cblasfuncs.c | 33 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 58 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 16 |
3 files changed, 96 insertions, 11 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c index ef05c7205..f1402e3b4 100644 --- a/numpy/core/src/multiarray/cblasfuncs.c +++ b/numpy/core/src/multiarray/cblasfuncs.c @@ -12,6 +12,7 @@ #include "npy_cblas.h" #include "arraytypes.h" #include "common.h" +#include "mem_overlap.h" /* @@ -242,7 +243,7 @@ NPY_NO_EXPORT PyObject * cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject *out) { - PyArrayObject *ret = NULL; + PyArrayObject *ret = NULL, *result = NULL; int j, lda, ldb; npy_intp l; int nd; @@ -412,14 +413,38 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, goto fail; } } + + /* check for memory overlap */ + if (!(solve_may_share_memory(out, ap1, 1) == 0 && + solve_may_share_memory(out, ap2, 1) == 0)) { + /* allocate temporary output array */ + ret = (PyArrayObject *)PyArray_NewLikeArray(out, NPY_CORDER, + NULL, 0); + if (ret == NULL) { + goto fail; + } + + /* set copy-back */ + Py_INCREF(out); + if (PyArray_SetUpdateIfCopyBase(ret, out) < 0) { + Py_DECREF(out); + goto fail; + } + } + else { + Py_INCREF(out); + ret = out; + } Py_INCREF(out); - ret = out; + result = out; } else { PyObject *tmp = (PyObject *)(prior2 > prior1 ? ap2 : ap1); ret = (PyArrayObject *)PyArray_New(subtype, nd, dimensions, typenum, NULL, NULL, 0, 0, tmp); + Py_INCREF(ret); + result = ret; } if (ret == NULL) { @@ -742,11 +767,13 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, Py_DECREF(ap1); Py_DECREF(ap2); - return PyArray_Return(ret); + Py_DECREF(ret); + return PyArray_Return(result); fail: Py_XDECREF(ap1); Py_XDECREF(ap2); Py_XDECREF(ret); + Py_XDECREF(result); return NULL; } diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index a9ed5d198..2722e6bbb 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -753,10 +753,15 @@ PyArray_CanCoerceScalar(int thistype, int neededtype, /* * Make a new empty array, of the passed size, of a type that takes the * priority of ap1 and ap2 into account. + * + * If `out` is non-NULL, memory overlap is checked with ap1 and ap2, and an + * updateifcopy temporary array may be returned. If `result` is non-NULL, the + * output array to be returned (`out` if non-NULL and the newly allocated array + * otherwise) is incref'd and put to *result. */ static PyArrayObject * new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out, - int nd, npy_intp dimensions[], int typenum) + int nd, npy_intp dimensions[], int typenum, PyArrayObject **result) { PyArrayObject *ret; PyTypeObject *subtype; @@ -776,6 +781,7 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out, } if (out) { int d; + /* verify that out is usable */ if (Py_TYPE(out) != subtype || PyArray_NDIM(out) != nd || @@ -793,14 +799,48 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out, return 0; } } - Py_INCREF(out); - return out; + + /* check for memory overlap */ + if (!(solve_may_share_memory(out, ap1, 1) == 0 && + solve_may_share_memory(out, ap2, 1) == 0)) { + /* allocate temporary output array */ + ret = (PyArrayObject *)PyArray_NewLikeArray(out, NPY_CORDER, + NULL, 0); + if (ret == NULL) { + return NULL; + } + + /* set copy-back */ + Py_INCREF(out); + if (PyArray_SetUpdateIfCopyBase(ret, out) < 0) { + Py_DECREF(out); + Py_DECREF(ret); + return NULL; + } + } + else { + Py_INCREF(out); + ret = out; + } + + if (result) { + Py_INCREF(out); + *result = out; + } + + return ret; } ret = (PyArrayObject *)PyArray_New(subtype, nd, dimensions, typenum, NULL, NULL, 0, 0, (PyObject *) (prior2 > prior1 ? ap2 : ap1)); + + if (ret != NULL && result) { + Py_INCREF(ret); + *result = ret; + } + return ret; } @@ -897,7 +937,7 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2) NPY_NO_EXPORT PyObject * PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) { - PyArrayObject *ap1, *ap2, *ret = NULL; + PyArrayObject *ap1, *ap2, *ret = NULL, *result = NULL; PyArrayIterObject *it1, *it2; npy_intp i, j, l; int typenum, nd, axis, matchDim; @@ -976,7 +1016,7 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) is1 = PyArray_STRIDES(ap1)[PyArray_NDIM(ap1)-1]; is2 = PyArray_STRIDES(ap2)[matchDim]; /* Choose which subtype to return */ - ret = new_array_for_sum(ap1, ap2, out, nd, dimensions, typenum); + ret = new_array_for_sum(ap1, ap2, out, nd, dimensions, typenum, &result); if (ret == NULL) { goto fail; } @@ -1025,12 +1065,14 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) } Py_DECREF(ap1); Py_DECREF(ap2); - return (PyObject *)ret; + Py_DECREF(ret); + return (PyObject *)result; fail: Py_XDECREF(ap1); Py_XDECREF(ap2); Py_XDECREF(ret); + Py_XDECREF(result); return NULL; } @@ -1142,7 +1184,7 @@ _pyarray_correlate(PyArrayObject *ap1, PyArrayObject *ap2, int typenum, * Need to choose an output array that can hold a sum * -- use priority to determine which subtype. */ - ret = new_array_for_sum(ap1, ap2, NULL, 1, &length, typenum); + ret = new_array_for_sum(ap1, ap2, NULL, 1, &length, typenum, NULL); if (ret == NULL) { return NULL; } @@ -2240,7 +2282,7 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) } /* array scalar output */ - ret = new_array_for_sum(ap1, ap2, NULL, 0, (npy_intp *)NULL, typenum); + ret = new_array_for_sum(ap1, ap2, NULL, 0, (npy_intp *)NULL, typenum, NULL); if (ret == NULL) { goto fail; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 10b243b35..2621daf37 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2326,6 +2326,22 @@ class TestMethods(TestCase): assert_raises(TypeError, np.dot, c, A) assert_raises(TypeError, np.dot, A, c) + def test_dot_out_mem_overlap(self): + np.random.seed(1) + + for dtype in [np.object_, np.float32, np.complex128, np.int64]: + a0 = np.random.rand(3, 3).astype(dtype) + b0 = np.random.rand(3, 3).astype(dtype) + for a, b in [(a0.copy(), b0.copy()), + (a0.copy().T, b0.copy())]: + y = np.dot(a, b) + x = np.dot(a, b, out=b) + assert_equal(x, y, err_msg=repr(dtype)) + + # Check invalid output array + assert_raises(ValueError, np.dot, a, b, out=b[::2]) + assert_raises(ValueError, np.dot, a, b, out=b.T) + def test_diagonal(self): a = np.arange(12).reshape((3, 4)) assert_equal(a.diagonal(), [0, 5, 10]) |