summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2022-05-27 15:10:59 -0700
committerSebastian Berg <sebastian@sipsolutions.net>2022-06-15 11:42:02 -0700
commitc855cecce28d5f925c5b2e015a0bdfd1aa472590 (patch)
tree35b7e271aee7b3d5df416ecea5e3bd55f4d1ecbb /numpy
parent09d407a3cd24b712f8a40748262e00188e8b8efa (diff)
downloadnumpy-c855cecce28d5f925c5b2e015a0bdfd1aa472590.tar.gz
API: Expose `get_promotion_state` and `set_promotion_state`
We need to be able to query the state for testing, probably should be renamed before the end, but need to have something for now.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/multiarray.py5
-rw-r--r--numpy/core/numeric.py5
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c43
-rw-r--r--numpy/core/src/multiarray/convert_datatype.h6
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c6
5 files changed, 62 insertions, 3 deletions
diff --git a/numpy/core/multiarray.py b/numpy/core/multiarray.py
index 8c14583e6..e8ac27987 100644
--- a/numpy/core/multiarray.py
+++ b/numpy/core/multiarray.py
@@ -40,7 +40,8 @@ __all__ = [
'ravel_multi_index', 'result_type', 'scalar', 'set_datetimeparse_function',
'set_legacy_print_mode', 'set_numeric_ops', 'set_string_function',
'set_typeDict', 'shares_memory', 'tracemalloc_domain', 'typeinfo',
- 'unpackbits', 'unravel_index', 'vdot', 'where', 'zeros']
+ 'unpackbits', 'unravel_index', 'vdot', 'where', 'zeros',
+ 'get_promotion_state', 'set_promotion_state']
# For backward compatibility, make sure pickle imports these functions from here
_reconstruct.__module__ = 'numpy.core.multiarray'
@@ -68,6 +69,8 @@ promote_types.__module__ = 'numpy'
set_numeric_ops.__module__ = 'numpy'
seterrobj.__module__ = 'numpy'
zeros.__module__ = 'numpy'
+get_promotion_state.__module__ = 'numpy'
+set_promotion_state.__module__ = 'numpy'
# We can't verify dispatcher signatures because NumPy's C functions don't
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 38d85da6e..f81cfebce 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -17,7 +17,7 @@ from .multiarray import (
fromstring, inner, lexsort, matmul, may_share_memory,
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
putmask, result_type, set_numeric_ops, shares_memory, vdot, where,
- zeros, normalize_axis_index)
+ zeros, normalize_axis_index, get_promotion_state, set_promotion_state)
from . import overrides
from . import umath
@@ -54,7 +54,8 @@ __all__ = [
'False_', 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS',
'BUFSIZE', 'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like',
'matmul', 'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS',
- 'MAY_SHARE_EXACT', 'TooHardError', 'AxisError']
+ 'MAY_SHARE_EXACT', 'TooHardError', 'AxisError',
+ 'get_promotion_state', 'set_promotion_state']
@set_module('numpy')
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index 94b7d6d98..2b1eca333 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -91,6 +91,49 @@ npy_give_promotion_warnings(void)
return val == Py_False;
}
+
+NPY_NO_EXPORT PyObject *
+npy_get_promotion_state(PyObject *NPY_UNUSED(mod), PyObject *NPY_UNUSED(arg)) {
+ if (npy_promotion_state == NPY_USE_WEAK_PROMOTION) {
+ return PyUnicode_FromString("weak");
+ }
+ else if (npy_promotion_state == NPY_USE_WEAK_PROMOTION_AND_WARN) {
+ return PyUnicode_FromString("weak_and_warn");
+ }
+ else if (npy_promotion_state == NPY_USE_LEGACY_PROMOTION) {
+ return PyUnicode_FromString("legacy");
+ }
+ PyErr_SetString(PyExc_SystemError, "invalid promotion state!");
+ return NULL;
+}
+
+
+NPY_NO_EXPORT PyObject *
+npy_set_promotion_state(PyObject *NPY_UNUSED(mod), PyObject *arg)
+{
+ if (!PyUnicode_Check(arg)) {
+ PyErr_SetString(PyExc_TypeError,
+ "set_promotion_state() argument must be a string.");
+ return NULL;
+ }
+ if (PyUnicode_CompareWithASCIIString(arg, "weak")) {
+ npy_promotion_state = NPY_USE_WEAK_PROMOTION;
+ }
+ else if (PyUnicode_CompareWithASCIIString(arg, "weak_and_warn")) {
+ npy_promotion_state = NPY_USE_WEAK_PROMOTION_AND_WARN;
+ }
+ else if (PyUnicode_CompareWithASCIIString(arg, "legacy")) {
+ npy_promotion_state = NPY_USE_LEGACY_PROMOTION;
+ }
+ else {
+ PyErr_Format(PyExc_TypeError,
+ "set_promotion_state() argument must be "
+ "'weak', 'legacy', or 'weak_and_warn' but got '%.100S'", arg);
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
/**
* Fetch the casting implementation from one DType to another.
*
diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h
index 3550f45d2..2d99042f2 100644
--- a/numpy/core/src/multiarray/convert_datatype.h
+++ b/numpy/core/src/multiarray/convert_datatype.h
@@ -19,6 +19,12 @@ NPY_NO_EXPORT int
npy_give_promotion_warnings(void);
NPY_NO_EXPORT PyObject *
+npy_get_promotion_state(PyObject *NPY_UNUSED(mod), PyObject *NPY_UNUSED(arg));
+
+NPY_NO_EXPORT PyObject *
+npy_set_promotion_state(PyObject *NPY_UNUSED(mod), PyObject *arg);
+
+NPY_NO_EXPORT PyObject *
PyArray_GetCastingImpl(PyArray_DTypeMeta *from, PyArray_DTypeMeta *to);
NPY_NO_EXPORT PyObject *
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 13fb83ee5..b2512978e 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -4502,6 +4502,12 @@ static struct PyMethodDef array_module_methods[] = {
{"get_handler_version",
(PyCFunction) get_handler_version,
METH_VARARGS, NULL},
+ {"get_promotion_state",
+ (PyCFunction)npy_get_promotion_state,
+ METH_NOARGS, NULL},
+ {"set_promotion_state",
+ (PyCFunction)npy_set_promotion_state,
+ METH_O, NULL},
{"_add_newdoc_ufunc", (PyCFunction)add_newdoc_ufunc,
METH_VARARGS, NULL},
{"_get_sfloat_dtype",