summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/umath/_scaled_float_dtype.c52
-rw-r--r--numpy/core/tests/test_custom_dtypes.py20
2 files changed, 72 insertions, 0 deletions
diff --git a/numpy/core/src/umath/_scaled_float_dtype.c b/numpy/core/src/umath/_scaled_float_dtype.c
index eeef33a3d..b6c19362a 100644
--- a/numpy/core/src/umath/_scaled_float_dtype.c
+++ b/numpy/core/src/umath/_scaled_float_dtype.c
@@ -398,6 +398,42 @@ float_to_from_sfloat_resolve_descriptors(
}
+/*
+ * Cast to boolean (for testing the logical functions a bit better).
+ */
+static int
+cast_sfloat_to_bool(PyArrayMethod_Context *NPY_UNUSED(context),
+ char *const data[], npy_intp const dimensions[],
+ npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
+{
+ npy_intp N = dimensions[0];
+ char *in = data[0];
+ char *out = data[1];
+ for (npy_intp i = 0; i < N; i++) {
+ *(npy_bool *)out = *(double *)in != 0;
+ in += strides[0];
+ out += strides[1];
+ }
+ return 0;
+}
+
+static NPY_CASTING
+sfloat_to_bool_resolve_descriptors(
+ PyArrayMethodObject *NPY_UNUSED(self),
+ PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
+ PyArray_Descr *given_descrs[2],
+ PyArray_Descr *loop_descrs[2])
+{
+ Py_INCREF(given_descrs[0]);
+ loop_descrs[0] = given_descrs[0];
+ if (loop_descrs[0] == NULL) {
+ return -1;
+ }
+ loop_descrs[1] = PyArray_DescrFromType(NPY_BOOL); /* cannot fail */
+ return NPY_UNSAFE_CASTING;
+}
+
+
static int
init_casts(void)
{
@@ -453,6 +489,22 @@ init_casts(void)
return -1;
}
+ slots[0].slot = NPY_METH_resolve_descriptors;
+ slots[0].pfunc = &sfloat_to_bool_resolve_descriptors;
+ slots[1].slot = NPY_METH_strided_loop;
+ slots[1].pfunc = &cast_sfloat_to_bool;
+ slots[2].slot = 0;
+ slots[2].pfunc = NULL;
+
+ spec.name = "sfloat_to_bool_cast";
+ dtypes[0] = &PyArray_SFloatDType;
+ dtypes[1] = PyArray_DTypeFromTypeNum(NPY_BOOL);
+ Py_DECREF(dtypes[1]); /* immortal anyway */
+
+ if (PyArray_AddCastingImplementation_FromSpec(&spec, 0)) {
+ return -1;
+ }
+
return 0;
}
diff --git a/numpy/core/tests/test_custom_dtypes.py b/numpy/core/tests/test_custom_dtypes.py
index b9311e283..e502a87ed 100644
--- a/numpy/core/tests/test_custom_dtypes.py
+++ b/numpy/core/tests/test_custom_dtypes.py
@@ -161,3 +161,23 @@ class TestSFloat:
# Check that casting the output fails also (done by the ufunc here)
with pytest.raises(TypeError):
np.add(a, a, out=c, casting="safe")
+
+ @pytest.mark.parametrize("ufunc",
+ [np.logical_and, np.logical_or, np.logical_xor])
+ def test_logical_ufuncs_casts_to_bool(self, ufunc):
+ a = self._get_array(2.)
+ a[0] = 0. # make sure first element is considered False.
+
+ float_equiv = a.astype(float)
+ expected = ufunc(float_equiv, float_equiv)
+ res = ufunc(a, a)
+ assert_array_equal(res, expected)
+
+ # also check that the same works for reductions:
+ expected = ufunc.reduce(float_equiv)
+ res = ufunc.reduce(a)
+ assert_array_equal(res, expected)
+
+ # The output casting does not match the bool, bool -> bool loop:
+ with pytest.raises(TypeError):
+ ufunc(a, a, out=np.empty(a.shape, dtype=int), casting="equiv")