summaryrefslogtreecommitdiff
path: root/numpy/core/src/multiarray
diff options
context:
space:
mode:
authorThe Dog Lulu <Iamsoto@users.noreply.github.com>2020-12-26 17:20:29 -0800
committermattip <matti.picus@gmail.com>2023-03-29 14:10:58 +0300
commit20ea97fcc47ce15f2236192c6e826f3e2609bbe2 (patch)
tree78f5105af96d9c9eb5a5ba0f9818209eccb8bf2c /numpy/core/src/multiarray
parent53f7b55cefa7207240e8bf5b8ec0f661c7b36491 (diff)
downloadnumpy-20ea97fcc47ce15f2236192c6e826f3e2609bbe2.tar.gz
Update numpy/core/src/multiarray/einsum_sumprod.c.src
Implementing Erics suggestions Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
Diffstat (limited to 'numpy/core/src/multiarray')
-rw-r--r--numpy/core/src/multiarray/einsum_sumprod.c.src42
1 files changed, 17 insertions, 25 deletions
diff --git a/numpy/core/src/multiarray/einsum_sumprod.c.src b/numpy/core/src/multiarray/einsum_sumprod.c.src
index e39ff0cee..fa59223fe 100644
--- a/numpy/core/src/multiarray/einsum_sumprod.c.src
+++ b/numpy/core/src/multiarray/einsum_sumprod.c.src
@@ -1051,45 +1051,37 @@ static void
@fn_name@(int nop, char **dataptr,
npy_intp const *strides, npy_intp count)
{
- PyObject *tmp1, *tmp2 = NULL;
-
while(count--){
- tmp1 = *(PyObject **)dataptr[0];
- if(!tmp1){
- return;
+ PyObject *prod = *(PyObject **)dataptr[0];
+ if (!prod) {
+ prod = Py_None; // convention is to treat nulls as None
}
- int i;
- Py_INCREF(tmp1);
- for(i = 1; i < nop; ++i){
- if((*(PyObject **)dataptr[i]) == NULL){
- return;
+ Py_INCREF(prod);
+ for (int i = 1; i < nop; ++i){
+ PyObject *curr = *(PyObject **)dataptr[i];
+ if (!curr) {
+ curr = Py_None; // convention is to treat nulls as None
}
- tmp2 = PyNumber_Multiply(*(PyObject **)dataptr[i], tmp1);
- Py_XDECREF(tmp1);
- if(!tmp2){
- /* Potential raised Exception */
+ Py_SETREF(prod, PyNumber_Multiply(curr, prod));
+ if (!prod) {
return;
}
- tmp1 = tmp2;
}
- if(*(PyObject **)dataptr[i] == NULL){
- Py_XDECREF(tmp2);
+ if (*(PyObject **)dataptr[nop] == NULL) {
+ Py_DECREF(prod);
return;
}
- tmp2 = PyNumber_Add(*(PyObject **)dataptr[i], tmp1);
- Py_XDECREF(tmp1);
-
- if(!tmp2){
- /* Potential raised Exception */
+ PyObject *sum = PyNumber_Add(*(PyObject **)dataptr[nop], prod);
+ Py_DECREF(prod);
+ if (!sum) {
return;
}
Py_XDECREF(*(PyObject **)dataptr[nop]);
- *(PyObject **)dataptr[nop] = tmp2;
- tmp2 = 0;
- for (i = 0; i <= nop; ++i){
+ *(PyObject **)dataptr[nop] = sum;
+ for (int i = 0; i <= nop; ++i) {
dataptr[i] += strides[i];
}
}