diff options
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/multiarray/einsum.c.src | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index e673437ef..e64145203 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -17,6 +17,7 @@ #include <numpy/npy_common.h> #include <numpy/arrayobject.h> #include <npy_pycompat.h> +#include <array_assign.h> //PyArray_AssignRawScalar #include <ctype.h> @@ -1045,9 +1046,21 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, goto fail; } - /* Initialize the output to all zeros */ + /* Initialize the output to all zeros or None*/ ret = NpyIter_GetOperandArray(iter)[nop]; - if (PyArray_AssignZero(ret, NULL) < 0) { + if (PyArray_ISOBJECT(ret)) { + /* + * Return zero + */ + PyObject * pZero = PyLong_FromLong(0); + int assign_result = PyArray_AssignRawScalar(ret, PyArray_DESCR(ret), + (char *)&pZero, NULL, NPY_SAFE_CASTING); + Py_DECREF(pZero); + if (assign_result < 0) { + goto fail; + } + } + else if (PyArray_AssignZero(ret, NULL) < 0) { goto fail; } |