diff options
author | mattip <matti.picus@gmail.com> | 2023-01-03 10:38:34 +0200 |
---|---|---|
committer | mattip <matti.picus@gmail.com> | 2023-01-03 10:57:23 +0200 |
commit | 1d4bef632205f1db9f8337198eb20e40ebefc859 (patch) | |
tree | 87d6c7fc657f2353964e116d38ee46581632d253 | |
parent | 333d01246db2e270288123095fc8be1ba5206ef4 (diff) | |
download | numpy-1d4bef632205f1db9f8337198eb20e40ebefc859.tar.gz |
MAINT: enable fast_path in more cases (from review)
-rw-r--r-- | numpy/core/src/umath/legacy_array_method.c | 4 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 22 |
2 files changed, 8 insertions, 18 deletions
diff --git a/numpy/core/src/umath/legacy_array_method.c b/numpy/core/src/umath/legacy_array_method.c index 96cc7a2f1..a8627d55c 100644 --- a/numpy/core/src/umath/legacy_array_method.c +++ b/numpy/core/src/umath/legacy_array_method.c @@ -91,10 +91,6 @@ generic_wrapped_legacy_loop(PyArrayMethod_Context *NPY_UNUSED(context), return 0; } -int -is_generic_wrapped_legacy_loop(PyArrayMethod_StridedLoop *strided_loop) { - return strided_loop == generic_wrapped_legacy_loop; -} /* * Signal that the old type-resolution function must be used to resolve * the descriptors (mainly/only used for datetimes due to the unit). diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 969bfbca5..3f36aced4 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -6411,16 +6411,14 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args) goto fail; } int fast_path = 1; - if (is_generic_wrapped_legacy_loop(strided_loop)) { - /* - * only user-defined dtypes will use a non-generic_wrapped_legacy loop - * so this path is very common - */ - for (int i=0; i<3; i++) { - if (operation_descrs[i] && !PyDataType_ISNUMBER(operation_descrs[i])) { - fast_path = 0; - } + for (int i=0; i<3; i++) { + /* over-cautious? Only use built-in numeric dtypes */ + if (operation_descrs[i] && !PyDataType_ISNUMBER(operation_descrs[i])) { + fast_path = 0; } + } + if (fast_path == 1) { + /* check no casting, alignment */ if (PyArray_DESCR(op1_array) != operation_descrs[0]) { fast_path = 0; } @@ -6434,11 +6432,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args) fast_path = 0; } } - else { - /* Not a known loop function */ - fast_path = 0; - } - if (fast_path) { + if (fast_path == 1) { res = ufunc_at__fast_iter(ufunc, flags, iter, iter2, op1_array, op2_array, strided_loop, &context, strides, auxdata); } else { |