summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormattip <matti.picus@gmail.com>2023-01-03 10:38:34 +0200
committermattip <matti.picus@gmail.com>2023-01-03 10:57:23 +0200
commit1d4bef632205f1db9f8337198eb20e40ebefc859 (patch)
tree87d6c7fc657f2353964e116d38ee46581632d253
parent333d01246db2e270288123095fc8be1ba5206ef4 (diff)
downloadnumpy-1d4bef632205f1db9f8337198eb20e40ebefc859.tar.gz
MAINT: enable fast_path in more cases (from review)
-rw-r--r--numpy/core/src/umath/legacy_array_method.c4
-rw-r--r--numpy/core/src/umath/ufunc_object.c22
2 files changed, 8 insertions, 18 deletions
diff --git a/numpy/core/src/umath/legacy_array_method.c b/numpy/core/src/umath/legacy_array_method.c
index 96cc7a2f1..a8627d55c 100644
--- a/numpy/core/src/umath/legacy_array_method.c
+++ b/numpy/core/src/umath/legacy_array_method.c
@@ -91,10 +91,6 @@ generic_wrapped_legacy_loop(PyArrayMethod_Context *NPY_UNUSED(context),
return 0;
}
-int
-is_generic_wrapped_legacy_loop(PyArrayMethod_StridedLoop *strided_loop) {
- return strided_loop == generic_wrapped_legacy_loop;
-}
/*
* Signal that the old type-resolution function must be used to resolve
* the descriptors (mainly/only used for datetimes due to the unit).
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 969bfbca5..3f36aced4 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -6411,16 +6411,14 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
goto fail;
}
int fast_path = 1;
- if (is_generic_wrapped_legacy_loop(strided_loop)) {
- /*
- * only user-defined dtypes will use a non-generic_wrapped_legacy loop
- * so this path is very common
- */
- for (int i=0; i<3; i++) {
- if (operation_descrs[i] && !PyDataType_ISNUMBER(operation_descrs[i])) {
- fast_path = 0;
- }
+ for (int i=0; i<3; i++) {
+ /* over-cautious? Only use built-in numeric dtypes */
+ if (operation_descrs[i] && !PyDataType_ISNUMBER(operation_descrs[i])) {
+ fast_path = 0;
}
+ }
+ if (fast_path == 1) {
+ /* check no casting, alignment */
if (PyArray_DESCR(op1_array) != operation_descrs[0]) {
fast_path = 0;
}
@@ -6434,11 +6432,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
fast_path = 0;
}
}
- else {
- /* Not a known loop function */
- fast_path = 0;
- }
- if (fast_path) {
+ if (fast_path == 1) {
res = ufunc_at__fast_iter(ufunc, flags, iter, iter2, op1_array, op2_array,
strided_loop, &context, strides, auxdata);
} else {