diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2017-01-30 19:09:53 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-01-30 19:09:53 -0700 |
commit | d4f5bea46f328c9378e9548f2a87d0500c95955a (patch) | |
tree | 081575ddcc7387d2834bbd974efb33c51c2f3c65 /numpy | |
parent | af506553cbeb7c376b5b687b78ac296f3140ecf8 (diff) | |
parent | 24e551b56c5af03e0cbfdb90a30ff299711094f0 (diff) | |
download | numpy-d4f5bea46f328c9378e9548f2a87d0500c95955a.tar.gz |
Merge pull request #8539 from pv/dot-out-overlap
BUG: core: in dot(), make copies if out has memory overlap with input
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/cblasfuncs.c | 131 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 93 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 18 |
3 files changed, 168 insertions, 74 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c index ef05c7205..4b11be947 100644 --- a/numpy/core/src/multiarray/cblasfuncs.c +++ b/numpy/core/src/multiarray/cblasfuncs.c @@ -12,6 +12,7 @@ #include "npy_cblas.h" #include "arraytypes.h" #include "common.h" +#include "mem_overlap.h" /* @@ -242,7 +243,7 @@ NPY_NO_EXPORT PyObject * cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject *out) { - PyArrayObject *ret = NULL; + PyArrayObject *result = NULL, *out_buf = NULL; int j, lda, ldb; npy_intp l; int nd; @@ -412,25 +413,50 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, goto fail; } } + + /* check for memory overlap */ + if (!(solve_may_share_memory(out, ap1, 1) == 0 && + solve_may_share_memory(out, ap2, 1) == 0)) { + /* allocate temporary output array */ + out_buf = (PyArrayObject *)PyArray_NewLikeArray(out, NPY_CORDER, + NULL, 0); + if (out_buf == NULL) { + goto fail; + } + + /* set copy-back */ + Py_INCREF(out); + if (PyArray_SetUpdateIfCopyBase(out_buf, out) < 0) { + Py_DECREF(out); + goto fail; + } + } + else { + Py_INCREF(out); + out_buf = out; + } Py_INCREF(out); - ret = out; + result = out; } else { PyObject *tmp = (PyObject *)(prior2 > prior1 ? ap2 : ap1); - ret = (PyArrayObject *)PyArray_New(subtype, nd, dimensions, - typenum, NULL, NULL, 0, 0, tmp); - } + out_buf = (PyArrayObject *)PyArray_New(subtype, nd, dimensions, + typenum, NULL, NULL, 0, 0, tmp); + if (out_buf == NULL) { + goto fail; + } - if (ret == NULL) { - goto fail; + Py_INCREF(out_buf); + result = out_buf; } - numbytes = PyArray_NBYTES(ret); - memset(PyArray_DATA(ret), 0, numbytes); + + numbytes = PyArray_NBYTES(out_buf); + memset(PyArray_DATA(out_buf), 0, numbytes); if (numbytes == 0 || l == 0) { Py_DECREF(ap1); Py_DECREF(ap2); - return PyArray_Return(ret); + return PyArray_Return(out_buf); } if (ap2shape == _scalar) { @@ -443,7 +469,7 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, if (typenum == NPY_DOUBLE) { if (l == 1) { - *((double *)PyArray_DATA(ret)) = *((double *)PyArray_DATA(ap2)) * + *((double *)PyArray_DATA(out_buf)) = *((double *)PyArray_DATA(ap2)) * *((double *)PyArray_DATA(ap1)); } else if (ap1shape != _matrix) { @@ -451,26 +477,26 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, *((double *)PyArray_DATA(ap2)), (double *)PyArray_DATA(ap1), ap1stride/sizeof(double), - (double *)PyArray_DATA(ret), 1); + (double *)PyArray_DATA(out_buf), 1); } else { - int maxind, oind, i, a1s, rets; - char *ptr, *rptr; + int maxind, oind, i, a1s, outs; + char *ptr, *optr; double val; maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1); oind = 1 - maxind; ptr = PyArray_DATA(ap1); - rptr = PyArray_DATA(ret); + optr = PyArray_DATA(out_buf); l = PyArray_DIM(ap1, maxind); val = *((double *)PyArray_DATA(ap2)); a1s = PyArray_STRIDE(ap1, maxind) / sizeof(double); - rets = PyArray_STRIDE(ret, maxind) / sizeof(double); + outs = PyArray_STRIDE(out_buf, maxind) / sizeof(double); for (i = 0; i < PyArray_DIM(ap1, oind); i++) { cblas_daxpy(l, val, (double *)ptr, a1s, - (double *)rptr, rets); + (double *)optr, outs); ptr += PyArray_STRIDE(ap1, oind); - rptr += PyArray_STRIDE(ret, oind); + optr += PyArray_STRIDE(out_buf, oind); } } } @@ -480,7 +506,7 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, ptr1 = (npy_cdouble *)PyArray_DATA(ap2); ptr2 = (npy_cdouble *)PyArray_DATA(ap1); - res = (npy_cdouble *)PyArray_DATA(ret); + res = (npy_cdouble *)PyArray_DATA(out_buf); res->real = ptr1->real * ptr2->real - ptr1->imag * ptr2->imag; res->imag = ptr1->real * ptr2->imag + ptr1->imag * ptr2->real; } @@ -489,32 +515,32 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, (double *)PyArray_DATA(ap2), (double *)PyArray_DATA(ap1), ap1stride/sizeof(npy_cdouble), - (double *)PyArray_DATA(ret), 1); + (double *)PyArray_DATA(out_buf), 1); } else { - int maxind, oind, i, a1s, rets; - char *ptr, *rptr; + int maxind, oind, i, a1s, outs; + char *ptr, *optr; double *pval; maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1); oind = 1 - maxind; ptr = PyArray_DATA(ap1); - rptr = PyArray_DATA(ret); + optr = PyArray_DATA(out_buf); l = PyArray_DIM(ap1, maxind); pval = (double *)PyArray_DATA(ap2); a1s = PyArray_STRIDE(ap1, maxind) / sizeof(npy_cdouble); - rets = PyArray_STRIDE(ret, maxind) / sizeof(npy_cdouble); + outs = PyArray_STRIDE(out_buf, maxind) / sizeof(npy_cdouble); for (i = 0; i < PyArray_DIM(ap1, oind); i++) { cblas_zaxpy(l, pval, (double *)ptr, a1s, - (double *)rptr, rets); + (double *)optr, outs); ptr += PyArray_STRIDE(ap1, oind); - rptr += PyArray_STRIDE(ret, oind); + optr += PyArray_STRIDE(out_buf, oind); } } } else if (typenum == NPY_FLOAT) { if (l == 1) { - *((float *)PyArray_DATA(ret)) = *((float *)PyArray_DATA(ap2)) * + *((float *)PyArray_DATA(out_buf)) = *((float *)PyArray_DATA(ap2)) * *((float *)PyArray_DATA(ap1)); } else if (ap1shape != _matrix) { @@ -522,26 +548,26 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, *((float *)PyArray_DATA(ap2)), (float *)PyArray_DATA(ap1), ap1stride/sizeof(float), - (float *)PyArray_DATA(ret), 1); + (float *)PyArray_DATA(out_buf), 1); } else { - int maxind, oind, i, a1s, rets; - char *ptr, *rptr; + int maxind, oind, i, a1s, outs; + char *ptr, *optr; float val; maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1); oind = 1 - maxind; ptr = PyArray_DATA(ap1); - rptr = PyArray_DATA(ret); + optr = PyArray_DATA(out_buf); l = PyArray_DIM(ap1, maxind); val = *((float *)PyArray_DATA(ap2)); a1s = PyArray_STRIDE(ap1, maxind) / sizeof(float); - rets = PyArray_STRIDE(ret, maxind) / sizeof(float); + outs = PyArray_STRIDE(out_buf, maxind) / sizeof(float); for (i = 0; i < PyArray_DIM(ap1, oind); i++) { cblas_saxpy(l, val, (float *)ptr, a1s, - (float *)rptr, rets); + (float *)optr, outs); ptr += PyArray_STRIDE(ap1, oind); - rptr += PyArray_STRIDE(ret, oind); + optr += PyArray_STRIDE(out_buf, oind); } } } @@ -551,7 +577,7 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, ptr1 = (npy_cfloat *)PyArray_DATA(ap2); ptr2 = (npy_cfloat *)PyArray_DATA(ap1); - res = (npy_cfloat *)PyArray_DATA(ret); + res = (npy_cfloat *)PyArray_DATA(out_buf); res->real = ptr1->real * ptr2->real - ptr1->imag * ptr2->imag; res->imag = ptr1->real * ptr2->imag + ptr1->imag * ptr2->real; } @@ -560,26 +586,26 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, (float *)PyArray_DATA(ap2), (float *)PyArray_DATA(ap1), ap1stride/sizeof(npy_cfloat), - (float *)PyArray_DATA(ret), 1); + (float *)PyArray_DATA(out_buf), 1); } else { - int maxind, oind, i, a1s, rets; - char *ptr, *rptr; + int maxind, oind, i, a1s, outs; + char *ptr, *optr; float *pval; maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1); oind = 1 - maxind; ptr = PyArray_DATA(ap1); - rptr = PyArray_DATA(ret); + optr = PyArray_DATA(out_buf); l = PyArray_DIM(ap1, maxind); pval = (float *)PyArray_DATA(ap2); a1s = PyArray_STRIDE(ap1, maxind) / sizeof(npy_cfloat); - rets = PyArray_STRIDE(ret, maxind) / sizeof(npy_cfloat); + outs = PyArray_STRIDE(out_buf, maxind) / sizeof(npy_cfloat); for (i = 0; i < PyArray_DIM(ap1, oind); i++) { cblas_caxpy(l, pval, (float *)ptr, a1s, - (float *)rptr, rets); + (float *)optr, outs); ptr += PyArray_STRIDE(ap1, oind); - rptr += PyArray_STRIDE(ret, oind); + optr += PyArray_STRIDE(out_buf, oind); } } } @@ -592,7 +618,7 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, blas_dot(typenum, l, PyArray_DATA(ap1), PyArray_STRIDE(ap1, (ap1shape == _row)), PyArray_DATA(ap2), PyArray_STRIDE(ap2, 0), - PyArray_DATA(ret)); + PyArray_DATA(out_buf)); NPY_END_ALLOW_THREADS; } else if (ap1shape == _matrix && ap2shape != _matrix) { @@ -620,7 +646,7 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, lda = (PyArray_DIM(ap1, 0) > 1 ? PyArray_DIM(ap1, 0) : 1); } ap2s = PyArray_STRIDE(ap2, 0) / PyArray_ITEMSIZE(ap2); - gemv(typenum, Order, CblasNoTrans, ap1, lda, ap2, ap2s, ret); + gemv(typenum, Order, CblasNoTrans, ap1, lda, ap2, ap2s, out_buf); NPY_END_ALLOW_THREADS; } else if (ap1shape != _matrix && ap2shape == _matrix) { @@ -652,7 +678,7 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, else { ap1s = PyArray_STRIDE(ap1, 0) / PyArray_ITEMSIZE(ap1); } - gemv(typenum, Order, CblasTrans, ap2, lda, ap1, ap1s, ret); + gemv(typenum, Order, CblasTrans, ap2, lda, ap1, ap1s, out_buf); NPY_END_ALLOW_THREADS; } else { @@ -726,15 +752,15 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, ((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans)) ) { if (Trans1 == CblasNoTrans) { - syrk(typenum, Order, Trans1, N, M, ap1, lda, ret); + syrk(typenum, Order, Trans1, N, M, ap1, lda, out_buf); } else { - syrk(typenum, Order, Trans1, N, M, ap2, ldb, ret); + syrk(typenum, Order, Trans1, N, M, ap2, ldb, out_buf); } } else { gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, - ret); + out_buf); } NPY_END_ALLOW_THREADS; } @@ -742,11 +768,16 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, Py_DECREF(ap1); Py_DECREF(ap2); - return PyArray_Return(ret); + + /* Trigger possible copyback into `result` */ + Py_DECREF(out_buf); + + return PyArray_Return(result); fail: Py_XDECREF(ap1); Py_XDECREF(ap2); - Py_XDECREF(ret); + Py_XDECREF(out_buf); + Py_XDECREF(result); return NULL; } diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index a9ed5d198..f00de46c4 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -753,12 +753,17 @@ PyArray_CanCoerceScalar(int thistype, int neededtype, /* * Make a new empty array, of the passed size, of a type that takes the * priority of ap1 and ap2 into account. + * + * If `out` is non-NULL, memory overlap is checked with ap1 and ap2, and an + * updateifcopy temporary array may be returned. If `result` is non-NULL, the + * output array to be returned (`out` if non-NULL and the newly allocated array + * otherwise) is incref'd and put to *result. */ static PyArrayObject * new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out, - int nd, npy_intp dimensions[], int typenum) + int nd, npy_intp dimensions[], int typenum, PyArrayObject **result) { - PyArrayObject *ret; + PyArrayObject *out_buf; PyTypeObject *subtype; double prior1, prior2; /* @@ -776,6 +781,7 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out, } if (out) { int d; + /* verify that out is usable */ if (Py_TYPE(out) != subtype || PyArray_NDIM(out) != nd || @@ -793,15 +799,49 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out, return 0; } } - Py_INCREF(out); - return out; + + /* check for memory overlap */ + if (!(solve_may_share_memory(out, ap1, 1) == 0 && + solve_may_share_memory(out, ap2, 1) == 0)) { + /* allocate temporary output array */ + out_buf = (PyArrayObject *)PyArray_NewLikeArray(out, NPY_CORDER, + NULL, 0); + if (out_buf == NULL) { + return NULL; + } + + /* set copy-back */ + Py_INCREF(out); + if (PyArray_SetUpdateIfCopyBase(out_buf, out) < 0) { + Py_DECREF(out); + Py_DECREF(out_buf); + return NULL; + } + } + else { + Py_INCREF(out); + out_buf = out; + } + + if (result) { + Py_INCREF(out); + *result = out; + } + + return out_buf; } - ret = (PyArrayObject *)PyArray_New(subtype, nd, dimensions, - typenum, NULL, NULL, 0, 0, - (PyObject *) - (prior2 > prior1 ? ap2 : ap1)); - return ret; + out_buf = (PyArrayObject *)PyArray_New(subtype, nd, dimensions, + typenum, NULL, NULL, 0, 0, + (PyObject *) + (prior2 > prior1 ? ap2 : ap1)); + + if (out_buf != NULL && result) { + Py_INCREF(out_buf); + *result = out_buf; + } + + return out_buf; } /* Could perhaps be redone to not make contiguous arrays */ @@ -897,7 +937,7 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2) NPY_NO_EXPORT PyObject * PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) { - PyArrayObject *ap1, *ap2, *ret = NULL; + PyArrayObject *ap1, *ap2, *out_buf = NULL, *result = NULL; PyArrayIterObject *it1, *it2; npy_intp i, j, l; int typenum, nd, axis, matchDim; @@ -939,12 +979,12 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) #endif if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) { - ret = (PyArray_NDIM(ap1) == 0 ? ap1 : ap2); - ret = (PyArrayObject *)Py_TYPE(ret)->tp_as_number->nb_multiply( + result = (PyArray_NDIM(ap1) == 0 ? ap1 : ap2); + result = (PyArrayObject *)Py_TYPE(result)->tp_as_number->nb_multiply( (PyObject *)ap1, (PyObject *)ap2); Py_DECREF(ap1); Py_DECREF(ap2); - return (PyObject *)ret; + return (PyObject *)result; } l = PyArray_DIMS(ap1)[PyArray_NDIM(ap1) - 1]; if (PyArray_NDIM(ap2) > 1) { @@ -976,24 +1016,24 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) is1 = PyArray_STRIDES(ap1)[PyArray_NDIM(ap1)-1]; is2 = PyArray_STRIDES(ap2)[matchDim]; /* Choose which subtype to return */ - ret = new_array_for_sum(ap1, ap2, out, nd, dimensions, typenum); - if (ret == NULL) { + out_buf = new_array_for_sum(ap1, ap2, out, nd, dimensions, typenum, &result); + if (out_buf == NULL) { goto fail; } /* Ensure that multiarray.dot(<Nx0>,<0xM>) -> zeros((N,M)) */ if (PyArray_SIZE(ap1) == 0 && PyArray_SIZE(ap2) == 0) { - memset(PyArray_DATA(ret), 0, PyArray_NBYTES(ret)); + memset(PyArray_DATA(out_buf), 0, PyArray_NBYTES(out_buf)); } - dot = PyArray_DESCR(ret)->f->dotfunc; + dot = PyArray_DESCR(out_buf)->f->dotfunc; if (dot == NULL) { PyErr_SetString(PyExc_ValueError, "dot not available for this type"); goto fail; } - op = PyArray_DATA(ret); - os = PyArray_DESCR(ret)->elsize; + op = PyArray_DATA(out_buf); + os = PyArray_DESCR(out_buf)->elsize; axis = PyArray_NDIM(ap1)-1; it1 = (PyArrayIterObject *) PyArray_IterAllButAxis((PyObject *)ap1, &axis); @@ -1009,7 +1049,7 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); while (it1->index < it1->size) { while (it2->index < it2->size) { - dot(it1->dataptr, is1, it2->dataptr, is2, op, l, ret); + dot(it1->dataptr, is1, it2->dataptr, is2, op, l, out_buf); op += os; PyArray_ITER_NEXT(it2); } @@ -1025,12 +1065,17 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) } Py_DECREF(ap1); Py_DECREF(ap2); - return (PyObject *)ret; + + /* Trigger possible copy-back into `result` */ + Py_DECREF(out_buf); + + return (PyObject *)result; fail: Py_XDECREF(ap1); Py_XDECREF(ap2); - Py_XDECREF(ret); + Py_XDECREF(out_buf); + Py_XDECREF(result); return NULL; } @@ -1142,7 +1187,7 @@ _pyarray_correlate(PyArrayObject *ap1, PyArrayObject *ap2, int typenum, * Need to choose an output array that can hold a sum * -- use priority to determine which subtype. */ - ret = new_array_for_sum(ap1, ap2, NULL, 1, &length, typenum); + ret = new_array_for_sum(ap1, ap2, NULL, 1, &length, typenum, NULL); if (ret == NULL) { return NULL; } @@ -2240,7 +2285,7 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) } /* array scalar output */ - ret = new_array_for_sum(ap1, ap2, NULL, 0, (npy_intp *)NULL, typenum); + ret = new_array_for_sum(ap1, ap2, NULL, 0, (npy_intp *)NULL, typenum, NULL); if (ret == NULL) { goto fail; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 10b243b35..3bbe99b8f 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2326,6 +2326,24 @@ class TestMethods(TestCase): assert_raises(TypeError, np.dot, c, A) assert_raises(TypeError, np.dot, A, c) + def test_dot_out_mem_overlap(self): + np.random.seed(1) + + # Test BLAS and non-BLAS code paths, including all dtypes + # that dot() supports + dtypes = [np.dtype(code) for code in np.typecodes['All'] + if code not in 'USVM'] + for dtype in dtypes: + a = np.random.rand(3, 3).astype(dtype) + b = np.random.rand(3, 3).astype(dtype) + y = np.dot(a, b) + x = np.dot(a, b, out=b) + assert_equal(x, y, err_msg=repr(dtype)) + + # Check invalid output array + assert_raises(ValueError, np.dot, a, b, out=b[::2]) + assert_raises(ValueError, np.dot, a, b, out=b.T) + def test_diagonal(self): a = np.arange(12).reshape((3, 4)) assert_equal(a.diagonal(), [0, 5, 10]) |