summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-02-20 15:51:17 +0100
committerSebastian Berg <sebastianb@nvidia.com>2023-02-20 16:28:55 +0100
commit94472bd702afd2a063e630b3aa020d9ccc0ec7ec (patch)
tree378ffeca91701dccd0545dad6845ff12b921bd7d /numpy
parentaf9f656865d7948e8c8978e4002e7f5c1b8f0ecf (diff)
downloadnumpy-94472bd702afd2a063e630b3aa020d9ccc0ec7ec.tar.gz
ENH: Avoid use of item XINCREF and DECREF in fasttake
Rather, use the cast function directly when the copy is not trivial (which we know based on it not passing REFCHK, if that passes we assume memcpy is fine). Also uses memcpy, since no overlap is possible here.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/item_selection.c78
1 files changed, 55 insertions, 23 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 44c553133..25ea011c3 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -17,6 +17,7 @@
#include "multiarraymodule.h"
#include "common.h"
+#include "dtype_transfer.h"
#include "arrayobject.h"
#include "ctors.h"
#include "lowlevel_strided_loops.h"
@@ -39,7 +40,26 @@ npy_fasttake_impl(
PyArray_Descr *dtype, int axis)
{
NPY_BEGIN_THREADS_DEF;
- NPY_BEGIN_THREADS_DESCR(dtype);
+
+ NPY_cast_info cast_info;
+ NPY_ARRAYMETHOD_FLAGS flags;
+ NPY_cast_info_init(&cast_info);
+
+ if (!needs_refcounting) {
+ /* if "refcounting" is not needed memcpy is safe for a simple copy */
+ NPY_BEGIN_THREADS;
+ }
+ else {
+ if (PyArray_GetDTypeTransferFunction(
+ 1, itemsize, itemsize, dtype, dtype, 0,
+ &cast_info, &flags) < 0) {
+ return -1;
+ }
+ if (!(flags & NPY_METH_REQUIRES_PYAPI)) {
+ NPY_BEGIN_THREADS;
+ }
+ }
+
switch (clipmode) {
case NPY_RAISE:
for (npy_intp i = 0; i < n; i++) {
@@ -47,20 +67,22 @@ npy_fasttake_impl(
npy_intp tmp = indices[j];
if (check_and_adjust_index(&tmp, max_item, axis,
_save) < 0) {
- return -1;
+ goto fail;
}
char *tmp_src = src + tmp * chunk;
if (needs_refcounting) {
- for (npy_intp k = 0; k < nelem; k++) {
- PyArray_Item_INCREF(tmp_src, dtype);
- PyArray_Item_XDECREF(dest, dtype);
- memmove(dest, tmp_src, itemsize);
- dest += itemsize;
- tmp_src += itemsize;
+ char *data[2] = {tmp_src, dest};
+ npy_intp strides[2] = {itemsize, itemsize};
+ if (cast_info.func(
+ &cast_info.context, data, &nelem, strides,
+ cast_info.auxdata) < 0) {
+ NPY_END_THREADS;
+ goto fail;
}
+ dest += itemsize * nelem;
}
else {
- memmove(dest, tmp_src, chunk);
+ memcpy(dest, tmp_src, chunk);
dest += chunk;
}
}
@@ -83,16 +105,18 @@ npy_fasttake_impl(
}
char *tmp_src = src + tmp * chunk;
if (needs_refcounting) {
- for (npy_intp k = 0; k < nelem; k++) {
- PyArray_Item_INCREF(tmp_src, dtype);
- PyArray_Item_XDECREF(dest, dtype);
- memmove(dest, tmp_src, itemsize);
- dest += itemsize;
- tmp_src += itemsize;
+ char *data[2] = {tmp_src, dest};
+ npy_intp strides[2] = {itemsize, itemsize};
+ if (cast_info.func(
+ &cast_info.context, data, &nelem, strides,
+ cast_info.auxdata) < 0) {
+ NPY_END_THREADS;
+ goto fail;
}
+ dest += itemsize * nelem;
}
else {
- memmove(dest, tmp_src, chunk);
+ memcpy(dest, tmp_src, chunk);
dest += chunk;
}
}
@@ -111,16 +135,18 @@ npy_fasttake_impl(
}
char *tmp_src = src + tmp * chunk;
if (needs_refcounting) {
- for (npy_intp k = 0; k < nelem; k++) {
- PyArray_Item_INCREF(tmp_src, dtype);
- PyArray_Item_XDECREF(dest, dtype);
- memmove(dest, tmp_src, itemsize);
- dest += itemsize;
- tmp_src += itemsize;
+ char *data[2] = {tmp_src, dest};
+ npy_intp strides[2] = {itemsize, itemsize};
+ if (cast_info.func(
+ &cast_info.context, data, &nelem, strides,
+ cast_info.auxdata) < 0) {
+ NPY_END_THREADS;
+ goto fail;
}
+ dest += itemsize * nelem;
}
else {
- memmove(dest, tmp_src, chunk);
+ memcpy(dest, tmp_src, chunk);
dest += chunk;
}
}
@@ -130,7 +156,13 @@ npy_fasttake_impl(
}
NPY_END_THREADS;
+ NPY_cast_info_xfree(&cast_info);
return 0;
+
+ fail:
+ /* NPY_END_THREADS already ensured. */
+ NPY_cast_info_xfree(&cast_info);
+ return -1;
}