diff options
author | mattip <matti.picus@gmail.com> | 2023-03-30 14:00:09 +0300 |
---|---|---|
committer | mattip <matti.picus@gmail.com> | 2023-03-30 16:02:55 +0300 |
commit | d1306cbb594bc78d9fef50924fb64d71e299d21a (patch) | |
tree | c429e83fbcb532d85b4a3f0eb66dcb1ddf729577 /numpy/core/src | |
parent | 31e21768ef38a6b60050060a3bce836059469745 (diff) | |
download | numpy-d1306cbb594bc78d9fef50924fb64d71e299d21a.tar.gz |
MAINT: expand PyArray_AssignZero to handle object dtype
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/multiarray/convert.c | 33 | ||||
-rw-r--r-- | numpy/core/src/multiarray/einsum.c.src | 14 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 15 |
3 files changed, 24 insertions, 38 deletions
diff --git a/numpy/core/src/multiarray/convert.c b/numpy/core/src/multiarray/convert.c index e99bc3fe4..ddaef0adc 100644 --- a/numpy/core/src/multiarray/convert.c +++ b/numpy/core/src/multiarray/convert.c @@ -423,7 +423,9 @@ PyArray_FillWithScalar(PyArrayObject *arr, PyObject *obj) } /* - * Fills an array with zeros. + * Internal function to fill an array with zeros. + * Used in einsum and dot, which ensures the dtype is, in some sense, numerical + * and not a str or struct * * dst: The destination array. * wheremask: If non-NULL, a boolean mask specifying where to set the values. @@ -434,21 +436,26 @@ NPY_NO_EXPORT int PyArray_AssignZero(PyArrayObject *dst, PyArrayObject *wheremask) { - npy_bool value; - PyArray_Descr *bool_dtype; - int retcode; - - /* Create a raw bool scalar with the value False */ - bool_dtype = PyArray_DescrFromType(NPY_BOOL); - if (bool_dtype == NULL) { - return -1; + int retcode = 0; + if (PyArray_ISOBJECT(dst)) { + PyObject * pZero = PyLong_FromLong(0); + retcode = PyArray_AssignRawScalar(dst, PyArray_DESCR(dst), + (char *)&pZero, wheremask, NPY_SAFE_CASTING); + Py_DECREF(pZero); } - value = 0; + else { + /* Create a raw bool scalar with the value False */ + PyArray_Descr *bool_dtype = PyArray_DescrFromType(NPY_BOOL); + if (bool_dtype == NULL) { + return -1; + } + npy_bool value = 0; - retcode = PyArray_AssignRawScalar(dst, bool_dtype, (char *)&value, - wheremask, NPY_SAFE_CASTING); + retcode = PyArray_AssignRawScalar(dst, bool_dtype, (char *)&value, + wheremask, NPY_SAFE_CASTING); - Py_DECREF(bool_dtype); + Py_DECREF(bool_dtype); + } return retcode; } diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index e64145203..856dce11b 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -1048,19 +1048,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, /* Initialize the output to all zeros or None*/ ret = NpyIter_GetOperandArray(iter)[nop]; - if (PyArray_ISOBJECT(ret)) { - /* - * Return zero - */ - PyObject * pZero = PyLong_FromLong(0); - int assign_result = PyArray_AssignRawScalar(ret, PyArray_DESCR(ret), - (char *)&pZero, NULL, NPY_SAFE_CASTING); - Py_DECREF(pZero); - if (assign_result < 0) { - goto fail; - } - } - else if (PyArray_AssignZero(ret, NULL) < 0) { + if (PyArray_AssignZero(ret, NULL) < 0) { goto fail; } diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 36dffe1f4..d7cb78ea8 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -66,6 +66,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0; #include "compiled_base.h" #include "mem_overlap.h" #include "typeinfo.h" +#include "convert.h" /* for PyArray_AssignZero */ #include "get_attr_string.h" #include "experimental_public_dtype_api.h" /* _get_experimental_dtype_api */ @@ -1084,18 +1085,8 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) } /* Ensure that multiarray.dot(<Nx0>,<0xM>) -> zeros((N,M)) */ if (PyArray_SIZE(ap1) == 0 && PyArray_SIZE(ap2) == 0) { - if (PyArray_ISOBJECT(out_buf)) { - // issue gh-23492: fill with int(0) when there is no iteration - PyObject * pZero = PyLong_FromLong(0); - int assign_result = PyArray_AssignRawScalar(out_buf, PyArray_DESCR(out_buf), - (char *)&pZero, NULL, NPY_SAFE_CASTING); - Py_DECREF(pZero); - if (assign_result < 0) { - goto fail; - } - } - else { - memset(PyArray_DATA(out_buf), 0, PyArray_NBYTES(out_buf)); + if (PyArray_AssignZero(out_buf, NULL) < 0) { + goto fail; } } |