diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2022-06-27 20:39:08 -0700 |
---|---|---|
committer | Sebastian Berg <sebastianb@nvidia.com> | 2022-10-12 10:41:37 +0200 |
commit | 409cccf70341a9ef9c80138dd9569f6ca3720cef (patch) | |
tree | f0c41275dcbab789e5cdb5edd74cbac8f7ec46d7 | |
parent | 241c905c464a29c7b25858d57ea1a43131848530 (diff) | |
download | numpy-409cccf70341a9ef9c80138dd9569f6ca3720cef.tar.gz |
ENH: Implement safe integers and weak python scalars for scalars
This requires adding a path that uses the "normal" Python object
to dtype conversion (setitem) function, rather than always converting
to the default dtype.
-rw-r--r-- | numpy/core/src/umath/scalarmath.c.src | 84 | ||||
-rw-r--r-- | numpy/core/tests/test_nep50_promotions.py | 17 |
2 files changed, 90 insertions, 11 deletions
diff --git a/numpy/core/src/umath/scalarmath.c.src b/numpy/core/src/umath/scalarmath.c.src index 7bfa29d7e..89c065549 100644 --- a/numpy/core/src/umath/scalarmath.c.src +++ b/numpy/core/src/umath/scalarmath.c.src @@ -26,11 +26,15 @@ #include "binop_override.h" #include "npy_longdouble.h" +#include "arraytypes.h" #include "array_coercion.h" #include "common.h" #include "can_cast_table.h" #include "umathmodule.h" +#include "convert_datatype.h" +#include "dtypemeta.h" + /* TODO: Used for some functions, should possibly move these to npy_math.h */ #include "loops.h" @@ -792,7 +796,12 @@ typedef enum { */ CONVERSION_SUCCESS, /* - * Other object is an unknown scalar or array-like, we (typically) use + * We use the normal conversion (setitem) function when coercing from + * Python scalars. + */ + CONVERT_PYSCALAR, + /* + * Other object is an unkown scalar or array-like, we (typically) use * the generic path, which normally ends up in the ufunc machinery. */ OTHER_IS_UNKNOWN_OBJECT, @@ -956,7 +965,15 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring) *may_need_deferring = NPY_TRUE; } if (!IS_SAFE(NPY_DOUBLE, NPY_@TYPE@)) { - return PROMOTION_REQUIRED; + if (npy_promotion_state != NPY_USE_WEAK_PROMOTION) { + /* Legacy promotion and weak-and-warn not handled here */ + return PROMOTION_REQUIRED; + } + /* Weak promotion is used when self is float or complex: */ + if (!PyTypeNum_ISFLOAT(NPY_@TYPE@) && !PyTypeNum_ISCOMPLEX(NPY_@TYPE@)) { + return PROMOTION_REQUIRED; + } + return CONVERT_PYSCALAR; } CONVERT_TO_RESULT(PyFloat_AS_DOUBLE(value)); return CONVERSION_SUCCESS; @@ -968,15 +985,19 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring) } if (!IS_SAFE(NPY_LONG, NPY_@TYPE@)) { /* - * long -> (c)longdouble is safe, so `THER_IS_UNKNOWN_OBJECT` will + * long -> (c)longdouble is safe, so `OTHER_IS_UNKNOWN_OBJECT` will * be returned below for huge integers. */ - return PROMOTION_REQUIRED; + if (npy_promotion_state != NPY_USE_WEAK_PROMOTION) { + /* Legacy promotion and weak-and-warn not handled here */ + return PROMOTION_REQUIRED; + } + return CONVERT_PYSCALAR; } int overflow; long val = PyLong_AsLongAndOverflow(value, &overflow); if (overflow) { - return OTHER_IS_UNKNOWN_OBJECT; /* handle as if arbitrary object */ + return CONVERT_PYSCALAR; /* handle as if "unsafe" */ } if (error_converting(val)) { return CONVERSION_ERROR; /* should not be possible */ @@ -995,7 +1016,15 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring) *may_need_deferring = NPY_TRUE; } if (!IS_SAFE(NPY_CDOUBLE, NPY_@TYPE@)) { - return PROMOTION_REQUIRED; + if (npy_promotion_state != NPY_USE_WEAK_PROMOTION) { + /* Legacy promotion and weak-and-warn not handled here */ + return PROMOTION_REQUIRED; + } + /* Weak promotion is used when self is float or complex: */ + if (!PyTypeNum_ISCOMPLEX(NPY_@TYPE@)) { + return PROMOTION_REQUIRED; + } + return CONVERT_PYSCALAR; } #if defined(IS_CFLOAT) || defined(IS_CDOUBLE) || defined(IS_CLONGDOUBLE) Py_complex val = PyComplex_AsCComplex(value); @@ -1164,12 +1193,24 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring) * (npy_half, npy_float, npy_double, npy_longdouble, * npy_cfloat, npy_cdouble, npy_clongdouble)*4, * (npy_half, npy_float, npy_double, npy_longdouble)*3# + * #oname = (byte, ubyte, short, ushort, int, uint, + * long, ulong, longlong, ulonglong)*11, + * double*10, + * (half, float, double, longdouble, + * cfloat, cdouble, clongdouble)*4, + * (half, float, double, longdouble)*3# * #OName = (Byte, UByte, Short, UShort, Int, UInt, * Long, ULong, LongLong, ULongLong)*11, * Double*10, * (Half, Float, Double, LongDouble, * CFloat, CDouble, CLongDouble)*4, * (Half, Float, Double, LongDouble)*3# + * #ONAME = (BYTE, UBYTE, SHORT, USHORT, INT, UINT, + * LONG, ULONG, LONGLONG, ULONGLONG)*11, + * DOUBLE*10, + * (HALF, FLOAT, DOUBLE, LONGDOUBLE, + * CFLOAT, CDOUBLE, CLONGDOUBLE)*4, + * (HALF, FLOAT, DOUBLE, LONGDOUBLE)*3# */ #define IS_@name@ /* drop the "true_" from "true_divide" for floating point warnings: */ @@ -1179,13 +1220,12 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring) #else #define OP_NAME "@oper@" #endif -#undef IS_@oper@ static PyObject * @name@_@oper@(PyObject *a, PyObject *b) { PyObject *ret; - @type@ arg1, arg2, other_val; + @otype@ arg1, arg2, other_val; /* * Check if this operation may be considered forward. Note `is_forward` @@ -1214,7 +1254,7 @@ static PyObject * PyObject *other = is_forward ? b : a; npy_bool may_need_deferring; - conversion_result res = convert_to_@name@( + conversion_result res = convert_to_@oname@( other, &other_val, &may_need_deferring); if (res == CONVERSION_ERROR) { return NULL; /* an error occurred (should never happen) */ @@ -1255,6 +1295,11 @@ static PyObject * * correctly. (e.g. `uint8 * int8` cannot warn). */ return PyGenericArrType_Type.tp_as_number->nb_@oper@(a,b); + case CONVERT_PYSCALAR: + if (@ONAME@_setitem(other, (char *)&other_val, NULL) < 0) { + return NULL; + } + break; default: assert(0); /* error was checked already, impossible to reach */ return NULL; @@ -1291,7 +1336,7 @@ static PyObject * #if @twoout@ int retstatus = @name@_ctype_@oper@(arg1, arg2, &out, &out2); #else - int retstatus = @name@_ctype_@oper@(arg1, arg2, &out); + int retstatus = @oname@_ctype_@oper@(arg1, arg2, &out); #endif #if @fperr@ @@ -1336,6 +1381,7 @@ static PyObject * #undef OP_NAME +#undef IS_@oper@ #undef IS_@name@ /**end repeat**/ @@ -1358,6 +1404,10 @@ static PyObject * * Long, ULong, LongLong, ULongLong, * Half, Float, Double, LongDouble, * CFloat, CDouble, CLongDouble# + * #NAME = BYTE, UBYTE, SHORT, USHORT, INT, UINT, + * LONG, ULONG, LONGLONG, ULONGLONG, + * HALF, FLOAT, DOUBLE, LONGDOUBLE, + * CFLOAT, CDOUBLE, CLONGDOUBLE# * * #isint = 1*10,0*7# * #isuint = (0,1)*5,0*7# @@ -1417,6 +1467,11 @@ static PyObject * #endif case PROMOTION_REQUIRED: return PyGenericArrType_Type.tp_as_number->nb_power(a, b, modulo); + case CONVERT_PYSCALAR: + if (@NAME@_setitem(other, (char *)&other_val, NULL) < 0) { + return NULL; + } + break; default: assert(0); /* error was checked already, impossible to reach */ return NULL; @@ -1759,6 +1814,10 @@ static PyObject * * Long, ULong, LongLong, ULongLong, * Half, Float, Double, LongDouble, * CFloat, CDouble, CLongDouble# + * #NAME = BYTE, UBYTE, SHORT, USHORT, INT, UINT, + * LONG, ULONG, LONGLONG, ULONGLONG, + * HALF, FLOAT, DOUBLE, LONGDOUBLE, + * CFLOAT, CDOUBLE, CLONGDOUBLE# * #simp = def*10, def_half, def*3, cmplx*3# */ #define IS_@name@ @@ -1791,6 +1850,11 @@ static PyObject* #endif case PROMOTION_REQUIRED: return PyGenericArrType_Type.tp_richcompare(self, other, cmp_op); + case CONVERT_PYSCALAR: + if (@NAME@_setitem(other, (char *)&arg2, NULL) < 0) { + return NULL; + } + break; default: assert(0); /* error was checked already, impossible to reach */ return NULL; diff --git a/numpy/core/tests/test_nep50_promotions.py b/numpy/core/tests/test_nep50_promotions.py index 5c59a16ea..3957bdc3e 100644 --- a/numpy/core/tests/test_nep50_promotions.py +++ b/numpy/core/tests/test_nep50_promotions.py @@ -30,7 +30,7 @@ def test_nep50_examples(): assert res.dtype == np.int64 with pytest.warns(UserWarning, match="result dtype changed"): - # Note: Should warn (error with the errstate), but does not: + # Note: Overflow would be nice, but does not warn with change warning with np.errstate(over="raise"): res = np.uint8(100) + 200 assert res.dtype == np.uint8 @@ -61,6 +61,21 @@ def test_nep50_examples(): assert res.dtype == np.float64 +def test_nep50_without_warnings(): + # Test that avoid the "warn" method, since that may lead to different + # code paths in some cases. + # Set promotion to weak (no warning), the auto-fixture will reset it. + np._set_promotion_state("weak") + with np.errstate(over="warn"): + with pytest.warns(RuntimeWarning): + res = np.uint8(100) + 200 + assert res.dtype == np.uint8 + + with pytest.warns(RuntimeWarning): + res = np.float32(1) + 3e100 + assert res.dtype == np.float32 + + @pytest.mark.xfail def test_nep50_integer_conversion_errors(): # Implementation for error paths is mostly missing (as of writing) |