summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2022-06-27 20:39:08 -0700
committerSebastian Berg <sebastianb@nvidia.com>2022-10-12 10:41:37 +0200
commit409cccf70341a9ef9c80138dd9569f6ca3720cef (patch)
treef0c41275dcbab789e5cdb5edd74cbac8f7ec46d7
parent241c905c464a29c7b25858d57ea1a43131848530 (diff)
downloadnumpy-409cccf70341a9ef9c80138dd9569f6ca3720cef.tar.gz
ENH: Implement safe integers and weak python scalars for scalars
This requires adding a path that uses the "normal" Python object to dtype conversion (setitem) function, rather than always converting to the default dtype.
-rw-r--r--numpy/core/src/umath/scalarmath.c.src84
-rw-r--r--numpy/core/tests/test_nep50_promotions.py17
2 files changed, 90 insertions, 11 deletions
diff --git a/numpy/core/src/umath/scalarmath.c.src b/numpy/core/src/umath/scalarmath.c.src
index 7bfa29d7e..89c065549 100644
--- a/numpy/core/src/umath/scalarmath.c.src
+++ b/numpy/core/src/umath/scalarmath.c.src
@@ -26,11 +26,15 @@
#include "binop_override.h"
#include "npy_longdouble.h"
+#include "arraytypes.h"
#include "array_coercion.h"
#include "common.h"
#include "can_cast_table.h"
#include "umathmodule.h"
+#include "convert_datatype.h"
+#include "dtypemeta.h"
+
/* TODO: Used for some functions, should possibly move these to npy_math.h */
#include "loops.h"
@@ -792,7 +796,12 @@ typedef enum {
*/
CONVERSION_SUCCESS,
/*
- * Other object is an unknown scalar or array-like, we (typically) use
+ * We use the normal conversion (setitem) function when coercing from
+ * Python scalars.
+ */
+ CONVERT_PYSCALAR,
+ /*
+ * Other object is an unkown scalar or array-like, we (typically) use
* the generic path, which normally ends up in the ufunc machinery.
*/
OTHER_IS_UNKNOWN_OBJECT,
@@ -956,7 +965,15 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
*may_need_deferring = NPY_TRUE;
}
if (!IS_SAFE(NPY_DOUBLE, NPY_@TYPE@)) {
- return PROMOTION_REQUIRED;
+ if (npy_promotion_state != NPY_USE_WEAK_PROMOTION) {
+ /* Legacy promotion and weak-and-warn not handled here */
+ return PROMOTION_REQUIRED;
+ }
+ /* Weak promotion is used when self is float or complex: */
+ if (!PyTypeNum_ISFLOAT(NPY_@TYPE@) && !PyTypeNum_ISCOMPLEX(NPY_@TYPE@)) {
+ return PROMOTION_REQUIRED;
+ }
+ return CONVERT_PYSCALAR;
}
CONVERT_TO_RESULT(PyFloat_AS_DOUBLE(value));
return CONVERSION_SUCCESS;
@@ -968,15 +985,19 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
}
if (!IS_SAFE(NPY_LONG, NPY_@TYPE@)) {
/*
- * long -> (c)longdouble is safe, so `THER_IS_UNKNOWN_OBJECT` will
+ * long -> (c)longdouble is safe, so `OTHER_IS_UNKNOWN_OBJECT` will
* be returned below for huge integers.
*/
- return PROMOTION_REQUIRED;
+ if (npy_promotion_state != NPY_USE_WEAK_PROMOTION) {
+ /* Legacy promotion and weak-and-warn not handled here */
+ return PROMOTION_REQUIRED;
+ }
+ return CONVERT_PYSCALAR;
}
int overflow;
long val = PyLong_AsLongAndOverflow(value, &overflow);
if (overflow) {
- return OTHER_IS_UNKNOWN_OBJECT; /* handle as if arbitrary object */
+ return CONVERT_PYSCALAR; /* handle as if "unsafe" */
}
if (error_converting(val)) {
return CONVERSION_ERROR; /* should not be possible */
@@ -995,7 +1016,15 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
*may_need_deferring = NPY_TRUE;
}
if (!IS_SAFE(NPY_CDOUBLE, NPY_@TYPE@)) {
- return PROMOTION_REQUIRED;
+ if (npy_promotion_state != NPY_USE_WEAK_PROMOTION) {
+ /* Legacy promotion and weak-and-warn not handled here */
+ return PROMOTION_REQUIRED;
+ }
+ /* Weak promotion is used when self is float or complex: */
+ if (!PyTypeNum_ISCOMPLEX(NPY_@TYPE@)) {
+ return PROMOTION_REQUIRED;
+ }
+ return CONVERT_PYSCALAR;
}
#if defined(IS_CFLOAT) || defined(IS_CDOUBLE) || defined(IS_CLONGDOUBLE)
Py_complex val = PyComplex_AsCComplex(value);
@@ -1164,12 +1193,24 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
* (npy_half, npy_float, npy_double, npy_longdouble,
* npy_cfloat, npy_cdouble, npy_clongdouble)*4,
* (npy_half, npy_float, npy_double, npy_longdouble)*3#
+ * #oname = (byte, ubyte, short, ushort, int, uint,
+ * long, ulong, longlong, ulonglong)*11,
+ * double*10,
+ * (half, float, double, longdouble,
+ * cfloat, cdouble, clongdouble)*4,
+ * (half, float, double, longdouble)*3#
* #OName = (Byte, UByte, Short, UShort, Int, UInt,
* Long, ULong, LongLong, ULongLong)*11,
* Double*10,
* (Half, Float, Double, LongDouble,
* CFloat, CDouble, CLongDouble)*4,
* (Half, Float, Double, LongDouble)*3#
+ * #ONAME = (BYTE, UBYTE, SHORT, USHORT, INT, UINT,
+ * LONG, ULONG, LONGLONG, ULONGLONG)*11,
+ * DOUBLE*10,
+ * (HALF, FLOAT, DOUBLE, LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE)*4,
+ * (HALF, FLOAT, DOUBLE, LONGDOUBLE)*3#
*/
#define IS_@name@
/* drop the "true_" from "true_divide" for floating point warnings: */
@@ -1179,13 +1220,12 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
#else
#define OP_NAME "@oper@"
#endif
-#undef IS_@oper@
static PyObject *
@name@_@oper@(PyObject *a, PyObject *b)
{
PyObject *ret;
- @type@ arg1, arg2, other_val;
+ @otype@ arg1, arg2, other_val;
/*
* Check if this operation may be considered forward. Note `is_forward`
@@ -1214,7 +1254,7 @@ static PyObject *
PyObject *other = is_forward ? b : a;
npy_bool may_need_deferring;
- conversion_result res = convert_to_@name@(
+ conversion_result res = convert_to_@oname@(
other, &other_val, &may_need_deferring);
if (res == CONVERSION_ERROR) {
return NULL; /* an error occurred (should never happen) */
@@ -1255,6 +1295,11 @@ static PyObject *
* correctly. (e.g. `uint8 * int8` cannot warn).
*/
return PyGenericArrType_Type.tp_as_number->nb_@oper@(a,b);
+ case CONVERT_PYSCALAR:
+ if (@ONAME@_setitem(other, (char *)&other_val, NULL) < 0) {
+ return NULL;
+ }
+ break;
default:
assert(0); /* error was checked already, impossible to reach */
return NULL;
@@ -1291,7 +1336,7 @@ static PyObject *
#if @twoout@
int retstatus = @name@_ctype_@oper@(arg1, arg2, &out, &out2);
#else
- int retstatus = @name@_ctype_@oper@(arg1, arg2, &out);
+ int retstatus = @oname@_ctype_@oper@(arg1, arg2, &out);
#endif
#if @fperr@
@@ -1336,6 +1381,7 @@ static PyObject *
#undef OP_NAME
+#undef IS_@oper@
#undef IS_@name@
/**end repeat**/
@@ -1358,6 +1404,10 @@ static PyObject *
* Long, ULong, LongLong, ULongLong,
* Half, Float, Double, LongDouble,
* CFloat, CDouble, CLongDouble#
+ * #NAME = BYTE, UBYTE, SHORT, USHORT, INT, UINT,
+ * LONG, ULONG, LONGLONG, ULONGLONG,
+ * HALF, FLOAT, DOUBLE, LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE#
*
* #isint = 1*10,0*7#
* #isuint = (0,1)*5,0*7#
@@ -1417,6 +1467,11 @@ static PyObject *
#endif
case PROMOTION_REQUIRED:
return PyGenericArrType_Type.tp_as_number->nb_power(a, b, modulo);
+ case CONVERT_PYSCALAR:
+ if (@NAME@_setitem(other, (char *)&other_val, NULL) < 0) {
+ return NULL;
+ }
+ break;
default:
assert(0); /* error was checked already, impossible to reach */
return NULL;
@@ -1759,6 +1814,10 @@ static PyObject *
* Long, ULong, LongLong, ULongLong,
* Half, Float, Double, LongDouble,
* CFloat, CDouble, CLongDouble#
+ * #NAME = BYTE, UBYTE, SHORT, USHORT, INT, UINT,
+ * LONG, ULONG, LONGLONG, ULONGLONG,
+ * HALF, FLOAT, DOUBLE, LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE#
* #simp = def*10, def_half, def*3, cmplx*3#
*/
#define IS_@name@
@@ -1791,6 +1850,11 @@ static PyObject*
#endif
case PROMOTION_REQUIRED:
return PyGenericArrType_Type.tp_richcompare(self, other, cmp_op);
+ case CONVERT_PYSCALAR:
+ if (@NAME@_setitem(other, (char *)&arg2, NULL) < 0) {
+ return NULL;
+ }
+ break;
default:
assert(0); /* error was checked already, impossible to reach */
return NULL;
diff --git a/numpy/core/tests/test_nep50_promotions.py b/numpy/core/tests/test_nep50_promotions.py
index 5c59a16ea..3957bdc3e 100644
--- a/numpy/core/tests/test_nep50_promotions.py
+++ b/numpy/core/tests/test_nep50_promotions.py
@@ -30,7 +30,7 @@ def test_nep50_examples():
assert res.dtype == np.int64
with pytest.warns(UserWarning, match="result dtype changed"):
- # Note: Should warn (error with the errstate), but does not:
+ # Note: Overflow would be nice, but does not warn with change warning
with np.errstate(over="raise"):
res = np.uint8(100) + 200
assert res.dtype == np.uint8
@@ -61,6 +61,21 @@ def test_nep50_examples():
assert res.dtype == np.float64
+def test_nep50_without_warnings():
+ # Test that avoid the "warn" method, since that may lead to different
+ # code paths in some cases.
+ # Set promotion to weak (no warning), the auto-fixture will reset it.
+ np._set_promotion_state("weak")
+ with np.errstate(over="warn"):
+ with pytest.warns(RuntimeWarning):
+ res = np.uint8(100) + 200
+ assert res.dtype == np.uint8
+
+ with pytest.warns(RuntimeWarning):
+ res = np.float32(1) + 3e100
+ assert res.dtype == np.float32
+
+
@pytest.mark.xfail
def test_nep50_integer_conversion_errors():
# Implementation for error paths is mostly missing (as of writing)