diff options
| author | Sebastian Berg <sebastianb@nvidia.com> | 2023-02-20 15:51:17 +0100 |
|---|---|---|
| committer | Sebastian Berg <sebastianb@nvidia.com> | 2023-02-20 16:28:55 +0100 |
| commit | 94472bd702afd2a063e630b3aa020d9ccc0ec7ec (patch) | |
| tree | 378ffeca91701dccd0545dad6845ff12b921bd7d /numpy | |
| parent | af9f656865d7948e8c8978e4002e7f5c1b8f0ecf (diff) | |
| download | numpy-94472bd702afd2a063e630b3aa020d9ccc0ec7ec.tar.gz | |
ENH: Avoid use of item XINCREF and DECREF in fasttake
Rather, use the cast function directly when the copy is not trivial
(which we know based on it not passing REFCHK, if that passes we assume
memcpy is fine).
Also uses memcpy, since no overlap is possible here.
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 78 |
1 files changed, 55 insertions, 23 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 44c553133..25ea011c3 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -17,6 +17,7 @@ #include "multiarraymodule.h" #include "common.h" +#include "dtype_transfer.h" #include "arrayobject.h" #include "ctors.h" #include "lowlevel_strided_loops.h" @@ -39,7 +40,26 @@ npy_fasttake_impl( PyArray_Descr *dtype, int axis) { NPY_BEGIN_THREADS_DEF; - NPY_BEGIN_THREADS_DESCR(dtype); + + NPY_cast_info cast_info; + NPY_ARRAYMETHOD_FLAGS flags; + NPY_cast_info_init(&cast_info); + + if (!needs_refcounting) { + /* if "refcounting" is not needed memcpy is safe for a simple copy */ + NPY_BEGIN_THREADS; + } + else { + if (PyArray_GetDTypeTransferFunction( + 1, itemsize, itemsize, dtype, dtype, 0, + &cast_info, &flags) < 0) { + return -1; + } + if (!(flags & NPY_METH_REQUIRES_PYAPI)) { + NPY_BEGIN_THREADS; + } + } + switch (clipmode) { case NPY_RAISE: for (npy_intp i = 0; i < n; i++) { @@ -47,20 +67,22 @@ npy_fasttake_impl( npy_intp tmp = indices[j]; if (check_and_adjust_index(&tmp, max_item, axis, _save) < 0) { - return -1; + goto fail; } char *tmp_src = src + tmp * chunk; if (needs_refcounting) { - for (npy_intp k = 0; k < nelem; k++) { - PyArray_Item_INCREF(tmp_src, dtype); - PyArray_Item_XDECREF(dest, dtype); - memmove(dest, tmp_src, itemsize); - dest += itemsize; - tmp_src += itemsize; + char *data[2] = {tmp_src, dest}; + npy_intp strides[2] = {itemsize, itemsize}; + if (cast_info.func( + &cast_info.context, data, &nelem, strides, + cast_info.auxdata) < 0) { + NPY_END_THREADS; + goto fail; } + dest += itemsize * nelem; } else { - memmove(dest, tmp_src, chunk); + memcpy(dest, tmp_src, chunk); dest += chunk; } } @@ -83,16 +105,18 @@ npy_fasttake_impl( } char *tmp_src = src + tmp * chunk; if (needs_refcounting) { - for (npy_intp k = 0; k < nelem; k++) { - PyArray_Item_INCREF(tmp_src, dtype); - PyArray_Item_XDECREF(dest, dtype); - memmove(dest, tmp_src, itemsize); - dest += itemsize; - tmp_src += itemsize; + char *data[2] = {tmp_src, dest}; + npy_intp strides[2] = {itemsize, itemsize}; + if (cast_info.func( + &cast_info.context, data, &nelem, strides, + cast_info.auxdata) < 0) { + NPY_END_THREADS; + goto fail; } + dest += itemsize * nelem; } else { - memmove(dest, tmp_src, chunk); + memcpy(dest, tmp_src, chunk); dest += chunk; } } @@ -111,16 +135,18 @@ npy_fasttake_impl( } char *tmp_src = src + tmp * chunk; if (needs_refcounting) { - for (npy_intp k = 0; k < nelem; k++) { - PyArray_Item_INCREF(tmp_src, dtype); - PyArray_Item_XDECREF(dest, dtype); - memmove(dest, tmp_src, itemsize); - dest += itemsize; - tmp_src += itemsize; + char *data[2] = {tmp_src, dest}; + npy_intp strides[2] = {itemsize, itemsize}; + if (cast_info.func( + &cast_info.context, data, &nelem, strides, + cast_info.auxdata) < 0) { + NPY_END_THREADS; + goto fail; } + dest += itemsize * nelem; } else { - memmove(dest, tmp_src, chunk); + memcpy(dest, tmp_src, chunk); dest += chunk; } } @@ -130,7 +156,13 @@ npy_fasttake_impl( } NPY_END_THREADS; + NPY_cast_info_xfree(&cast_info); return 0; + + fail: + /* NPY_END_THREADS already ensured. */ + NPY_cast_info_xfree(&cast_info); + return -1; } |
