summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2022-06-09 12:34:20 -0700
committerSebastian Berg <sebastian@sipsolutions.net>2022-06-15 11:42:02 -0700
commit983277185ea595880697e9289796c527be28aab3 (patch)
treeb493cf9bfa819f963f1427208eae1a787f247af2 /numpy
parentc66e6314f041653e2b289c03a09333ba5b4fd45a (diff)
downloadnumpy-983277185ea595880697e9289796c527be28aab3.tar.gz
BUG: Fix broken weak promotion (including legacy ones) and make it more robust
It seems that the (weird and probably non-existing in practice) case of uint8 vs. int8 promotion when the input is a single integer was broken at some point and this fixes it again. This is only really relevant for rational, which defines only a very selective number of integer promotions. This fixes up the previous chunk that relaxes promotion fallbacks a lot for legacy dtypes.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/abstractdtypes.c26
-rw-r--r--numpy/core/tests/test_dtype.py5
2 files changed, 29 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/abstractdtypes.c b/numpy/core/src/multiarray/abstractdtypes.c
index bff6ecbb7..3e89d045e 100644
--- a/numpy/core/src/multiarray/abstractdtypes.c
+++ b/numpy/core/src/multiarray/abstractdtypes.c
@@ -170,6 +170,12 @@ int_common_dtype(PyArray_DTypeMeta *NPY_UNUSED(cls), PyArray_DTypeMeta *other)
if (res == NULL) {
PyErr_Clear();
}
+ else if (res == (PyArray_DTypeMeta *)Py_NotImplemented) {
+ Py_DECREF(res);
+ }
+ else {
+ return res;
+ }
/* Try again with `int8`, an error may have been set, though */
PyArray_DTypeMeta *int8_dt = PyArray_DTypeFromTypeNum(NPY_INT8);
res = NPY_DT_CALL_common_dtype(other, int8_dt);
@@ -177,6 +183,12 @@ int_common_dtype(PyArray_DTypeMeta *NPY_UNUSED(cls), PyArray_DTypeMeta *other)
if (res == NULL) {
PyErr_Clear();
}
+ else if (res == (PyArray_DTypeMeta *)Py_NotImplemented) {
+ Py_DECREF(res);
+ }
+ else {
+ return res;
+ }
/* And finally, we will try the default integer, just for sports... */
PyArray_DTypeMeta *default_int = PyArray_DTypeFromTypeNum(NPY_LONG);
res = NPY_DT_CALL_common_dtype(other, default_int);
@@ -217,6 +229,12 @@ float_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
if (res == NULL) {
PyErr_Clear();
}
+ else if (res == (PyArray_DTypeMeta *)Py_NotImplemented) {
+ Py_DECREF(res);
+ }
+ else {
+ return res;
+ }
/* Retry with double (the default float) */
PyArray_DTypeMeta *double_dt = PyArray_DTypeFromTypeNum(NPY_DOUBLE);
res = NPY_DT_CALL_common_dtype(other, double_dt);
@@ -265,7 +283,13 @@ complex_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
if (res == NULL) {
PyErr_Clear();
}
- /* Retry with double (the default float) */
+ else if (res == (PyArray_DTypeMeta *)Py_NotImplemented) {
+ Py_DECREF(res);
+ }
+ else {
+ return res;
+ }
+ /* Retry with cdouble (the default complex) */
PyArray_DTypeMeta *cdouble_dt = PyArray_DTypeFromTypeNum(NPY_CDOUBLE);
res = NPY_DT_CALL_common_dtype(other, cdouble_dt);
Py_DECREF(cdouble_dt);
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index d895459ed..f95f95893 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -1373,7 +1373,10 @@ class TestPromotion:
# Note that rationals are a bit akward as they promote with float64
# or default ints, but not float16 or uint8/int8 (which looks
# inconsistent here). The new promotion fixes this (partially?)
- if not weak_promotion:
+ if not weak_promotion and type(other) == float:
+ # The float version, checks float16 in the legacy path, which fails
+ # the integer version seems to check int8 (also), so it can
+ # pass.
with pytest.raises(TypeError,
match=r".* do not have a common DType"):
np.result_type(other, rational)