diff options
| author | Sebastian Berg <sebastian@sipsolutions.net> | 2022-05-27 15:10:59 -0700 |
|---|---|---|
| committer | Sebastian Berg <sebastian@sipsolutions.net> | 2022-06-15 11:42:02 -0700 |
| commit | c855cecce28d5f925c5b2e015a0bdfd1aa472590 (patch) | |
| tree | 35b7e271aee7b3d5df416ecea5e3bd55f4d1ecbb /numpy | |
| parent | 09d407a3cd24b712f8a40748262e00188e8b8efa (diff) | |
| download | numpy-c855cecce28d5f925c5b2e015a0bdfd1aa472590.tar.gz | |
API: Expose `get_promotion_state` and `set_promotion_state`
We need to be able to query the state for testing, probably should
be renamed before the end, but need to have something for now.
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/multiarray.py | 5 | ||||
| -rw-r--r-- | numpy/core/numeric.py | 5 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 43 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/convert_datatype.h | 6 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 6 |
5 files changed, 62 insertions, 3 deletions
diff --git a/numpy/core/multiarray.py b/numpy/core/multiarray.py index 8c14583e6..e8ac27987 100644 --- a/numpy/core/multiarray.py +++ b/numpy/core/multiarray.py @@ -40,7 +40,8 @@ __all__ = [ 'ravel_multi_index', 'result_type', 'scalar', 'set_datetimeparse_function', 'set_legacy_print_mode', 'set_numeric_ops', 'set_string_function', 'set_typeDict', 'shares_memory', 'tracemalloc_domain', 'typeinfo', - 'unpackbits', 'unravel_index', 'vdot', 'where', 'zeros'] + 'unpackbits', 'unravel_index', 'vdot', 'where', 'zeros', + 'get_promotion_state', 'set_promotion_state'] # For backward compatibility, make sure pickle imports these functions from here _reconstruct.__module__ = 'numpy.core.multiarray' @@ -68,6 +69,8 @@ promote_types.__module__ = 'numpy' set_numeric_ops.__module__ = 'numpy' seterrobj.__module__ = 'numpy' zeros.__module__ = 'numpy' +get_promotion_state.__module__ = 'numpy' +set_promotion_state.__module__ = 'numpy' # We can't verify dispatcher signatures because NumPy's C functions don't diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 38d85da6e..f81cfebce 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -17,7 +17,7 @@ from .multiarray import ( fromstring, inner, lexsort, matmul, may_share_memory, min_scalar_type, ndarray, nditer, nested_iters, promote_types, putmask, result_type, set_numeric_ops, shares_memory, vdot, where, - zeros, normalize_axis_index) + zeros, normalize_axis_index, get_promotion_state, set_promotion_state) from . import overrides from . import umath @@ -54,7 +54,8 @@ __all__ = [ 'False_', 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul', 'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', - 'MAY_SHARE_EXACT', 'TooHardError', 'AxisError'] + 'MAY_SHARE_EXACT', 'TooHardError', 'AxisError', + 'get_promotion_state', 'set_promotion_state'] @set_module('numpy') diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index 94b7d6d98..2b1eca333 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -91,6 +91,49 @@ npy_give_promotion_warnings(void) return val == Py_False; } + +NPY_NO_EXPORT PyObject * +npy_get_promotion_state(PyObject *NPY_UNUSED(mod), PyObject *NPY_UNUSED(arg)) { + if (npy_promotion_state == NPY_USE_WEAK_PROMOTION) { + return PyUnicode_FromString("weak"); + } + else if (npy_promotion_state == NPY_USE_WEAK_PROMOTION_AND_WARN) { + return PyUnicode_FromString("weak_and_warn"); + } + else if (npy_promotion_state == NPY_USE_LEGACY_PROMOTION) { + return PyUnicode_FromString("legacy"); + } + PyErr_SetString(PyExc_SystemError, "invalid promotion state!"); + return NULL; +} + + +NPY_NO_EXPORT PyObject * +npy_set_promotion_state(PyObject *NPY_UNUSED(mod), PyObject *arg) +{ + if (!PyUnicode_Check(arg)) { + PyErr_SetString(PyExc_TypeError, + "set_promotion_state() argument must be a string."); + return NULL; + } + if (PyUnicode_CompareWithASCIIString(arg, "weak")) { + npy_promotion_state = NPY_USE_WEAK_PROMOTION; + } + else if (PyUnicode_CompareWithASCIIString(arg, "weak_and_warn")) { + npy_promotion_state = NPY_USE_WEAK_PROMOTION_AND_WARN; + } + else if (PyUnicode_CompareWithASCIIString(arg, "legacy")) { + npy_promotion_state = NPY_USE_LEGACY_PROMOTION; + } + else { + PyErr_Format(PyExc_TypeError, + "set_promotion_state() argument must be " + "'weak', 'legacy', or 'weak_and_warn' but got '%.100S'", arg); + return NULL; + } + Py_RETURN_NONE; +} + /** * Fetch the casting implementation from one DType to another. * diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h index 3550f45d2..2d99042f2 100644 --- a/numpy/core/src/multiarray/convert_datatype.h +++ b/numpy/core/src/multiarray/convert_datatype.h @@ -19,6 +19,12 @@ NPY_NO_EXPORT int npy_give_promotion_warnings(void); NPY_NO_EXPORT PyObject * +npy_get_promotion_state(PyObject *NPY_UNUSED(mod), PyObject *NPY_UNUSED(arg)); + +NPY_NO_EXPORT PyObject * +npy_set_promotion_state(PyObject *NPY_UNUSED(mod), PyObject *arg); + +NPY_NO_EXPORT PyObject * PyArray_GetCastingImpl(PyArray_DTypeMeta *from, PyArray_DTypeMeta *to); NPY_NO_EXPORT PyObject * diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 13fb83ee5..b2512978e 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -4502,6 +4502,12 @@ static struct PyMethodDef array_module_methods[] = { {"get_handler_version", (PyCFunction) get_handler_version, METH_VARARGS, NULL}, + {"get_promotion_state", + (PyCFunction)npy_get_promotion_state, + METH_NOARGS, NULL}, + {"set_promotion_state", + (PyCFunction)npy_set_promotion_state, + METH_O, NULL}, {"_add_newdoc_ufunc", (PyCFunction)add_newdoc_ufunc, METH_VARARGS, NULL}, {"_get_sfloat_dtype", |
