diff options
-rw-r--r-- | numpy/core/src/umath/_scaled_float_dtype.c | 67 | ||||
-rw-r--r-- | numpy/core/tests/test_custom_dtypes.py | 12 |
2 files changed, 70 insertions, 9 deletions
diff --git a/numpy/core/src/umath/_scaled_float_dtype.c b/numpy/core/src/umath/_scaled_float_dtype.c index aa9c4549c..8ca31169f 100644 --- a/numpy/core/src/umath/_scaled_float_dtype.c +++ b/numpy/core/src/umath/_scaled_float_dtype.c @@ -464,9 +464,6 @@ init_casts(void) * 2. Addition, which needs to use the common instance, and runs into * cast safety subtleties since we will implement it without an additional * cast. - * - * NOTE: When first writing this, promotion did not exist for new-style loops, - * if it exists, we could use promotion to implement double * sfloat. */ static int multiply_sfloats(PyArrayMethod_Context *NPY_UNUSED(context), @@ -591,7 +588,8 @@ add_sfloats_resolve_descriptors( static int -add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth) +add_loop(const char *ufunc_name, + PyArray_DTypeMeta *dtypes[3], PyObject *meth_or_promoter) { PyObject *mod = PyImport_ImportModule("numpy"); if (mod == NULL) { @@ -605,13 +603,12 @@ add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth) "numpy.%s was not a ufunc!", ufunc_name); return -1; } - PyObject *dtype_tup = PyArray_TupleFromItems( - 3, (PyObject **)bmeth->dtypes, 0); + PyObject *dtype_tup = PyArray_TupleFromItems(3, (PyObject **)dtypes, 1); if (dtype_tup == NULL) { Py_DECREF(ufunc); return -1; } - PyObject *info = PyTuple_Pack(2, dtype_tup, bmeth->method); + PyObject *info = PyTuple_Pack(2, dtype_tup, meth_or_promoter); Py_DECREF(dtype_tup); if (info == NULL) { Py_DECREF(ufunc); @@ -624,6 +621,28 @@ add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth) } + +/* + * We add some very basic promoters to allow multiplying normal and scaled + */ +static int +promote_to_sfloat(PyUFuncObject *NPY_UNUSED(ufunc), + PyArray_DTypeMeta *const NPY_UNUSED(dtypes[3]), + PyArray_DTypeMeta *const signature[3], + PyArray_DTypeMeta *new_dtypes[3]) +{ + for (int i = 0; i < 3; i++) { + PyArray_DTypeMeta *new = &PyArray_SFloatDType; + if (signature[i] != NULL) { + new = signature[i]; + } + Py_INCREF(new); + new_dtypes[i] = new; + } + return 0; +} + + /* * Add new ufunc loops (this is somewhat clumsy as of writing it, but should * get less so with the introduction of public API). @@ -650,7 +669,8 @@ init_ufuncs(void) { if (bmeth == NULL) { return -1; } - int res = add_loop("multiply", bmeth); + int res = add_loop("multiply", + bmeth->dtypes, (PyObject *)bmeth->method); Py_DECREF(bmeth); if (res < 0) { return -1; @@ -667,11 +687,40 @@ init_ufuncs(void) { if (bmeth == NULL) { return -1; } - res = add_loop("add", bmeth); + res = add_loop("add", + bmeth->dtypes, (PyObject *)bmeth->method); Py_DECREF(bmeth); if (res < 0) { return -1; } + + /* + * Add a promoter for both directions of multiply with double. + */ + PyArray_DTypeMeta *double_DType = PyArray_DTypeFromTypeNum(NPY_DOUBLE); + Py_DECREF(double_DType); /* immortal anyway */ + + PyArray_DTypeMeta *promoter_dtypes[3] = { + &PyArray_SFloatDType, double_DType, NULL}; + + PyObject *promoter = PyCapsule_New( + &promote_to_sfloat, "numpy._ufunc_promoter", NULL); + if (promoter == NULL) { + return -1; + } + res = add_loop("multiply", promoter_dtypes, promoter); + if (res < 0) { + Py_DECREF(promoter); + return -1; + } + promoter_dtypes[0] = double_DType; + promoter_dtypes[1] = &PyArray_SFloatDType; + res = add_loop("multiply", promoter_dtypes, promoter); + Py_DECREF(promoter); + if (res < 0) { + return -1; + } + return 0; } diff --git a/numpy/core/tests/test_custom_dtypes.py b/numpy/core/tests/test_custom_dtypes.py index 3ec2363b9..5eb82bc93 100644 --- a/numpy/core/tests/test_custom_dtypes.py +++ b/numpy/core/tests/test_custom_dtypes.py @@ -101,6 +101,18 @@ class TestSFloat: expected_view = a.view(np.float64) * b.view(np.float64) assert_array_equal(res.view(np.float64), expected_view) + def test_basic_multiply_promotion(self): + float_a = np.array([1., 2., 3.]) + b = self._get_array(2.) + + res1 = float_a * b + res2 = b * float_a + # one factor is one, so we get the factor of b: + assert res1.dtype == res2.dtype == b.dtype + expected_view = float_a * b.view(np.float64) + assert_array_equal(res1.view(np.float64), expected_view) + assert_array_equal(res2.view(np.float64), expected_view) + def test_basic_addition(self): a = self._get_array(2.) b = self._get_array(4.) |