summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-06-29 08:43:56 -0600
committerGitHub <noreply@github.com>2018-06-29 08:43:56 -0600
commitc4924b0c66e103014b09f61eed18a3fab5bd8b3d (patch)
treeae4ae9226542c96c2018dd365df2bca74612b6ce /numpy
parent166b39fd8ed6ad07e129682c3d42ba55da706d58 (diff)
parentae5000f4787fcc99f27e55a171f14cc8405e8ed9 (diff)
downloadnumpy-c4924b0c66e103014b09f61eed18a3fab5bd8b3d.tar.gz
Merge pull request #11406 from mattip/einsum-out-is-res
BUG: ensure ret is out in einsum
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/einsum.c.src28
-rw-r--r--numpy/core/tests/test_einsum.py5
2 files changed, 15 insertions, 18 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index 33184d99a..1765982a0 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -2767,11 +2767,11 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
goto fail;
}
- /* Initialize the output to all zeros and reset the iterator */
+ /* Initialize the output to all zeros */
ret = NpyIter_GetOperandArray(iter)[nop];
- Py_INCREF(ret);
- PyArray_AssignZero(ret, NULL);
-
+ if (PyArray_AssignZero(ret, NULL) < 0) {
+ goto fail;
+ }
/***************************/
/*
@@ -2785,16 +2785,12 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
case 1:
if (ndim == 2) {
if (unbuffered_loop_nop1_ndim2(iter) < 0) {
- Py_DECREF(ret);
- ret = NULL;
goto fail;
}
goto finish;
}
else if (ndim == 3) {
if (unbuffered_loop_nop1_ndim3(iter) < 0) {
- Py_DECREF(ret);
- ret = NULL;
goto fail;
}
goto finish;
@@ -2803,16 +2799,12 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
case 2:
if (ndim == 2) {
if (unbuffered_loop_nop2_ndim2(iter) < 0) {
- Py_DECREF(ret);
- ret = NULL;
goto fail;
}
goto finish;
}
else if (ndim == 3) {
if (unbuffered_loop_nop2_ndim3(iter) < 0) {
- Py_DECREF(ret);
- ret = NULL;
goto fail;
}
goto finish;
@@ -2823,7 +2815,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
/***************************/
if (NpyIter_Reset(iter, NULL) != NPY_SUCCEED) {
- Py_DECREF(ret);
goto fail;
}
@@ -2845,8 +2836,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
if (sop == NULL) {
PyErr_SetString(PyExc_TypeError,
"invalid data type for einsum");
- Py_DECREF(ret);
- ret = NULL;
}
else if (NpyIter_GetIterSize(iter) != 0) {
NpyIter_IterNextFunc *iternext;
@@ -2858,7 +2847,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
iternext = NpyIter_GetIterNext(iter, NULL);
if (iternext == NULL) {
NpyIter_Deallocate(iter);
- Py_DECREF(ret);
goto fail;
}
dataptr = NpyIter_GetDataPtrArray(iter);
@@ -2874,12 +2862,16 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
/* If the API was needed, it may have thrown an error */
if (NpyIter_IterationNeedsAPI(iter) && PyErr_Occurred()) {
- Py_DECREF(ret);
- ret = NULL;
+ goto fail;
}
}
finish:
+ if (out != NULL) {
+ ret = out;
+ }
+ Py_INCREF(ret);
+
NpyIter_Deallocate(iter);
for (iop = 0; iop < nop; ++iop) {
Py_DECREF(op[iop]);
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index 647738831..a72079218 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -730,6 +730,11 @@ class TestEinSum(object):
res = np.einsum('...ij,...jk->...ik', a, a, out=out)
assert_equal(res, tgt)
+ def test_out_is_res(self):
+ a = np.arange(9).reshape(3, 3)
+ res = np.einsum('...ij,...jk->...ik', a, a, out=a)
+ assert res is a
+
def optimize_compare(self, string):
# Tests all paths of the optimization function against
# conventional einsum