summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2022-11-16 18:19:13 +0200
committerGitHub <noreply@github.com>2022-11-16 18:19:13 +0200
commitd428d4599034c6039c6eff1ad33518a3e3888104 (patch)
treece7807b660efd211002cbb0abd7c7350ea36e38f /numpy
parent21195337da3e87c0d153794bb65d2d32da542395 (diff)
parentc973fc90c65e6864444677b4d8e0f70060a17da3 (diff)
downloadnumpy-d428d4599034c6039c6eff1ad33518a3e3888104.tar.gz
Merge pull request #22422 from seberg/expose-dtype-resolution-get-loop
ENH: Expose `ufunc.resolve_dtypes` and strided loop access
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_add_newdocs.py159
-rw-r--r--numpy/core/src/umath/ufunc_object.c498
-rw-r--r--numpy/core/tests/test_ufunc.py133
3 files changed, 750 insertions, 40 deletions
diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py
index 9cfbce86a..d75f9ec64 100644
--- a/numpy/core/_add_newdocs.py
+++ b/numpy/core/_add_newdocs.py
@@ -5694,6 +5694,165 @@ add_newdoc('numpy.core', 'ufunc', ('at',
"""))
+add_newdoc('numpy.core', 'ufunc', ('resolve_dtypes',
+ """
+ resolve_dtypes(dtypes, *, signature=None, casting=None, reduction=False)
+
+ Find the dtypes NumPy will use for the operation. Both input and
+ output dtypes are returned and may differ from those provided.
+
+ .. note::
+
+ This function always applies NEP 50 rules since it is not provided
+ any actual values. The Python types ``int``, ``float``, and
+ ``complex`` thus behave weak and should be passed for "untyped"
+ Python input.
+
+ Parameters
+ ----------
+ dtypes : tuple of dtypes, None, or literal int, float, complex
+ The input dtypes for each operand. Output operands can be
+ None, indicating that the dtype must be found.
+ signature : tuple of DTypes or None, optional
+ If given, enforces exact DType (classes) of the specific operand.
+ The ufunc ``dtype`` argument is equivalent to passing a tuple with
+ only output dtypes set.
+ casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
+ The casting mode when casting is necessary. This is identical to
+ the ufunc call casting modes.
+ reduction : boolean
+ If given, the resolution assumes a reduce operation is happening
+ which slightly changes the promotion and type resolution rules.
+ `dtypes` is usually something like ``(None, np.dtype("i2"), None)``
+ for reductions (first input is also the output).
+
+ .. note::
+
+ The default casting mode is "same_kind", however, as of
+ NumPy 1.24, NumPy uses "unsafe" for reductions.
+
+ Returns
+ -------
+ dtypes : tuple of dtypes
+ The dtypes which NumPy would use for the calculation. Note that
+ dtypes may not match the passed in ones (casting is necessary).
+
+ See Also
+ --------
+ numpy.ufunc._resolve_dtypes_and_context :
+ Similar function to this, but returns additional information which
+ give access to the core C functionality of NumPy.
+
+ Examples
+ --------
+ This API requires passing dtypes, define them for convenience:
+
+ >>> int32 = np.dtype("int32")
+ >>> float32 = np.dtype("float32")
+
+ The typical ufunc call does not pass an output dtype. `np.add` has two
+ inputs and one output, so leave the output as ``None`` (not provided):
+
+ >>> np.add.resolve_dtypes((int32, float32, None))
+ (dtype('float64'), dtype('float64'), dtype('float64'))
+
+ The loop found uses "float64" for all operands (including the output), the
+ first input would be cast.
+
+ ``resolve_dtypes`` supports "weak" handling for Python scalars by passing
+ ``int``, ``float``, or ``complex``:
+
+ >>> np.add.resolve_dtypes((float32, float, None))
+ (dtype('float32'), dtype('float32'), dtype('float32'))
+
+ Where the Python ``float`` behaves samilar to a Python value ``0.0``
+ in a ufunc call. (See :ref:`NEP 50 <NEP50>` for details.)
+
+ """))
+
+add_newdoc('numpy.core', 'ufunc', ('_resolve_dtypes_and_context',
+ """
+ _resolve_dtypes_and_context(dtypes, *, signature=None, casting=None, reduction=False)
+
+ See `numpy.ufunc.resolve_dtypes` for parameter information. This
+ function is considered *unstable*. You may use it, but the returned
+ information is NumPy version specific and expected to change.
+ Large API/ABI changes are not expected, but a new NumPy version is
+ expected to require updating code using this functionality.
+
+ This function is designed to be used in conjunction with
+ `numpy.ufunc._get_strided_loop`. The calls are split to mirror the C API
+ and allow future improvements.
+
+ Returns
+ -------
+ dtypes : tuple of dtypes
+ call_info :
+ PyCapsule with all necessary information to get access to low level
+ C calls. See `numpy.ufunc._get_strided_loop` for more information.
+
+ """))
+
+add_newdoc('numpy.core', 'ufunc', ('_get_strided_loop',
+ """
+ _get_strided_loop(call_info, /, *, fixed_strides=None)
+
+ This function fills in the ``call_info`` capsule to include all
+ information necessary to call the low-level strided loop from NumPy.
+
+ See notes for more information.
+
+ Parameters
+ ----------
+ call_info : PyCapsule
+ The PyCapsule returned by `numpy.ufunc._resolve_dtypes_and_context`.
+ fixed_strides : tuple of int or None, optional
+ A tuple with fixed byte strides of all input arrays. NumPy may use
+ this information to find specialized loops, so any call must follow
+ the given stride. Use ``None`` to indicate that the stride is not
+ known (or not fixed) for all calls.
+
+ Notes
+ -----
+ Together with `numpy.ufunc._resolve_dtypes_and_context` this function
+ gives low-level access to the NumPy ufunc loops.
+ The first function does general preparation and returns the required
+ information. It returns this as a C capsule with the version specific
+ name ``numpy_1.24_ufunc_call_info``.
+ The NumPy 1.24 ufunc call info capsule has the following layout::
+
+ typedef struct {
+ PyArrayMethod_StridedLoop *strided_loop;
+ PyArrayMethod_Context *context;
+ NpyAuxData *auxdata;
+
+ /* Flag information (expected to change) */
+ npy_bool requires_pyapi; /* GIL is required by loop */
+
+ /* Loop doesn't set FPE flags; if not set check FPE flags */
+ npy_bool no_floatingpoint_errors;
+ } ufunc_call_info;
+
+ Note that the first call only fills in the ``context``. The call to
+ ``_get_strided_loop`` fills in all other data.
+ Please see the ``numpy/experimental_dtype_api.h`` header for exact
+ call information; the main thing to note is that the new-style loops
+ return 0 on success, -1 on failure. They are passed context as new
+ first input and ``auxdata`` as (replaced) last.
+
+ Only the ``strided_loop``signature is considered guaranteed stable
+ for NumPy bug-fix releases. All other API is tied to the experimental
+ API versioning.
+
+ The reason for the split call is that cast information is required to
+ decide what the fixed-strides will be.
+
+ NumPy ties the lifetime of the ``auxdata`` information to the capsule.
+
+ """))
+
+
+
##############################################################################
#
# Documentation for dtype attributes and methods
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 6ebd3f788..b2955a6a5 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -52,6 +52,7 @@
#include "arrayobject.h"
#include "common.h"
+#include "ctors.h"
#include "dtypemeta.h"
#include "numpyos.h"
#include "dispatching.h"
@@ -2739,8 +2740,41 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc,
PyArrayObject *arr, PyArrayObject *out,
PyArray_DTypeMeta *signature[3],
npy_bool enforce_uniform_args, PyArray_Descr *out_descrs[3],
- char *method)
+ NPY_CASTING casting, char *method)
{
+ /*
+ * If no dtype is specified and out is not specified, we override the
+ * integer and bool dtype used for add and multiply.
+ *
+ * TODO: The following should be handled by a promoter!
+ */
+ if (signature[0] == NULL && out == NULL) {
+ /*
+ * For integer types --- make sure at least a long
+ * is used for add and multiply reduction to avoid overflow
+ */
+ int typenum = PyArray_TYPE(arr);
+ if ((PyTypeNum_ISBOOL(typenum) || PyTypeNum_ISINTEGER(typenum))
+ && ((strcmp(ufunc->name, "add") == 0)
+ || (strcmp(ufunc->name, "multiply") == 0))) {
+ if (PyTypeNum_ISBOOL(typenum)) {
+ typenum = NPY_LONG;
+ }
+ else if ((size_t)PyArray_DESCR(arr)->elsize < sizeof(long)) {
+ if (PyTypeNum_ISUNSIGNED(typenum)) {
+ typenum = NPY_ULONG;
+ }
+ else {
+ typenum = NPY_LONG;
+ }
+ }
+ signature[0] = PyArray_DTypeFromTypeNum(typenum);
+ }
+ }
+ assert(signature[2] == NULL); /* we always fill it here */
+ Py_XINCREF(signature[0]);
+ signature[2] = signature[0];
+
/*
* Note that the `ops` is not really correct. But legacy resolution
* cannot quite handle the correct ops (e.g. a NULL first item if `out`
@@ -2802,7 +2836,7 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc,
* (although this should possibly happen through a deprecation)
*/
if (resolve_descriptors(3, ufunc, ufuncimpl,
- ops, out_descrs, signature, NPY_UNSAFE_CASTING) < 0) {
+ ops, out_descrs, signature, casting) < 0) {
return NULL;
}
@@ -2825,8 +2859,7 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc,
goto fail;
}
/* TODO: This really should _not_ be unsafe casting (same above)! */
- if (validate_casting(ufuncimpl,
- ufunc, ops, out_descrs, NPY_UNSAFE_CASTING) < 0) {
+ if (validate_casting(ufuncimpl, ufunc, ops, out_descrs, casting) < 0) {
goto fail;
}
@@ -2834,7 +2867,7 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc,
fail:
for (int i = 0; i < 3; ++i) {
- Py_DECREF(out_descrs[i]);
+ Py_CLEAR(out_descrs[i]);
}
return NULL;
}
@@ -2989,7 +3022,7 @@ PyUFunc_Reduce(PyUFuncObject *ufunc,
*/
PyArray_Descr *descrs[3];
PyArrayMethodObject *ufuncimpl = reducelike_promote_and_resolve(ufunc,
- arr, out, signature, NPY_FALSE, descrs, "reduce");
+ arr, out, signature, NPY_FALSE, descrs, NPY_UNSAFE_CASTING, "reduce");
if (ufuncimpl == NULL) {
return NULL;
}
@@ -3094,7 +3127,8 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
PyArray_Descr *descrs[3];
PyArrayMethodObject *ufuncimpl = reducelike_promote_and_resolve(ufunc,
- arr, out, signature, NPY_TRUE, descrs, "accumulate");
+ arr, out, signature, NPY_TRUE, descrs, NPY_UNSAFE_CASTING,
+ "accumulate");
if (ufuncimpl == NULL) {
return NULL;
}
@@ -3511,7 +3545,8 @@ PyUFunc_Reduceat(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *ind,
PyArray_Descr *descrs[3];
PyArrayMethodObject *ufuncimpl = reducelike_promote_and_resolve(ufunc,
- arr, out, signature, NPY_TRUE, descrs, "reduceat");
+ arr, out, signature, NPY_TRUE, descrs, NPY_UNSAFE_CASTING,
+ "reduceat");
if (ufuncimpl == NULL) {
return NULL;
}
@@ -4169,38 +4204,6 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc,
}
}
- /*
- * If no dtype is specified and out is not specified, we override the
- * integer and bool dtype used for add and multiply.
- *
- * TODO: The following should be handled by a promoter!
- */
- if (signature[0] == NULL && out == NULL) {
- /*
- * For integer types --- make sure at least a long
- * is used for add and multiply reduction to avoid overflow
- */
- int typenum = PyArray_TYPE(mp);
- if ((PyTypeNum_ISBOOL(typenum) || PyTypeNum_ISINTEGER(typenum))
- && ((strcmp(ufunc->name, "add") == 0)
- || (strcmp(ufunc->name, "multiply") == 0))) {
- if (PyTypeNum_ISBOOL(typenum)) {
- typenum = NPY_LONG;
- }
- else if ((size_t)PyArray_DESCR(mp)->elsize < sizeof(long)) {
- if (PyTypeNum_ISUNSIGNED(typenum)) {
- typenum = NPY_ULONG;
- }
- else {
- typenum = NPY_LONG;
- }
- }
- signature[0] = PyArray_DTypeFromTypeNum(typenum);
- }
- }
- Py_XINCREF(signature[0]);
- signature[2] = signature[0];
-
switch(operation) {
case UFUNC_REDUCE:
ret = PyUFunc_Reduce(ufunc,
@@ -6315,6 +6318,406 @@ fail:
}
+typedef struct {
+ PyArrayMethod_StridedLoop *strided_loop;
+ PyArrayMethod_Context *context;
+ NpyAuxData *auxdata;
+ /* Should move to flags, but lets keep it bools for now: */
+ npy_bool requires_pyapi;
+ npy_bool no_floatingpoint_errors;
+ PyArrayMethod_Context _full_context;
+ PyArray_Descr *_descrs[];
+} ufunc_call_info;
+
+
+void
+free_ufunc_call_info(PyObject *self)
+{
+ ufunc_call_info *call_info = PyCapsule_GetPointer(
+ self, "numpy_1.24_ufunc_call_info");
+
+ PyArrayMethod_Context *context = call_info->context;
+
+ int nargs = context->method->nin + context->method->nout;
+ for (int i = 0; i < nargs; i++) {
+ Py_DECREF(context->descriptors[i]);
+ }
+ Py_DECREF(context->caller);
+ Py_DECREF(context->method);
+ NPY_AUXDATA_FREE(call_info->auxdata);
+
+ PyObject_Free(call_info);
+}
+
+
+/*
+ * Python entry-point to ufunc promotion and dtype/descr resolution.
+ *
+ * This function does most of the work required to execute ufunc without
+ * actually executing it.
+ * This can be very useful for downstream libraries that reimplement NumPy
+ * functionality, such as Numba or Dask.
+ */
+static PyObject *
+py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context,
+ PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
+{
+ NPY_PREPARE_ARGPARSER;
+
+ PyObject *descrs_tuple;
+ PyObject *signature_obj = NULL;
+ NPY_CASTING casting = NPY_DEFAULT_ASSIGN_CASTING;
+ npy_bool reduction = NPY_FALSE;
+
+ if (npy_parse_arguments("resolve_dtypes", args, len_args, kwnames,
+ "", NULL, &descrs_tuple,
+ "$signature", NULL, &signature_obj,
+ "$casting", &PyArray_CastingConverter, &casting,
+ "$reduction", &PyArray_BoolConverter, &reduction,
+ NULL, NULL, NULL) < 0) {
+ return NULL;
+ }
+
+ if (reduction && (ufunc->nin != 2 || ufunc->nout != 1)) {
+ PyErr_SetString(PyExc_ValueError,
+ "ufunc is not compatible with reduction operations.");
+ return NULL;
+ }
+
+ /*
+ * Legacy type resolvers expect NumPy arrays as input. Until NEP 50 is
+ * adopted, it is most convenient to ensure that we have an "array" object
+ * before calling the type promotion. Eventually, this hack may be moved
+ * into the legacy type resolution code itself (probably after NumPy stops
+ * using legacy type resolution itself for the most part).
+ *
+ * We make the pretty safe assumptions here that:
+ * - Nobody will actually do anything with the array objects besides
+ * checking the descriptor or calling CanCast.
+ * - No type resolver will cause weird paths that mess with our promotion
+ * state (or mind us messing with it).
+ */
+ PyObject *result = NULL;
+ PyObject *result_dtype_tuple = NULL;
+
+ PyArrayObject *dummy_arrays[NPY_MAXARGS] = {NULL};
+ PyArray_DTypeMeta *DTypes[NPY_MAXARGS] = {NULL};
+ PyArray_DTypeMeta *signature[NPY_MAXARGS] = {NULL};
+ PyArray_Descr *operation_descrs[NPY_MAXARGS] = {NULL};
+
+ /* This entry-point to promotion lives in the NEP 50 future: */
+ int original_promotion_state = npy_promotion_state;
+ npy_promotion_state = NPY_USE_WEAK_PROMOTION;
+
+ npy_bool promoting_pyscalars = NPY_FALSE;
+ npy_bool allow_legacy_promotion = NPY_TRUE;
+
+ if (_get_fixed_signature(ufunc, NULL, signature_obj, signature) < 0) {
+ goto finish;
+ }
+
+ if (!PyTuple_CheckExact(descrs_tuple)
+ || PyTuple_Size(descrs_tuple) != ufunc->nargs) {
+ PyErr_SetString(PyExc_TypeError,
+ "resolve_dtypes: The dtypes must be a tuple of "
+ "`ufunc.nargs` length.");
+ goto finish;
+ }
+ for (int i=0; i < ufunc->nargs; i++) {
+ /*
+ * We create dummy arrays for now. It should be OK to make this
+ * truly "dummy" (not even proper objects), but that is a hack better
+ * left for the legacy_type_resolution wrapper when NEP 50 is done.
+ */
+ PyObject *descr_obj = PyTuple_GET_ITEM(descrs_tuple, i);
+ PyArray_Descr *descr;
+
+ if (PyArray_DescrCheck(descr_obj)) {
+ descr = (PyArray_Descr *)descr_obj;
+ Py_INCREF(descr);
+ dummy_arrays[i] = (PyArrayObject *)PyArray_NewFromDescr_int(
+ &PyArray_Type, descr, 0, NULL, NULL, NULL,
+ 0, NULL, NULL, 0, 1);
+ if (dummy_arrays[i] == NULL) {
+ goto finish;
+ }
+ if (PyArray_DESCR(dummy_arrays[i]) != descr) {
+ PyErr_SetString(PyExc_NotImplementedError,
+ "dtype was replaced during array creation, the dtype is "
+ "unsupported currently (a subarray dtype?).");
+ goto finish;
+ }
+ DTypes[i] = NPY_DTYPE(descr);
+ Py_INCREF(DTypes[i]);
+ if (!NPY_DT_is_legacy(DTypes[i])) {
+ allow_legacy_promotion = NPY_FALSE;
+ }
+ }
+ /* Explicitly allow int, float, and complex for the "weak" types. */
+ else if (descr_obj == (PyObject *)&PyLong_Type) {
+ descr = PyArray_DescrFromType(NPY_LONG);
+ Py_INCREF(descr);
+ dummy_arrays[i] = (PyArrayObject *)PyArray_Empty(0, NULL, descr, 0);
+ if (dummy_arrays[i] == NULL) {
+ goto finish;
+ }
+ PyArray_ENABLEFLAGS(dummy_arrays[i], NPY_ARRAY_WAS_PYTHON_INT);
+ Py_INCREF(&PyArray_PyIntAbstractDType);
+ DTypes[i] = &PyArray_PyIntAbstractDType;
+ promoting_pyscalars = NPY_TRUE;
+ }
+ else if (descr_obj == (PyObject *)&PyFloat_Type) {
+ descr = PyArray_DescrFromType(NPY_DOUBLE);
+ Py_INCREF(descr);
+ dummy_arrays[i] = (PyArrayObject *)PyArray_Empty(0, NULL, descr, 0);
+ if (dummy_arrays[i] == NULL) {
+ goto finish;
+ }
+ PyArray_ENABLEFLAGS(dummy_arrays[i], NPY_ARRAY_WAS_PYTHON_FLOAT);
+ Py_INCREF(&PyArray_PyFloatAbstractDType);
+ DTypes[i] = &PyArray_PyFloatAbstractDType;
+ promoting_pyscalars = NPY_TRUE;
+ }
+ else if (descr_obj == (PyObject *)&PyComplex_Type) {
+ descr = PyArray_DescrFromType(NPY_CDOUBLE);
+ Py_INCREF(descr);
+ dummy_arrays[i] = (PyArrayObject *)PyArray_Empty(0, NULL, descr, 0);
+ if (dummy_arrays[i] == NULL) {
+ goto finish;
+ }
+ PyArray_ENABLEFLAGS(dummy_arrays[i], NPY_ARRAY_WAS_PYTHON_COMPLEX);
+ Py_INCREF(&PyArray_PyComplexAbstractDType);
+ DTypes[i] = &PyArray_PyComplexAbstractDType;
+ promoting_pyscalars = NPY_TRUE;
+ }
+ else if (descr_obj == Py_None) {
+ if (i < ufunc->nin && !(reduction && i == 0)) {
+ PyErr_SetString(PyExc_TypeError,
+ "All input dtypes must be provided "
+ "(except the first one in reductions)");
+ goto finish;
+ }
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError,
+ "Provided dtype must be a valid NumPy dtype, "
+ "int, float, complex, or None.");
+ goto finish;
+ }
+ }
+
+ PyArrayMethodObject *ufuncimpl;
+ if (!reduction) {
+ ufuncimpl = promote_and_get_ufuncimpl(ufunc,
+ dummy_arrays, signature, DTypes, NPY_FALSE,
+ allow_legacy_promotion, promoting_pyscalars, NPY_FALSE);
+ if (ufuncimpl == NULL) {
+ goto finish;
+ }
+
+ /* Find the correct descriptors for the operation */
+ if (resolve_descriptors(ufunc->nargs, ufunc, ufuncimpl,
+ dummy_arrays, operation_descrs, signature, casting) < 0) {
+ goto finish;
+ }
+
+ if (validate_casting(
+ ufuncimpl, ufunc, dummy_arrays, operation_descrs, casting) < 0) {
+ goto finish;
+ }
+ }
+ else { /* reduction */
+ if (signature[2] != NULL) {
+ PyErr_SetString(PyExc_ValueError,
+ "Reduction signature must end with None, instead pass "
+ "the first DType in the signature.");
+ goto finish;
+ }
+
+ if (dummy_arrays[2] != NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "Output dtype must not be passed for reductions, "
+ "pass the first input instead.");
+ goto finish;
+ }
+
+ ufuncimpl = reducelike_promote_and_resolve(ufunc,
+ dummy_arrays[1], dummy_arrays[0], signature, NPY_FALSE,
+ operation_descrs, casting, "resolve_dtypes");
+
+ if (ufuncimpl == NULL) {
+ goto finish;
+ }
+ }
+
+ result = PyArray_TupleFromItems(
+ ufunc->nargs, (PyObject **)operation_descrs, 0);
+
+ if (result == NULL || !return_context) {
+ goto finish;
+ }
+ /* Result will be (dtype_tuple, call_info), so move it and clear result */
+ result_dtype_tuple = result;
+ result = NULL;
+
+ /* We may have to return the context: */
+ ufunc_call_info *call_info;
+ call_info = PyObject_Malloc(sizeof(ufunc_call_info)
+ + ufunc->nargs * sizeof(PyArray_Descr *));
+ if (call_info == NULL) {
+ PyErr_NoMemory();
+ goto finish;
+ }
+ call_info->strided_loop = NULL;
+ call_info->auxdata = NULL;
+ call_info->context = &call_info->_full_context;
+
+ /*
+ * We create a capsule with NumPy 1.24 in the name to signal that it is
+ * prone to change in version updates (it doesn't have to).
+ * This capsule is documented in the `ufunc._resolve_dtypes_and_context`
+ * docstring.
+ */
+ PyObject *capsule = PyCapsule_New(
+ call_info, "numpy_1.24_ufunc_call_info", &free_ufunc_call_info);
+ if (capsule == NULL) {
+ PyObject_Free(call_info);
+ goto finish;
+ }
+
+ PyArrayMethod_Context *context = call_info->context;
+
+ Py_INCREF(ufunc);
+ context->caller = (PyObject *)ufunc;
+ Py_INCREF(ufuncimpl);
+ context->method = ufuncimpl;
+ context->descriptors = call_info->_descrs;
+ for (int i=0; i < ufunc->nargs; i++) {
+ Py_INCREF(operation_descrs[i]);
+ context->descriptors[i] = operation_descrs[i];
+ }
+
+ result = PyTuple_Pack(2, result_dtype_tuple, capsule);
+ /* cleanup and return */
+ Py_DECREF(capsule);
+
+ finish:
+ npy_promotion_state = original_promotion_state;
+
+ Py_XDECREF(result_dtype_tuple);
+ for (int i = 0; i < ufunc->nargs; i++) {
+ Py_XDECREF(signature[i]);
+ Py_XDECREF(dummy_arrays[i]);
+ Py_XDECREF(operation_descrs[i]);
+ Py_XDECREF(DTypes[i]);
+ }
+
+ return result;
+}
+
+
+static PyObject *
+py_resolve_dtypes(PyUFuncObject *ufunc,
+ PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
+{
+ return py_resolve_dtypes_generic(ufunc, NPY_FALSE, args, len_args, kwnames);
+}
+
+
+static PyObject *
+py_resolve_dtypes_and_context(PyUFuncObject *ufunc,
+ PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
+{
+ return py_resolve_dtypes_generic(ufunc, NPY_TRUE, args, len_args, kwnames);
+}
+
+
+static PyObject *
+py_get_strided_loop(PyUFuncObject *ufunc,
+ PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
+{
+ NPY_PREPARE_ARGPARSER;
+
+ PyObject *call_info_obj;
+ PyObject *fixed_strides_obj = Py_None;
+ npy_intp fixed_strides[NPY_MAXARGS];
+
+ if (npy_parse_arguments("_get_strided_loop", args, len_args, kwnames,
+ "", NULL, &call_info_obj,
+ "$fixed_strides", NULL, &fixed_strides_obj,
+ NULL, NULL, NULL) < 0) {
+ return NULL;
+ }
+
+ ufunc_call_info *call_info = PyCapsule_GetPointer(
+ call_info_obj, "numpy_1.24_ufunc_call_info");
+ if (call_info == NULL) {
+ /* Cannot have a context with NULL inside... */
+ assert(PyErr_Occurred());
+ return NULL;
+ }
+ if (call_info->strided_loop != NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "ufunc call_info has already been filled/used!");
+ return NULL;
+ }
+
+ if (call_info->context->caller != (PyObject *)ufunc) {
+ PyErr_SetString(PyExc_TypeError,
+ "calling get_strided_loop with incompatible context");
+ return NULL;
+ }
+
+ /*
+ * Strict conversion of fixed_strides, None, or tuple of int or None.
+ */
+ if (fixed_strides_obj == Py_None) {
+ for (int i = 0; i < ufunc->nargs; i++) {
+ fixed_strides[i] = NPY_MAX_INTP;
+ }
+ }
+ else if (PyTuple_CheckExact(fixed_strides_obj)
+ && PyTuple_Size(fixed_strides_obj) == ufunc->nargs) {
+ for (int i = 0; i < ufunc->nargs; i++) {
+ PyObject *stride = PyTuple_GET_ITEM(fixed_strides_obj, i);
+ if (PyLong_CheckExact(stride)) {
+ fixed_strides[i] = PyLong_AsSsize_t(stride);
+ if (error_converting(fixed_strides[i])) {
+ return NULL;
+ }
+ }
+ else if (stride == Py_None) {
+ fixed_strides[i] = NPY_MAX_INTP;
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError,
+ "_get_strided_loop(): fixed_strides tuple must contain "
+ "Python ints or None");
+ return NULL;
+ }
+ }
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError,
+ "_get_strided_loop(): fixed_strides must be a tuple or None");
+ return NULL;
+ }
+
+ NPY_ARRAYMETHOD_FLAGS flags;
+ if (call_info->context->method->get_strided_loop(call_info->context,
+ 1, 0, fixed_strides, &call_info->strided_loop, &call_info->auxdata,
+ &flags) < 0) {
+ return NULL;
+ }
+
+ call_info->requires_pyapi = flags & NPY_METH_REQUIRES_PYAPI;
+ call_info->no_floatingpoint_errors = (
+ flags & NPY_METH_NO_FLOATINGPOINT_ERRORS);
+
+ Py_RETURN_NONE;
+}
+
+
static struct PyMethodDef ufunc_methods[] = {
{"reduce",
(PyCFunction)ufunc_reduce,
@@ -6331,6 +6734,21 @@ static struct PyMethodDef ufunc_methods[] = {
{"at",
(PyCFunction)ufunc_at,
METH_VARARGS, NULL},
+ /* Lower level methods: */
+ {"resolve_dtypes",
+ (PyCFunction)py_resolve_dtypes,
+ METH_FASTCALL | METH_KEYWORDS, NULL},
+ /*
+ * The following two functions are public API, but underscored since they
+ * are C-user specific and allow direct access to the core of ufunc loops.
+ * (See their documentation for API stability.)
+ */
+ {"_resolve_dtypes_and_context",
+ (PyCFunction)py_resolve_dtypes_and_context,
+ METH_FASTCALL | METH_KEYWORDS, NULL},
+ {"_get_strided_loop",
+ (PyCFunction)py_get_strided_loop,
+ METH_FASTCALL | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL} /* sentinel */
};
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index 21b0d9e46..cc6c0839d 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -1,6 +1,7 @@
import warnings
import itertools
import sys
+import ctypes as ct
import pytest
from pytest import param
@@ -2656,3 +2657,135 @@ def test_addition_reduce_negative_zero(dtype, use_initial):
# `sum([])` should probably be 0.0 and not -0.0 like `sum([-0.0])`
assert not np.signbit(res.real)
assert not np.signbit(res.imag)
+
+class TestLowlevelAPIAccess:
+ def test_resolve_dtypes_basic(self):
+ # Basic test for dtype resolution:
+ i4 = np.dtype("i4")
+ f4 = np.dtype("f4")
+ f8 = np.dtype("f8")
+
+ r = np.add.resolve_dtypes((i4, f4, None))
+ assert r == (f8, f8, f8)
+
+ # Signature uses the same logic to parse as ufunc (less strict)
+ # the following is "same-kind" casting so works:
+ r = np.add.resolve_dtypes((
+ i4, i4, None), signature=(None, None, "f4"))
+ assert r == (f4, f4, f4)
+
+ # Check NEP 50 "weak" promotion also:
+ r = np.add.resolve_dtypes((f4, int, None))
+ assert r == (f4, f4, f4)
+
+ with pytest.raises(TypeError):
+ np.add.resolve_dtypes((i4, f4, None), casting="no")
+
+ def test_weird_dtypes(self):
+ S0 = np.dtype("S0")
+ # S0 is often converted by NumPy to S1, but not here:
+ r = np.equal.resolve_dtypes((S0, S0, None))
+ assert r == (S0, S0, np.dtype(bool))
+
+ # Subarray dtypes are weird and only really exist nested, they need
+ # the shift to full NEP 50 to be implemented nicely:
+ dts = np.dtype("10i")
+ with pytest.raises(NotImplementedError):
+ np.equal.resolve_dtypes((dts, dts, None))
+
+ def test_resolve_dtypes_reduction(self):
+ i4 = np.dtype("i4")
+ with pytest.raises(NotImplementedError):
+ np.add.resolve_dtypes((i4, i4, i4), reduction=True)
+
+ @pytest.mark.parametrize("dtypes", [
+ (np.dtype("i"), np.dtype("i")),
+ (None, np.dtype("i"), np.dtype("f")),
+ (np.dtype("i"), None, np.dtype("f")),
+ ("i4", "i4", None)])
+ def test_resolve_dtypes_errors(self, dtypes):
+ with pytest.raises(TypeError):
+ np.add.resolve_dtypes(dtypes)
+
+ def test_resolve_dtypes_reduction(self):
+ i2 = np.dtype("i2")
+ long_ = np.dtype("long")
+ # Check special addition resolution:
+ res = np.add.resolve_dtypes((None, i2, None), reduction=True)
+ assert res == (long_, long_, long_)
+
+ def test_resolve_dtypes_reduction_errors(self):
+ i2 = np.dtype("i2")
+
+ with pytest.raises(TypeError):
+ np.add.resolve_dtypes((None, i2, i2))
+
+ with pytest.raises(TypeError):
+ np.add.signature((None, None, "i4"))
+
+ @pytest.mark.skipif(not hasattr(ct, "pythonapi"),
+ reason="`ctypes.pythonapi` required for capsule unpacking.")
+ def test_loop_access(self):
+ # This is a basic test for the full strided loop access
+ data_t = ct.ARRAY(ct.c_char_p, 2)
+ dim_t = ct.ARRAY(ct.c_ssize_t, 1)
+ strides_t = ct.ARRAY(ct.c_ssize_t, 2)
+ strided_loop_t = ct.CFUNCTYPE(
+ ct.c_int, ct.c_void_p, data_t, dim_t, strides_t, ct.c_void_p)
+
+ class call_info_t(ct.Structure):
+ _fields_ = [
+ ("strided_loop", strided_loop_t),
+ ("context", ct.c_void_p),
+ ("auxdata", ct.c_void_p),
+ ("requires_pyapi", ct.c_byte),
+ ("no_floatingpoint_errors", ct.c_byte),
+ ]
+
+ i4 = np.dtype("i4")
+ dt, call_info_obj = np.negative._resolve_dtypes_and_context((i4, i4))
+ assert dt == (i4, i4) # can be used without casting
+
+ # Fill in the rest of the information:
+ np.negative._get_strided_loop(call_info_obj)
+
+ ct.pythonapi.PyCapsule_GetPointer.restype = ct.c_void_p
+ call_info = ct.pythonapi.PyCapsule_GetPointer(
+ ct.py_object(call_info_obj),
+ ct.c_char_p(b"numpy_1.24_ufunc_call_info"))
+
+ call_info = ct.cast(call_info, ct.POINTER(call_info_t)).contents
+
+ arr = np.arange(10, dtype=i4)
+ call_info.strided_loop(
+ call_info.context,
+ data_t(arr.ctypes.data, arr.ctypes.data),
+ arr.ctypes.shape, # is a C-array with 10 here
+ strides_t(arr.ctypes.strides[0], arr.ctypes.strides[0]),
+ call_info.auxdata)
+
+ # We just directly called the negative inner-loop in-place:
+ assert_array_equal(arr, -np.arange(10, dtype=i4))
+
+ @pytest.mark.parametrize("strides", [1, (1, 2, 3), (1, "2")])
+ def test__get_strided_loop_errors_bad_strides(self, strides):
+ i4 = np.dtype("i4")
+ dt, call_info = np.negative._resolve_dtypes_and_context((i4, i4))
+
+ with pytest.raises(TypeError, match="fixed_strides.*tuple.*or None"):
+ np.negative._get_strided_loop(call_info, fixed_strides=strides)
+
+ def test__get_strided_loop_errors_bad_call_info(self):
+ i4 = np.dtype("i4")
+ dt, call_info = np.negative._resolve_dtypes_and_context((i4, i4))
+
+ with pytest.raises(ValueError, match="PyCapsule"):
+ np.negative._get_strided_loop("not the capsule!")
+
+ with pytest.raises(TypeError, match=".*incompatible context"):
+ np.add._get_strided_loop(call_info)
+
+ np.negative._get_strided_loop(call_info)
+ with pytest.raises(TypeError):
+ # cannot call it a second time:
+ np.negative._get_strided_loop(call_info)