summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-01-06 05:56:46 -0700
committerGitHub <noreply@github.com>2021-01-06 05:56:46 -0700
commit10eeabee267e17c042bd1ff991823b16c85ed563 (patch)
tree9ba29c5db17c554c78e01e0e576d9f45d68185dd
parentda887a666ad975ece7fb7465005aa99c0ddef8d2 (diff)
parent9961807f24b5e53b2248cffd1274567a3a6139ba (diff)
downloadnumpy-10eeabee267e17c042bd1ff991823b16c85ed563.tar.gz
Merge pull request #18115 from seberg/half-promotion
BUG: Fix promotion of half and string
-rw-r--r--numpy/core/src/multiarray/dtypemeta.c10
-rw-r--r--numpy/core/tests/test_half.py15
2 files changed, 21 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c
index 4c11723e7..b1cd074a0 100644
--- a/numpy/core/src/multiarray/dtypemeta.c
+++ b/numpy/core/src/multiarray/dtypemeta.c
@@ -375,7 +375,10 @@ default_builtin_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
{
assert(cls->type_num < NPY_NTYPES);
if (!other->legacy || other->type_num > cls->type_num) {
- /* Let the more generic (larger type number) DType handle this */
+ /*
+ * Let the more generic (larger type number) DType handle this
+ * (note that half is after all others, which works out here.)
+ */
Py_INCREF(Py_NotImplemented);
return (PyArray_DTypeMeta *)Py_NotImplemented;
}
@@ -398,9 +401,8 @@ static PyArray_DTypeMeta *
string_unicode_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
{
assert(cls->type_num < NPY_NTYPES);
- if (!other->legacy || other->type_num > cls->type_num ||
- other->type_num == NPY_OBJECT) {
- /* Let the more generic (larger type number) DType handle this */
+ if (!other->legacy || (!PyTypeNum_ISNUMBER(other->type_num) &&
+ (cls->type_num == NPY_STRING && other->type_num == NPY_UNICODE))) {
Py_INCREF(Py_NotImplemented);
return (PyArray_DTypeMeta *)Py_NotImplemented;
}
diff --git a/numpy/core/tests/test_half.py b/numpy/core/tests/test_half.py
index ae9827bc7..1b6fd21e1 100644
--- a/numpy/core/tests/test_half.py
+++ b/numpy/core/tests/test_half.py
@@ -67,6 +67,21 @@ class TestHalf:
j = np.array(i_f16, dtype=int)
assert_equal(i_int, j)
+ @pytest.mark.parametrize("string_dt", ["S", "U"])
+ def test_half_conversion_to_string(self, string_dt):
+ # Currently uses S/U32 (which is sufficient for float32)
+ expected_dt = np.dtype(f"{string_dt}32")
+ assert np.promote_types(np.float16, string_dt) == expected_dt
+ assert np.promote_types(string_dt, np.float16) == expected_dt
+
+ arr = np.ones(3, dtype=np.float16).astype(string_dt)
+ assert arr.dtype == expected_dt
+
+ @pytest.mark.parametrize("string_dt", ["S", "U"])
+ def test_half_conversion_from_string(self, string_dt):
+ string = np.array("3.1416", dtype=string_dt)
+ assert string.astype(np.float16) == np.array(3.1416, dtype=np.float16)
+
@pytest.mark.parametrize("offset", [None, "up", "down"])
@pytest.mark.parametrize("shift", [None, "up", "down"])
@pytest.mark.parametrize("float_t", [np.float32, np.float64])