diff options
-rw-r--r-- | numpy/core/src/multiarray/dtypemeta.c | 5 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 68 |
2 files changed, 38 insertions, 35 deletions
diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c index b1cd074a0..2931977c2 100644 --- a/numpy/core/src/multiarray/dtypemeta.c +++ b/numpy/core/src/multiarray/dtypemeta.c @@ -400,9 +400,10 @@ default_builtin_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) static PyArray_DTypeMeta * string_unicode_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) { - assert(cls->type_num < NPY_NTYPES); + assert(cls->type_num < NPY_NTYPES && cls != other); if (!other->legacy || (!PyTypeNum_ISNUMBER(other->type_num) && - (cls->type_num == NPY_STRING && other->type_num == NPY_UNICODE))) { + /* Not numeric so defer unless cls is unicode and other is string */ + !(cls->type_num == NPY_UNICODE && other->type_num == NPY_STRING))) { Py_INCREF(Py_NotImplemented); return (PyArray_DTypeMeta *)Py_NotImplemented; } diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 866a96e31..280874d21 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -888,39 +888,41 @@ class TestTypes: assert np.can_cast(rational_dt, double_dt) assert np.promote_types(double_dt, rational_dt) is double_dt - def test_promote_types_strings(self): - assert_equal(np.promote_types('bool', 'S'), np.dtype('S5')) - assert_equal(np.promote_types('b', 'S'), np.dtype('S4')) - assert_equal(np.promote_types('u1', 'S'), np.dtype('S3')) - assert_equal(np.promote_types('u2', 'S'), np.dtype('S5')) - assert_equal(np.promote_types('u4', 'S'), np.dtype('S10')) - assert_equal(np.promote_types('u8', 'S'), np.dtype('S20')) - assert_equal(np.promote_types('i1', 'S'), np.dtype('S4')) - assert_equal(np.promote_types('i2', 'S'), np.dtype('S6')) - assert_equal(np.promote_types('i4', 'S'), np.dtype('S11')) - assert_equal(np.promote_types('i8', 'S'), np.dtype('S21')) - assert_equal(np.promote_types('bool', 'U'), np.dtype('U5')) - assert_equal(np.promote_types('b', 'U'), np.dtype('U4')) - assert_equal(np.promote_types('u1', 'U'), np.dtype('U3')) - assert_equal(np.promote_types('u2', 'U'), np.dtype('U5')) - assert_equal(np.promote_types('u4', 'U'), np.dtype('U10')) - assert_equal(np.promote_types('u8', 'U'), np.dtype('U20')) - assert_equal(np.promote_types('i1', 'U'), np.dtype('U4')) - assert_equal(np.promote_types('i2', 'U'), np.dtype('U6')) - assert_equal(np.promote_types('i4', 'U'), np.dtype('U11')) - assert_equal(np.promote_types('i8', 'U'), np.dtype('U21')) - assert_equal(np.promote_types('bool', 'S1'), np.dtype('S5')) - assert_equal(np.promote_types('bool', 'S30'), np.dtype('S30')) - assert_equal(np.promote_types('b', 'S1'), np.dtype('S4')) - assert_equal(np.promote_types('b', 'S30'), np.dtype('S30')) - assert_equal(np.promote_types('u1', 'S1'), np.dtype('S3')) - assert_equal(np.promote_types('u1', 'S30'), np.dtype('S30')) - assert_equal(np.promote_types('u2', 'S1'), np.dtype('S5')) - assert_equal(np.promote_types('u2', 'S30'), np.dtype('S30')) - assert_equal(np.promote_types('u4', 'S1'), np.dtype('S10')) - assert_equal(np.promote_types('u4', 'S30'), np.dtype('S30')) - assert_equal(np.promote_types('u8', 'S1'), np.dtype('S20')) - assert_equal(np.promote_types('u8', 'S30'), np.dtype('S30')) + @pytest.mark.parametrize("swap", ["", "swap"]) + @pytest.mark.parametrize("string_dtype", ["U", "S"]) + def test_promote_types_strings(self, swap, string_dtype): + if swap == "swap": + promote_types = lambda a, b: np.promote_types(b, a) + else: + promote_types = np.promote_types + + S = string_dtype + # Promote numeric with unsized string: + assert_equal(promote_types('bool', S), np.dtype(S+'5')) + assert_equal(promote_types('b', S), np.dtype(S+'4')) + assert_equal(promote_types('u1', S), np.dtype(S+'3')) + assert_equal(promote_types('u2', S), np.dtype(S+'5')) + assert_equal(promote_types('u4', S), np.dtype(S+'10')) + assert_equal(promote_types('u8', S), np.dtype(S+'20')) + assert_equal(promote_types('i1', S), np.dtype(S+'4')) + assert_equal(promote_types('i2', S), np.dtype(S+'6')) + assert_equal(promote_types('i4', S), np.dtype(S+'11')) + assert_equal(promote_types('i8', S), np.dtype(S+'21')) + # Promote numeric with sized string: + assert_equal(promote_types('bool', S+'1'), np.dtype(S+'5')) + assert_equal(promote_types('bool', S+'30'), np.dtype(S+'30')) + assert_equal(promote_types('b', S+'1'), np.dtype(S+'4')) + assert_equal(promote_types('b', S+'30'), np.dtype(S+'30')) + assert_equal(promote_types('u1', S+'1'), np.dtype(S+'3')) + assert_equal(promote_types('u1', S+'30'), np.dtype(S+'30')) + assert_equal(promote_types('u2', S+'1'), np.dtype(S+'5')) + assert_equal(promote_types('u2', S+'30'), np.dtype(S+'30')) + assert_equal(promote_types('u4', S+'1'), np.dtype(S+'10')) + assert_equal(promote_types('u4', S+'30'), np.dtype(S+'30')) + assert_equal(promote_types('u8', S+'1'), np.dtype(S+'20')) + assert_equal(promote_types('u8', S+'30'), np.dtype(S+'30')) + # Promote with object: + assert_equal(promote_types('O', S+'30'), np.dtype('O')) @pytest.mark.parametrize(["dtype1", "dtype2"], [[np.dtype("V6"), np.dtype("V10")], |