summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2017-01-28 18:32:45 +0100
committerPauli Virtanen <pav@iki.fi>2017-01-28 18:38:00 +0100
commitcacc2ed5fd6bc06f16061f6f230a592ea8c68d3e (patch)
treee3fa5bce523f4d961148791b9a1276f0ea80016e /numpy
parentc5e1773f0d77755e21d072eb106b8e51a672bfa8 (diff)
downloadnumpy-cacc2ed5fd6bc06f16061f6f230a592ea8c68d3e.tar.gz
BUG: core: in dot(), make copies if out has memory overlap with input
BLAS does not allow aliased inputs. If user-provided out= argument may overlap in memory with one of the inputs to dot(), put the output to a temporary work array and copy back after the operation.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/cblasfuncs.c33
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c58
-rw-r--r--numpy/core/tests/test_multiarray.py16
3 files changed, 96 insertions, 11 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c
index ef05c7205..f1402e3b4 100644
--- a/numpy/core/src/multiarray/cblasfuncs.c
+++ b/numpy/core/src/multiarray/cblasfuncs.c
@@ -12,6 +12,7 @@
#include "npy_cblas.h"
#include "arraytypes.h"
#include "common.h"
+#include "mem_overlap.h"
/*
@@ -242,7 +243,7 @@ NPY_NO_EXPORT PyObject *
cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
PyArrayObject *out)
{
- PyArrayObject *ret = NULL;
+ PyArrayObject *ret = NULL, *result = NULL;
int j, lda, ldb;
npy_intp l;
int nd;
@@ -412,14 +413,38 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
goto fail;
}
}
+
+ /* check for memory overlap */
+ if (!(solve_may_share_memory(out, ap1, 1) == 0 &&
+ solve_may_share_memory(out, ap2, 1) == 0)) {
+ /* allocate temporary output array */
+ ret = (PyArrayObject *)PyArray_NewLikeArray(out, NPY_CORDER,
+ NULL, 0);
+ if (ret == NULL) {
+ goto fail;
+ }
+
+ /* set copy-back */
+ Py_INCREF(out);
+ if (PyArray_SetUpdateIfCopyBase(ret, out) < 0) {
+ Py_DECREF(out);
+ goto fail;
+ }
+ }
+ else {
+ Py_INCREF(out);
+ ret = out;
+ }
Py_INCREF(out);
- ret = out;
+ result = out;
}
else {
PyObject *tmp = (PyObject *)(prior2 > prior1 ? ap2 : ap1);
ret = (PyArrayObject *)PyArray_New(subtype, nd, dimensions,
typenum, NULL, NULL, 0, 0, tmp);
+ Py_INCREF(ret);
+ result = ret;
}
if (ret == NULL) {
@@ -742,11 +767,13 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
Py_DECREF(ap1);
Py_DECREF(ap2);
- return PyArray_Return(ret);
+ Py_DECREF(ret);
+ return PyArray_Return(result);
fail:
Py_XDECREF(ap1);
Py_XDECREF(ap2);
Py_XDECREF(ret);
+ Py_XDECREF(result);
return NULL;
}
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index a9ed5d198..2722e6bbb 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -753,10 +753,15 @@ PyArray_CanCoerceScalar(int thistype, int neededtype,
/*
* Make a new empty array, of the passed size, of a type that takes the
* priority of ap1 and ap2 into account.
+ *
+ * If `out` is non-NULL, memory overlap is checked with ap1 and ap2, and an
+ * updateifcopy temporary array may be returned. If `result` is non-NULL, the
+ * output array to be returned (`out` if non-NULL and the newly allocated array
+ * otherwise) is incref'd and put to *result.
*/
static PyArrayObject *
new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out,
- int nd, npy_intp dimensions[], int typenum)
+ int nd, npy_intp dimensions[], int typenum, PyArrayObject **result)
{
PyArrayObject *ret;
PyTypeObject *subtype;
@@ -776,6 +781,7 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out,
}
if (out) {
int d;
+
/* verify that out is usable */
if (Py_TYPE(out) != subtype ||
PyArray_NDIM(out) != nd ||
@@ -793,14 +799,48 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out,
return 0;
}
}
- Py_INCREF(out);
- return out;
+
+ /* check for memory overlap */
+ if (!(solve_may_share_memory(out, ap1, 1) == 0 &&
+ solve_may_share_memory(out, ap2, 1) == 0)) {
+ /* allocate temporary output array */
+ ret = (PyArrayObject *)PyArray_NewLikeArray(out, NPY_CORDER,
+ NULL, 0);
+ if (ret == NULL) {
+ return NULL;
+ }
+
+ /* set copy-back */
+ Py_INCREF(out);
+ if (PyArray_SetUpdateIfCopyBase(ret, out) < 0) {
+ Py_DECREF(out);
+ Py_DECREF(ret);
+ return NULL;
+ }
+ }
+ else {
+ Py_INCREF(out);
+ ret = out;
+ }
+
+ if (result) {
+ Py_INCREF(out);
+ *result = out;
+ }
+
+ return ret;
}
ret = (PyArrayObject *)PyArray_New(subtype, nd, dimensions,
typenum, NULL, NULL, 0, 0,
(PyObject *)
(prior2 > prior1 ? ap2 : ap1));
+
+ if (ret != NULL && result) {
+ Py_INCREF(ret);
+ *result = ret;
+ }
+
return ret;
}
@@ -897,7 +937,7 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2)
NPY_NO_EXPORT PyObject *
PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
{
- PyArrayObject *ap1, *ap2, *ret = NULL;
+ PyArrayObject *ap1, *ap2, *ret = NULL, *result = NULL;
PyArrayIterObject *it1, *it2;
npy_intp i, j, l;
int typenum, nd, axis, matchDim;
@@ -976,7 +1016,7 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
is1 = PyArray_STRIDES(ap1)[PyArray_NDIM(ap1)-1];
is2 = PyArray_STRIDES(ap2)[matchDim];
/* Choose which subtype to return */
- ret = new_array_for_sum(ap1, ap2, out, nd, dimensions, typenum);
+ ret = new_array_for_sum(ap1, ap2, out, nd, dimensions, typenum, &result);
if (ret == NULL) {
goto fail;
}
@@ -1025,12 +1065,14 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
}
Py_DECREF(ap1);
Py_DECREF(ap2);
- return (PyObject *)ret;
+ Py_DECREF(ret);
+ return (PyObject *)result;
fail:
Py_XDECREF(ap1);
Py_XDECREF(ap2);
Py_XDECREF(ret);
+ Py_XDECREF(result);
return NULL;
}
@@ -1142,7 +1184,7 @@ _pyarray_correlate(PyArrayObject *ap1, PyArrayObject *ap2, int typenum,
* Need to choose an output array that can hold a sum
* -- use priority to determine which subtype.
*/
- ret = new_array_for_sum(ap1, ap2, NULL, 1, &length, typenum);
+ ret = new_array_for_sum(ap1, ap2, NULL, 1, &length, typenum, NULL);
if (ret == NULL) {
return NULL;
}
@@ -2240,7 +2282,7 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
}
/* array scalar output */
- ret = new_array_for_sum(ap1, ap2, NULL, 0, (npy_intp *)NULL, typenum);
+ ret = new_array_for_sum(ap1, ap2, NULL, 0, (npy_intp *)NULL, typenum, NULL);
if (ret == NULL) {
goto fail;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 10b243b35..2621daf37 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2326,6 +2326,22 @@ class TestMethods(TestCase):
assert_raises(TypeError, np.dot, c, A)
assert_raises(TypeError, np.dot, A, c)
+ def test_dot_out_mem_overlap(self):
+ np.random.seed(1)
+
+ for dtype in [np.object_, np.float32, np.complex128, np.int64]:
+ a0 = np.random.rand(3, 3).astype(dtype)
+ b0 = np.random.rand(3, 3).astype(dtype)
+ for a, b in [(a0.copy(), b0.copy()),
+ (a0.copy().T, b0.copy())]:
+ y = np.dot(a, b)
+ x = np.dot(a, b, out=b)
+ assert_equal(x, y, err_msg=repr(dtype))
+
+ # Check invalid output array
+ assert_raises(ValueError, np.dot, a, b, out=b[::2])
+ assert_raises(ValueError, np.dot, a, b, out=b.T)
+
def test_diagonal(self):
a = np.arange(12).reshape((3, 4))
assert_equal(a.diagonal(), [0, 5, 10])