diff options
| author | Charles Harris <charlesr.harris@gmail.com> | 2018-06-29 08:43:56 -0600 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2018-06-29 08:43:56 -0600 |
| commit | c4924b0c66e103014b09f61eed18a3fab5bd8b3d (patch) | |
| tree | ae4ae9226542c96c2018dd365df2bca74612b6ce /numpy | |
| parent | 166b39fd8ed6ad07e129682c3d42ba55da706d58 (diff) | |
| parent | ae5000f4787fcc99f27e55a171f14cc8405e8ed9 (diff) | |
| download | numpy-c4924b0c66e103014b09f61eed18a3fab5bd8b3d.tar.gz | |
Merge pull request #11406 from mattip/einsum-out-is-res
BUG: ensure ret is out in einsum
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/src/multiarray/einsum.c.src | 28 | ||||
| -rw-r--r-- | numpy/core/tests/test_einsum.py | 5 |
2 files changed, 15 insertions, 18 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index 33184d99a..1765982a0 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -2767,11 +2767,11 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, goto fail; } - /* Initialize the output to all zeros and reset the iterator */ + /* Initialize the output to all zeros */ ret = NpyIter_GetOperandArray(iter)[nop]; - Py_INCREF(ret); - PyArray_AssignZero(ret, NULL); - + if (PyArray_AssignZero(ret, NULL) < 0) { + goto fail; + } /***************************/ /* @@ -2785,16 +2785,12 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, case 1: if (ndim == 2) { if (unbuffered_loop_nop1_ndim2(iter) < 0) { - Py_DECREF(ret); - ret = NULL; goto fail; } goto finish; } else if (ndim == 3) { if (unbuffered_loop_nop1_ndim3(iter) < 0) { - Py_DECREF(ret); - ret = NULL; goto fail; } goto finish; @@ -2803,16 +2799,12 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, case 2: if (ndim == 2) { if (unbuffered_loop_nop2_ndim2(iter) < 0) { - Py_DECREF(ret); - ret = NULL; goto fail; } goto finish; } else if (ndim == 3) { if (unbuffered_loop_nop2_ndim3(iter) < 0) { - Py_DECREF(ret); - ret = NULL; goto fail; } goto finish; @@ -2823,7 +2815,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, /***************************/ if (NpyIter_Reset(iter, NULL) != NPY_SUCCEED) { - Py_DECREF(ret); goto fail; } @@ -2845,8 +2836,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, if (sop == NULL) { PyErr_SetString(PyExc_TypeError, "invalid data type for einsum"); - Py_DECREF(ret); - ret = NULL; } else if (NpyIter_GetIterSize(iter) != 0) { NpyIter_IterNextFunc *iternext; @@ -2858,7 +2847,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, iternext = NpyIter_GetIterNext(iter, NULL); if (iternext == NULL) { NpyIter_Deallocate(iter); - Py_DECREF(ret); goto fail; } dataptr = NpyIter_GetDataPtrArray(iter); @@ -2874,12 +2862,16 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, /* If the API was needed, it may have thrown an error */ if (NpyIter_IterationNeedsAPI(iter) && PyErr_Occurred()) { - Py_DECREF(ret); - ret = NULL; + goto fail; } } finish: + if (out != NULL) { + ret = out; + } + Py_INCREF(ret); + NpyIter_Deallocate(iter); for (iop = 0; iop < nop; ++iop) { Py_DECREF(op[iop]); diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index 647738831..a72079218 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -730,6 +730,11 @@ class TestEinSum(object): res = np.einsum('...ij,...jk->...ik', a, a, out=out) assert_equal(res, tgt) + def test_out_is_res(self): + a = np.arange(9).reshape(3, 3) + res = np.einsum('...ij,...jk->...ik', a, a, out=a) + assert res is a + def optimize_compare(self, string): # Tests all paths of the optimization function against # conventional einsum |
