summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/multiarray/einsum.c.src17
1 files changed, 15 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index e673437ef..e64145203 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -17,6 +17,7 @@
#include <numpy/npy_common.h>
#include <numpy/arrayobject.h>
#include <npy_pycompat.h>
+#include <array_assign.h> //PyArray_AssignRawScalar
#include <ctype.h>
@@ -1045,9 +1046,21 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
goto fail;
}
- /* Initialize the output to all zeros */
+ /* Initialize the output to all zeros or None*/
ret = NpyIter_GetOperandArray(iter)[nop];
- if (PyArray_AssignZero(ret, NULL) < 0) {
+ if (PyArray_ISOBJECT(ret)) {
+ /*
+ * Return zero
+ */
+ PyObject * pZero = PyLong_FromLong(0);
+ int assign_result = PyArray_AssignRawScalar(ret, PyArray_DESCR(ret),
+ (char *)&pZero, NULL, NPY_SAFE_CASTING);
+ Py_DECREF(pZero);
+ if (assign_result < 0) {
+ goto fail;
+ }
+ }
+ else if (PyArray_AssignZero(ret, NULL) < 0) {
goto fail;
}