summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/umath/_scaled_float_dtype.c67
-rw-r--r--numpy/core/tests/test_custom_dtypes.py12
2 files changed, 70 insertions, 9 deletions
diff --git a/numpy/core/src/umath/_scaled_float_dtype.c b/numpy/core/src/umath/_scaled_float_dtype.c
index aa9c4549c..8ca31169f 100644
--- a/numpy/core/src/umath/_scaled_float_dtype.c
+++ b/numpy/core/src/umath/_scaled_float_dtype.c
@@ -464,9 +464,6 @@ init_casts(void)
* 2. Addition, which needs to use the common instance, and runs into
* cast safety subtleties since we will implement it without an additional
* cast.
- *
- * NOTE: When first writing this, promotion did not exist for new-style loops,
- * if it exists, we could use promotion to implement double * sfloat.
*/
static int
multiply_sfloats(PyArrayMethod_Context *NPY_UNUSED(context),
@@ -591,7 +588,8 @@ add_sfloats_resolve_descriptors(
static int
-add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth)
+add_loop(const char *ufunc_name,
+ PyArray_DTypeMeta *dtypes[3], PyObject *meth_or_promoter)
{
PyObject *mod = PyImport_ImportModule("numpy");
if (mod == NULL) {
@@ -605,13 +603,12 @@ add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth)
"numpy.%s was not a ufunc!", ufunc_name);
return -1;
}
- PyObject *dtype_tup = PyArray_TupleFromItems(
- 3, (PyObject **)bmeth->dtypes, 0);
+ PyObject *dtype_tup = PyArray_TupleFromItems(3, (PyObject **)dtypes, 1);
if (dtype_tup == NULL) {
Py_DECREF(ufunc);
return -1;
}
- PyObject *info = PyTuple_Pack(2, dtype_tup, bmeth->method);
+ PyObject *info = PyTuple_Pack(2, dtype_tup, meth_or_promoter);
Py_DECREF(dtype_tup);
if (info == NULL) {
Py_DECREF(ufunc);
@@ -624,6 +621,28 @@ add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth)
}
+
+/*
+ * We add some very basic promoters to allow multiplying normal and scaled
+ */
+static int
+promote_to_sfloat(PyUFuncObject *NPY_UNUSED(ufunc),
+ PyArray_DTypeMeta *const NPY_UNUSED(dtypes[3]),
+ PyArray_DTypeMeta *const signature[3],
+ PyArray_DTypeMeta *new_dtypes[3])
+{
+ for (int i = 0; i < 3; i++) {
+ PyArray_DTypeMeta *new = &PyArray_SFloatDType;
+ if (signature[i] != NULL) {
+ new = signature[i];
+ }
+ Py_INCREF(new);
+ new_dtypes[i] = new;
+ }
+ return 0;
+}
+
+
/*
* Add new ufunc loops (this is somewhat clumsy as of writing it, but should
* get less so with the introduction of public API).
@@ -650,7 +669,8 @@ init_ufuncs(void) {
if (bmeth == NULL) {
return -1;
}
- int res = add_loop("multiply", bmeth);
+ int res = add_loop("multiply",
+ bmeth->dtypes, (PyObject *)bmeth->method);
Py_DECREF(bmeth);
if (res < 0) {
return -1;
@@ -667,11 +687,40 @@ init_ufuncs(void) {
if (bmeth == NULL) {
return -1;
}
- res = add_loop("add", bmeth);
+ res = add_loop("add",
+ bmeth->dtypes, (PyObject *)bmeth->method);
Py_DECREF(bmeth);
if (res < 0) {
return -1;
}
+
+ /*
+ * Add a promoter for both directions of multiply with double.
+ */
+ PyArray_DTypeMeta *double_DType = PyArray_DTypeFromTypeNum(NPY_DOUBLE);
+ Py_DECREF(double_DType); /* immortal anyway */
+
+ PyArray_DTypeMeta *promoter_dtypes[3] = {
+ &PyArray_SFloatDType, double_DType, NULL};
+
+ PyObject *promoter = PyCapsule_New(
+ &promote_to_sfloat, "numpy._ufunc_promoter", NULL);
+ if (promoter == NULL) {
+ return -1;
+ }
+ res = add_loop("multiply", promoter_dtypes, promoter);
+ if (res < 0) {
+ Py_DECREF(promoter);
+ return -1;
+ }
+ promoter_dtypes[0] = double_DType;
+ promoter_dtypes[1] = &PyArray_SFloatDType;
+ res = add_loop("multiply", promoter_dtypes, promoter);
+ Py_DECREF(promoter);
+ if (res < 0) {
+ return -1;
+ }
+
return 0;
}
diff --git a/numpy/core/tests/test_custom_dtypes.py b/numpy/core/tests/test_custom_dtypes.py
index 3ec2363b9..5eb82bc93 100644
--- a/numpy/core/tests/test_custom_dtypes.py
+++ b/numpy/core/tests/test_custom_dtypes.py
@@ -101,6 +101,18 @@ class TestSFloat:
expected_view = a.view(np.float64) * b.view(np.float64)
assert_array_equal(res.view(np.float64), expected_view)
+ def test_basic_multiply_promotion(self):
+ float_a = np.array([1., 2., 3.])
+ b = self._get_array(2.)
+
+ res1 = float_a * b
+ res2 = b * float_a
+ # one factor is one, so we get the factor of b:
+ assert res1.dtype == res2.dtype == b.dtype
+ expected_view = float_a * b.view(np.float64)
+ assert_array_equal(res1.view(np.float64), expected_view)
+ assert_array_equal(res2.view(np.float64), expected_view)
+
def test_basic_addition(self):
a = self._get_array(2.)
b = self._get_array(4.)