diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2021-01-06 05:56:46 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-01-06 05:56:46 -0700 |
commit | 10eeabee267e17c042bd1ff991823b16c85ed563 (patch) | |
tree | 9ba29c5db17c554c78e01e0e576d9f45d68185dd | |
parent | da887a666ad975ece7fb7465005aa99c0ddef8d2 (diff) | |
parent | 9961807f24b5e53b2248cffd1274567a3a6139ba (diff) | |
download | numpy-10eeabee267e17c042bd1ff991823b16c85ed563.tar.gz |
Merge pull request #18115 from seberg/half-promotion
BUG: Fix promotion of half and string
-rw-r--r-- | numpy/core/src/multiarray/dtypemeta.c | 10 | ||||
-rw-r--r-- | numpy/core/tests/test_half.py | 15 |
2 files changed, 21 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c index 4c11723e7..b1cd074a0 100644 --- a/numpy/core/src/multiarray/dtypemeta.c +++ b/numpy/core/src/multiarray/dtypemeta.c @@ -375,7 +375,10 @@ default_builtin_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) { assert(cls->type_num < NPY_NTYPES); if (!other->legacy || other->type_num > cls->type_num) { - /* Let the more generic (larger type number) DType handle this */ + /* + * Let the more generic (larger type number) DType handle this + * (note that half is after all others, which works out here.) + */ Py_INCREF(Py_NotImplemented); return (PyArray_DTypeMeta *)Py_NotImplemented; } @@ -398,9 +401,8 @@ static PyArray_DTypeMeta * string_unicode_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) { assert(cls->type_num < NPY_NTYPES); - if (!other->legacy || other->type_num > cls->type_num || - other->type_num == NPY_OBJECT) { - /* Let the more generic (larger type number) DType handle this */ + if (!other->legacy || (!PyTypeNum_ISNUMBER(other->type_num) && + (cls->type_num == NPY_STRING && other->type_num == NPY_UNICODE))) { Py_INCREF(Py_NotImplemented); return (PyArray_DTypeMeta *)Py_NotImplemented; } diff --git a/numpy/core/tests/test_half.py b/numpy/core/tests/test_half.py index ae9827bc7..1b6fd21e1 100644 --- a/numpy/core/tests/test_half.py +++ b/numpy/core/tests/test_half.py @@ -67,6 +67,21 @@ class TestHalf: j = np.array(i_f16, dtype=int) assert_equal(i_int, j) + @pytest.mark.parametrize("string_dt", ["S", "U"]) + def test_half_conversion_to_string(self, string_dt): + # Currently uses S/U32 (which is sufficient for float32) + expected_dt = np.dtype(f"{string_dt}32") + assert np.promote_types(np.float16, string_dt) == expected_dt + assert np.promote_types(string_dt, np.float16) == expected_dt + + arr = np.ones(3, dtype=np.float16).astype(string_dt) + assert arr.dtype == expected_dt + + @pytest.mark.parametrize("string_dt", ["S", "U"]) + def test_half_conversion_from_string(self, string_dt): + string = np.array("3.1416", dtype=string_dt) + assert string.astype(np.float16) == np.array(3.1416, dtype=np.float16) + @pytest.mark.parametrize("offset", [None, "up", "down"]) @pytest.mark.parametrize("shift", [None, "up", "down"]) @pytest.mark.parametrize("float_t", [np.float32, np.float64]) |