diff options
author | Matti Picus <matti.picus@gmail.com> | 2022-11-16 18:19:13 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-16 18:19:13 +0200 |
commit | d428d4599034c6039c6eff1ad33518a3e3888104 (patch) | |
tree | ce7807b660efd211002cbb0abd7c7350ea36e38f /numpy | |
parent | 21195337da3e87c0d153794bb65d2d32da542395 (diff) | |
parent | c973fc90c65e6864444677b4d8e0f70060a17da3 (diff) | |
download | numpy-d428d4599034c6039c6eff1ad33518a3e3888104.tar.gz |
Merge pull request #22422 from seberg/expose-dtype-resolution-get-loop
ENH: Expose `ufunc.resolve_dtypes` and strided loop access
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/_add_newdocs.py | 159 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 498 | ||||
-rw-r--r-- | numpy/core/tests/test_ufunc.py | 133 |
3 files changed, 750 insertions, 40 deletions
diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py index 9cfbce86a..d75f9ec64 100644 --- a/numpy/core/_add_newdocs.py +++ b/numpy/core/_add_newdocs.py @@ -5694,6 +5694,165 @@ add_newdoc('numpy.core', 'ufunc', ('at', """)) +add_newdoc('numpy.core', 'ufunc', ('resolve_dtypes', + """ + resolve_dtypes(dtypes, *, signature=None, casting=None, reduction=False) + + Find the dtypes NumPy will use for the operation. Both input and + output dtypes are returned and may differ from those provided. + + .. note:: + + This function always applies NEP 50 rules since it is not provided + any actual values. The Python types ``int``, ``float``, and + ``complex`` thus behave weak and should be passed for "untyped" + Python input. + + Parameters + ---------- + dtypes : tuple of dtypes, None, or literal int, float, complex + The input dtypes for each operand. Output operands can be + None, indicating that the dtype must be found. + signature : tuple of DTypes or None, optional + If given, enforces exact DType (classes) of the specific operand. + The ufunc ``dtype`` argument is equivalent to passing a tuple with + only output dtypes set. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + The casting mode when casting is necessary. This is identical to + the ufunc call casting modes. + reduction : boolean + If given, the resolution assumes a reduce operation is happening + which slightly changes the promotion and type resolution rules. + `dtypes` is usually something like ``(None, np.dtype("i2"), None)`` + for reductions (first input is also the output). + + .. note:: + + The default casting mode is "same_kind", however, as of + NumPy 1.24, NumPy uses "unsafe" for reductions. + + Returns + ------- + dtypes : tuple of dtypes + The dtypes which NumPy would use for the calculation. Note that + dtypes may not match the passed in ones (casting is necessary). + + See Also + -------- + numpy.ufunc._resolve_dtypes_and_context : + Similar function to this, but returns additional information which + give access to the core C functionality of NumPy. + + Examples + -------- + This API requires passing dtypes, define them for convenience: + + >>> int32 = np.dtype("int32") + >>> float32 = np.dtype("float32") + + The typical ufunc call does not pass an output dtype. `np.add` has two + inputs and one output, so leave the output as ``None`` (not provided): + + >>> np.add.resolve_dtypes((int32, float32, None)) + (dtype('float64'), dtype('float64'), dtype('float64')) + + The loop found uses "float64" for all operands (including the output), the + first input would be cast. + + ``resolve_dtypes`` supports "weak" handling for Python scalars by passing + ``int``, ``float``, or ``complex``: + + >>> np.add.resolve_dtypes((float32, float, None)) + (dtype('float32'), dtype('float32'), dtype('float32')) + + Where the Python ``float`` behaves samilar to a Python value ``0.0`` + in a ufunc call. (See :ref:`NEP 50 <NEP50>` for details.) + + """)) + +add_newdoc('numpy.core', 'ufunc', ('_resolve_dtypes_and_context', + """ + _resolve_dtypes_and_context(dtypes, *, signature=None, casting=None, reduction=False) + + See `numpy.ufunc.resolve_dtypes` for parameter information. This + function is considered *unstable*. You may use it, but the returned + information is NumPy version specific and expected to change. + Large API/ABI changes are not expected, but a new NumPy version is + expected to require updating code using this functionality. + + This function is designed to be used in conjunction with + `numpy.ufunc._get_strided_loop`. The calls are split to mirror the C API + and allow future improvements. + + Returns + ------- + dtypes : tuple of dtypes + call_info : + PyCapsule with all necessary information to get access to low level + C calls. See `numpy.ufunc._get_strided_loop` for more information. + + """)) + +add_newdoc('numpy.core', 'ufunc', ('_get_strided_loop', + """ + _get_strided_loop(call_info, /, *, fixed_strides=None) + + This function fills in the ``call_info`` capsule to include all + information necessary to call the low-level strided loop from NumPy. + + See notes for more information. + + Parameters + ---------- + call_info : PyCapsule + The PyCapsule returned by `numpy.ufunc._resolve_dtypes_and_context`. + fixed_strides : tuple of int or None, optional + A tuple with fixed byte strides of all input arrays. NumPy may use + this information to find specialized loops, so any call must follow + the given stride. Use ``None`` to indicate that the stride is not + known (or not fixed) for all calls. + + Notes + ----- + Together with `numpy.ufunc._resolve_dtypes_and_context` this function + gives low-level access to the NumPy ufunc loops. + The first function does general preparation and returns the required + information. It returns this as a C capsule with the version specific + name ``numpy_1.24_ufunc_call_info``. + The NumPy 1.24 ufunc call info capsule has the following layout:: + + typedef struct { + PyArrayMethod_StridedLoop *strided_loop; + PyArrayMethod_Context *context; + NpyAuxData *auxdata; + + /* Flag information (expected to change) */ + npy_bool requires_pyapi; /* GIL is required by loop */ + + /* Loop doesn't set FPE flags; if not set check FPE flags */ + npy_bool no_floatingpoint_errors; + } ufunc_call_info; + + Note that the first call only fills in the ``context``. The call to + ``_get_strided_loop`` fills in all other data. + Please see the ``numpy/experimental_dtype_api.h`` header for exact + call information; the main thing to note is that the new-style loops + return 0 on success, -1 on failure. They are passed context as new + first input and ``auxdata`` as (replaced) last. + + Only the ``strided_loop``signature is considered guaranteed stable + for NumPy bug-fix releases. All other API is tied to the experimental + API versioning. + + The reason for the split call is that cast information is required to + decide what the fixed-strides will be. + + NumPy ties the lifetime of the ``auxdata`` information to the capsule. + + """)) + + + ############################################################################## # # Documentation for dtype attributes and methods diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 6ebd3f788..b2955a6a5 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -52,6 +52,7 @@ #include "arrayobject.h" #include "common.h" +#include "ctors.h" #include "dtypemeta.h" #include "numpyos.h" #include "dispatching.h" @@ -2739,8 +2740,41 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out, PyArray_DTypeMeta *signature[3], npy_bool enforce_uniform_args, PyArray_Descr *out_descrs[3], - char *method) + NPY_CASTING casting, char *method) { + /* + * If no dtype is specified and out is not specified, we override the + * integer and bool dtype used for add and multiply. + * + * TODO: The following should be handled by a promoter! + */ + if (signature[0] == NULL && out == NULL) { + /* + * For integer types --- make sure at least a long + * is used for add and multiply reduction to avoid overflow + */ + int typenum = PyArray_TYPE(arr); + if ((PyTypeNum_ISBOOL(typenum) || PyTypeNum_ISINTEGER(typenum)) + && ((strcmp(ufunc->name, "add") == 0) + || (strcmp(ufunc->name, "multiply") == 0))) { + if (PyTypeNum_ISBOOL(typenum)) { + typenum = NPY_LONG; + } + else if ((size_t)PyArray_DESCR(arr)->elsize < sizeof(long)) { + if (PyTypeNum_ISUNSIGNED(typenum)) { + typenum = NPY_ULONG; + } + else { + typenum = NPY_LONG; + } + } + signature[0] = PyArray_DTypeFromTypeNum(typenum); + } + } + assert(signature[2] == NULL); /* we always fill it here */ + Py_XINCREF(signature[0]); + signature[2] = signature[0]; + /* * Note that the `ops` is not really correct. But legacy resolution * cannot quite handle the correct ops (e.g. a NULL first item if `out` @@ -2802,7 +2836,7 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc, * (although this should possibly happen through a deprecation) */ if (resolve_descriptors(3, ufunc, ufuncimpl, - ops, out_descrs, signature, NPY_UNSAFE_CASTING) < 0) { + ops, out_descrs, signature, casting) < 0) { return NULL; } @@ -2825,8 +2859,7 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc, goto fail; } /* TODO: This really should _not_ be unsafe casting (same above)! */ - if (validate_casting(ufuncimpl, - ufunc, ops, out_descrs, NPY_UNSAFE_CASTING) < 0) { + if (validate_casting(ufuncimpl, ufunc, ops, out_descrs, casting) < 0) { goto fail; } @@ -2834,7 +2867,7 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc, fail: for (int i = 0; i < 3; ++i) { - Py_DECREF(out_descrs[i]); + Py_CLEAR(out_descrs[i]); } return NULL; } @@ -2989,7 +3022,7 @@ PyUFunc_Reduce(PyUFuncObject *ufunc, */ PyArray_Descr *descrs[3]; PyArrayMethodObject *ufuncimpl = reducelike_promote_and_resolve(ufunc, - arr, out, signature, NPY_FALSE, descrs, "reduce"); + arr, out, signature, NPY_FALSE, descrs, NPY_UNSAFE_CASTING, "reduce"); if (ufuncimpl == NULL) { return NULL; } @@ -3094,7 +3127,8 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out, PyArray_Descr *descrs[3]; PyArrayMethodObject *ufuncimpl = reducelike_promote_and_resolve(ufunc, - arr, out, signature, NPY_TRUE, descrs, "accumulate"); + arr, out, signature, NPY_TRUE, descrs, NPY_UNSAFE_CASTING, + "accumulate"); if (ufuncimpl == NULL) { return NULL; } @@ -3511,7 +3545,8 @@ PyUFunc_Reduceat(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *ind, PyArray_Descr *descrs[3]; PyArrayMethodObject *ufuncimpl = reducelike_promote_and_resolve(ufunc, - arr, out, signature, NPY_TRUE, descrs, "reduceat"); + arr, out, signature, NPY_TRUE, descrs, NPY_UNSAFE_CASTING, + "reduceat"); if (ufuncimpl == NULL) { return NULL; } @@ -4169,38 +4204,6 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, } } - /* - * If no dtype is specified and out is not specified, we override the - * integer and bool dtype used for add and multiply. - * - * TODO: The following should be handled by a promoter! - */ - if (signature[0] == NULL && out == NULL) { - /* - * For integer types --- make sure at least a long - * is used for add and multiply reduction to avoid overflow - */ - int typenum = PyArray_TYPE(mp); - if ((PyTypeNum_ISBOOL(typenum) || PyTypeNum_ISINTEGER(typenum)) - && ((strcmp(ufunc->name, "add") == 0) - || (strcmp(ufunc->name, "multiply") == 0))) { - if (PyTypeNum_ISBOOL(typenum)) { - typenum = NPY_LONG; - } - else if ((size_t)PyArray_DESCR(mp)->elsize < sizeof(long)) { - if (PyTypeNum_ISUNSIGNED(typenum)) { - typenum = NPY_ULONG; - } - else { - typenum = NPY_LONG; - } - } - signature[0] = PyArray_DTypeFromTypeNum(typenum); - } - } - Py_XINCREF(signature[0]); - signature[2] = signature[0]; - switch(operation) { case UFUNC_REDUCE: ret = PyUFunc_Reduce(ufunc, @@ -6315,6 +6318,406 @@ fail: } +typedef struct { + PyArrayMethod_StridedLoop *strided_loop; + PyArrayMethod_Context *context; + NpyAuxData *auxdata; + /* Should move to flags, but lets keep it bools for now: */ + npy_bool requires_pyapi; + npy_bool no_floatingpoint_errors; + PyArrayMethod_Context _full_context; + PyArray_Descr *_descrs[]; +} ufunc_call_info; + + +void +free_ufunc_call_info(PyObject *self) +{ + ufunc_call_info *call_info = PyCapsule_GetPointer( + self, "numpy_1.24_ufunc_call_info"); + + PyArrayMethod_Context *context = call_info->context; + + int nargs = context->method->nin + context->method->nout; + for (int i = 0; i < nargs; i++) { + Py_DECREF(context->descriptors[i]); + } + Py_DECREF(context->caller); + Py_DECREF(context->method); + NPY_AUXDATA_FREE(call_info->auxdata); + + PyObject_Free(call_info); +} + + +/* + * Python entry-point to ufunc promotion and dtype/descr resolution. + * + * This function does most of the work required to execute ufunc without + * actually executing it. + * This can be very useful for downstream libraries that reimplement NumPy + * functionality, such as Numba or Dask. + */ +static PyObject * +py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context, + PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames) +{ + NPY_PREPARE_ARGPARSER; + + PyObject *descrs_tuple; + PyObject *signature_obj = NULL; + NPY_CASTING casting = NPY_DEFAULT_ASSIGN_CASTING; + npy_bool reduction = NPY_FALSE; + + if (npy_parse_arguments("resolve_dtypes", args, len_args, kwnames, + "", NULL, &descrs_tuple, + "$signature", NULL, &signature_obj, + "$casting", &PyArray_CastingConverter, &casting, + "$reduction", &PyArray_BoolConverter, &reduction, + NULL, NULL, NULL) < 0) { + return NULL; + } + + if (reduction && (ufunc->nin != 2 || ufunc->nout != 1)) { + PyErr_SetString(PyExc_ValueError, + "ufunc is not compatible with reduction operations."); + return NULL; + } + + /* + * Legacy type resolvers expect NumPy arrays as input. Until NEP 50 is + * adopted, it is most convenient to ensure that we have an "array" object + * before calling the type promotion. Eventually, this hack may be moved + * into the legacy type resolution code itself (probably after NumPy stops + * using legacy type resolution itself for the most part). + * + * We make the pretty safe assumptions here that: + * - Nobody will actually do anything with the array objects besides + * checking the descriptor or calling CanCast. + * - No type resolver will cause weird paths that mess with our promotion + * state (or mind us messing with it). + */ + PyObject *result = NULL; + PyObject *result_dtype_tuple = NULL; + + PyArrayObject *dummy_arrays[NPY_MAXARGS] = {NULL}; + PyArray_DTypeMeta *DTypes[NPY_MAXARGS] = {NULL}; + PyArray_DTypeMeta *signature[NPY_MAXARGS] = {NULL}; + PyArray_Descr *operation_descrs[NPY_MAXARGS] = {NULL}; + + /* This entry-point to promotion lives in the NEP 50 future: */ + int original_promotion_state = npy_promotion_state; + npy_promotion_state = NPY_USE_WEAK_PROMOTION; + + npy_bool promoting_pyscalars = NPY_FALSE; + npy_bool allow_legacy_promotion = NPY_TRUE; + + if (_get_fixed_signature(ufunc, NULL, signature_obj, signature) < 0) { + goto finish; + } + + if (!PyTuple_CheckExact(descrs_tuple) + || PyTuple_Size(descrs_tuple) != ufunc->nargs) { + PyErr_SetString(PyExc_TypeError, + "resolve_dtypes: The dtypes must be a tuple of " + "`ufunc.nargs` length."); + goto finish; + } + for (int i=0; i < ufunc->nargs; i++) { + /* + * We create dummy arrays for now. It should be OK to make this + * truly "dummy" (not even proper objects), but that is a hack better + * left for the legacy_type_resolution wrapper when NEP 50 is done. + */ + PyObject *descr_obj = PyTuple_GET_ITEM(descrs_tuple, i); + PyArray_Descr *descr; + + if (PyArray_DescrCheck(descr_obj)) { + descr = (PyArray_Descr *)descr_obj; + Py_INCREF(descr); + dummy_arrays[i] = (PyArrayObject *)PyArray_NewFromDescr_int( + &PyArray_Type, descr, 0, NULL, NULL, NULL, + 0, NULL, NULL, 0, 1); + if (dummy_arrays[i] == NULL) { + goto finish; + } + if (PyArray_DESCR(dummy_arrays[i]) != descr) { + PyErr_SetString(PyExc_NotImplementedError, + "dtype was replaced during array creation, the dtype is " + "unsupported currently (a subarray dtype?)."); + goto finish; + } + DTypes[i] = NPY_DTYPE(descr); + Py_INCREF(DTypes[i]); + if (!NPY_DT_is_legacy(DTypes[i])) { + allow_legacy_promotion = NPY_FALSE; + } + } + /* Explicitly allow int, float, and complex for the "weak" types. */ + else if (descr_obj == (PyObject *)&PyLong_Type) { + descr = PyArray_DescrFromType(NPY_LONG); + Py_INCREF(descr); + dummy_arrays[i] = (PyArrayObject *)PyArray_Empty(0, NULL, descr, 0); + if (dummy_arrays[i] == NULL) { + goto finish; + } + PyArray_ENABLEFLAGS(dummy_arrays[i], NPY_ARRAY_WAS_PYTHON_INT); + Py_INCREF(&PyArray_PyIntAbstractDType); + DTypes[i] = &PyArray_PyIntAbstractDType; + promoting_pyscalars = NPY_TRUE; + } + else if (descr_obj == (PyObject *)&PyFloat_Type) { + descr = PyArray_DescrFromType(NPY_DOUBLE); + Py_INCREF(descr); + dummy_arrays[i] = (PyArrayObject *)PyArray_Empty(0, NULL, descr, 0); + if (dummy_arrays[i] == NULL) { + goto finish; + } + PyArray_ENABLEFLAGS(dummy_arrays[i], NPY_ARRAY_WAS_PYTHON_FLOAT); + Py_INCREF(&PyArray_PyFloatAbstractDType); + DTypes[i] = &PyArray_PyFloatAbstractDType; + promoting_pyscalars = NPY_TRUE; + } + else if (descr_obj == (PyObject *)&PyComplex_Type) { + descr = PyArray_DescrFromType(NPY_CDOUBLE); + Py_INCREF(descr); + dummy_arrays[i] = (PyArrayObject *)PyArray_Empty(0, NULL, descr, 0); + if (dummy_arrays[i] == NULL) { + goto finish; + } + PyArray_ENABLEFLAGS(dummy_arrays[i], NPY_ARRAY_WAS_PYTHON_COMPLEX); + Py_INCREF(&PyArray_PyComplexAbstractDType); + DTypes[i] = &PyArray_PyComplexAbstractDType; + promoting_pyscalars = NPY_TRUE; + } + else if (descr_obj == Py_None) { + if (i < ufunc->nin && !(reduction && i == 0)) { + PyErr_SetString(PyExc_TypeError, + "All input dtypes must be provided " + "(except the first one in reductions)"); + goto finish; + } + } + else { + PyErr_SetString(PyExc_TypeError, + "Provided dtype must be a valid NumPy dtype, " + "int, float, complex, or None."); + goto finish; + } + } + + PyArrayMethodObject *ufuncimpl; + if (!reduction) { + ufuncimpl = promote_and_get_ufuncimpl(ufunc, + dummy_arrays, signature, DTypes, NPY_FALSE, + allow_legacy_promotion, promoting_pyscalars, NPY_FALSE); + if (ufuncimpl == NULL) { + goto finish; + } + + /* Find the correct descriptors for the operation */ + if (resolve_descriptors(ufunc->nargs, ufunc, ufuncimpl, + dummy_arrays, operation_descrs, signature, casting) < 0) { + goto finish; + } + + if (validate_casting( + ufuncimpl, ufunc, dummy_arrays, operation_descrs, casting) < 0) { + goto finish; + } + } + else { /* reduction */ + if (signature[2] != NULL) { + PyErr_SetString(PyExc_ValueError, + "Reduction signature must end with None, instead pass " + "the first DType in the signature."); + goto finish; + } + + if (dummy_arrays[2] != NULL) { + PyErr_SetString(PyExc_TypeError, + "Output dtype must not be passed for reductions, " + "pass the first input instead."); + goto finish; + } + + ufuncimpl = reducelike_promote_and_resolve(ufunc, + dummy_arrays[1], dummy_arrays[0], signature, NPY_FALSE, + operation_descrs, casting, "resolve_dtypes"); + + if (ufuncimpl == NULL) { + goto finish; + } + } + + result = PyArray_TupleFromItems( + ufunc->nargs, (PyObject **)operation_descrs, 0); + + if (result == NULL || !return_context) { + goto finish; + } + /* Result will be (dtype_tuple, call_info), so move it and clear result */ + result_dtype_tuple = result; + result = NULL; + + /* We may have to return the context: */ + ufunc_call_info *call_info; + call_info = PyObject_Malloc(sizeof(ufunc_call_info) + + ufunc->nargs * sizeof(PyArray_Descr *)); + if (call_info == NULL) { + PyErr_NoMemory(); + goto finish; + } + call_info->strided_loop = NULL; + call_info->auxdata = NULL; + call_info->context = &call_info->_full_context; + + /* + * We create a capsule with NumPy 1.24 in the name to signal that it is + * prone to change in version updates (it doesn't have to). + * This capsule is documented in the `ufunc._resolve_dtypes_and_context` + * docstring. + */ + PyObject *capsule = PyCapsule_New( + call_info, "numpy_1.24_ufunc_call_info", &free_ufunc_call_info); + if (capsule == NULL) { + PyObject_Free(call_info); + goto finish; + } + + PyArrayMethod_Context *context = call_info->context; + + Py_INCREF(ufunc); + context->caller = (PyObject *)ufunc; + Py_INCREF(ufuncimpl); + context->method = ufuncimpl; + context->descriptors = call_info->_descrs; + for (int i=0; i < ufunc->nargs; i++) { + Py_INCREF(operation_descrs[i]); + context->descriptors[i] = operation_descrs[i]; + } + + result = PyTuple_Pack(2, result_dtype_tuple, capsule); + /* cleanup and return */ + Py_DECREF(capsule); + + finish: + npy_promotion_state = original_promotion_state; + + Py_XDECREF(result_dtype_tuple); + for (int i = 0; i < ufunc->nargs; i++) { + Py_XDECREF(signature[i]); + Py_XDECREF(dummy_arrays[i]); + Py_XDECREF(operation_descrs[i]); + Py_XDECREF(DTypes[i]); + } + + return result; +} + + +static PyObject * +py_resolve_dtypes(PyUFuncObject *ufunc, + PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames) +{ + return py_resolve_dtypes_generic(ufunc, NPY_FALSE, args, len_args, kwnames); +} + + +static PyObject * +py_resolve_dtypes_and_context(PyUFuncObject *ufunc, + PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames) +{ + return py_resolve_dtypes_generic(ufunc, NPY_TRUE, args, len_args, kwnames); +} + + +static PyObject * +py_get_strided_loop(PyUFuncObject *ufunc, + PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames) +{ + NPY_PREPARE_ARGPARSER; + + PyObject *call_info_obj; + PyObject *fixed_strides_obj = Py_None; + npy_intp fixed_strides[NPY_MAXARGS]; + + if (npy_parse_arguments("_get_strided_loop", args, len_args, kwnames, + "", NULL, &call_info_obj, + "$fixed_strides", NULL, &fixed_strides_obj, + NULL, NULL, NULL) < 0) { + return NULL; + } + + ufunc_call_info *call_info = PyCapsule_GetPointer( + call_info_obj, "numpy_1.24_ufunc_call_info"); + if (call_info == NULL) { + /* Cannot have a context with NULL inside... */ + assert(PyErr_Occurred()); + return NULL; + } + if (call_info->strided_loop != NULL) { + PyErr_SetString(PyExc_TypeError, + "ufunc call_info has already been filled/used!"); + return NULL; + } + + if (call_info->context->caller != (PyObject *)ufunc) { + PyErr_SetString(PyExc_TypeError, + "calling get_strided_loop with incompatible context"); + return NULL; + } + + /* + * Strict conversion of fixed_strides, None, or tuple of int or None. + */ + if (fixed_strides_obj == Py_None) { + for (int i = 0; i < ufunc->nargs; i++) { + fixed_strides[i] = NPY_MAX_INTP; + } + } + else if (PyTuple_CheckExact(fixed_strides_obj) + && PyTuple_Size(fixed_strides_obj) == ufunc->nargs) { + for (int i = 0; i < ufunc->nargs; i++) { + PyObject *stride = PyTuple_GET_ITEM(fixed_strides_obj, i); + if (PyLong_CheckExact(stride)) { + fixed_strides[i] = PyLong_AsSsize_t(stride); + if (error_converting(fixed_strides[i])) { + return NULL; + } + } + else if (stride == Py_None) { + fixed_strides[i] = NPY_MAX_INTP; + } + else { + PyErr_SetString(PyExc_TypeError, + "_get_strided_loop(): fixed_strides tuple must contain " + "Python ints or None"); + return NULL; + } + } + } + else { + PyErr_SetString(PyExc_TypeError, + "_get_strided_loop(): fixed_strides must be a tuple or None"); + return NULL; + } + + NPY_ARRAYMETHOD_FLAGS flags; + if (call_info->context->method->get_strided_loop(call_info->context, + 1, 0, fixed_strides, &call_info->strided_loop, &call_info->auxdata, + &flags) < 0) { + return NULL; + } + + call_info->requires_pyapi = flags & NPY_METH_REQUIRES_PYAPI; + call_info->no_floatingpoint_errors = ( + flags & NPY_METH_NO_FLOATINGPOINT_ERRORS); + + Py_RETURN_NONE; +} + + static struct PyMethodDef ufunc_methods[] = { {"reduce", (PyCFunction)ufunc_reduce, @@ -6331,6 +6734,21 @@ static struct PyMethodDef ufunc_methods[] = { {"at", (PyCFunction)ufunc_at, METH_VARARGS, NULL}, + /* Lower level methods: */ + {"resolve_dtypes", + (PyCFunction)py_resolve_dtypes, + METH_FASTCALL | METH_KEYWORDS, NULL}, + /* + * The following two functions are public API, but underscored since they + * are C-user specific and allow direct access to the core of ufunc loops. + * (See their documentation for API stability.) + */ + {"_resolve_dtypes_and_context", + (PyCFunction)py_resolve_dtypes_and_context, + METH_FASTCALL | METH_KEYWORDS, NULL}, + {"_get_strided_loop", + (PyCFunction)py_get_strided_loop, + METH_FASTCALL | METH_KEYWORDS, NULL}, {NULL, NULL, 0, NULL} /* sentinel */ }; diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index 21b0d9e46..cc6c0839d 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -1,6 +1,7 @@ import warnings import itertools import sys +import ctypes as ct import pytest from pytest import param @@ -2656,3 +2657,135 @@ def test_addition_reduce_negative_zero(dtype, use_initial): # `sum([])` should probably be 0.0 and not -0.0 like `sum([-0.0])` assert not np.signbit(res.real) assert not np.signbit(res.imag) + +class TestLowlevelAPIAccess: + def test_resolve_dtypes_basic(self): + # Basic test for dtype resolution: + i4 = np.dtype("i4") + f4 = np.dtype("f4") + f8 = np.dtype("f8") + + r = np.add.resolve_dtypes((i4, f4, None)) + assert r == (f8, f8, f8) + + # Signature uses the same logic to parse as ufunc (less strict) + # the following is "same-kind" casting so works: + r = np.add.resolve_dtypes(( + i4, i4, None), signature=(None, None, "f4")) + assert r == (f4, f4, f4) + + # Check NEP 50 "weak" promotion also: + r = np.add.resolve_dtypes((f4, int, None)) + assert r == (f4, f4, f4) + + with pytest.raises(TypeError): + np.add.resolve_dtypes((i4, f4, None), casting="no") + + def test_weird_dtypes(self): + S0 = np.dtype("S0") + # S0 is often converted by NumPy to S1, but not here: + r = np.equal.resolve_dtypes((S0, S0, None)) + assert r == (S0, S0, np.dtype(bool)) + + # Subarray dtypes are weird and only really exist nested, they need + # the shift to full NEP 50 to be implemented nicely: + dts = np.dtype("10i") + with pytest.raises(NotImplementedError): + np.equal.resolve_dtypes((dts, dts, None)) + + def test_resolve_dtypes_reduction(self): + i4 = np.dtype("i4") + with pytest.raises(NotImplementedError): + np.add.resolve_dtypes((i4, i4, i4), reduction=True) + + @pytest.mark.parametrize("dtypes", [ + (np.dtype("i"), np.dtype("i")), + (None, np.dtype("i"), np.dtype("f")), + (np.dtype("i"), None, np.dtype("f")), + ("i4", "i4", None)]) + def test_resolve_dtypes_errors(self, dtypes): + with pytest.raises(TypeError): + np.add.resolve_dtypes(dtypes) + + def test_resolve_dtypes_reduction(self): + i2 = np.dtype("i2") + long_ = np.dtype("long") + # Check special addition resolution: + res = np.add.resolve_dtypes((None, i2, None), reduction=True) + assert res == (long_, long_, long_) + + def test_resolve_dtypes_reduction_errors(self): + i2 = np.dtype("i2") + + with pytest.raises(TypeError): + np.add.resolve_dtypes((None, i2, i2)) + + with pytest.raises(TypeError): + np.add.signature((None, None, "i4")) + + @pytest.mark.skipif(not hasattr(ct, "pythonapi"), + reason="`ctypes.pythonapi` required for capsule unpacking.") + def test_loop_access(self): + # This is a basic test for the full strided loop access + data_t = ct.ARRAY(ct.c_char_p, 2) + dim_t = ct.ARRAY(ct.c_ssize_t, 1) + strides_t = ct.ARRAY(ct.c_ssize_t, 2) + strided_loop_t = ct.CFUNCTYPE( + ct.c_int, ct.c_void_p, data_t, dim_t, strides_t, ct.c_void_p) + + class call_info_t(ct.Structure): + _fields_ = [ + ("strided_loop", strided_loop_t), + ("context", ct.c_void_p), + ("auxdata", ct.c_void_p), + ("requires_pyapi", ct.c_byte), + ("no_floatingpoint_errors", ct.c_byte), + ] + + i4 = np.dtype("i4") + dt, call_info_obj = np.negative._resolve_dtypes_and_context((i4, i4)) + assert dt == (i4, i4) # can be used without casting + + # Fill in the rest of the information: + np.negative._get_strided_loop(call_info_obj) + + ct.pythonapi.PyCapsule_GetPointer.restype = ct.c_void_p + call_info = ct.pythonapi.PyCapsule_GetPointer( + ct.py_object(call_info_obj), + ct.c_char_p(b"numpy_1.24_ufunc_call_info")) + + call_info = ct.cast(call_info, ct.POINTER(call_info_t)).contents + + arr = np.arange(10, dtype=i4) + call_info.strided_loop( + call_info.context, + data_t(arr.ctypes.data, arr.ctypes.data), + arr.ctypes.shape, # is a C-array with 10 here + strides_t(arr.ctypes.strides[0], arr.ctypes.strides[0]), + call_info.auxdata) + + # We just directly called the negative inner-loop in-place: + assert_array_equal(arr, -np.arange(10, dtype=i4)) + + @pytest.mark.parametrize("strides", [1, (1, 2, 3), (1, "2")]) + def test__get_strided_loop_errors_bad_strides(self, strides): + i4 = np.dtype("i4") + dt, call_info = np.negative._resolve_dtypes_and_context((i4, i4)) + + with pytest.raises(TypeError, match="fixed_strides.*tuple.*or None"): + np.negative._get_strided_loop(call_info, fixed_strides=strides) + + def test__get_strided_loop_errors_bad_call_info(self): + i4 = np.dtype("i4") + dt, call_info = np.negative._resolve_dtypes_and_context((i4, i4)) + + with pytest.raises(ValueError, match="PyCapsule"): + np.negative._get_strided_loop("not the capsule!") + + with pytest.raises(TypeError, match=".*incompatible context"): + np.add._get_strided_loop(call_info) + + np.negative._get_strided_loop(call_info) + with pytest.raises(TypeError): + # cannot call it a second time: + np.negative._get_strided_loop(call_info) |