summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/dtypemeta.c5
-rw-r--r--numpy/core/tests/test_numeric.py68
2 files changed, 38 insertions, 35 deletions
diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c
index b1cd074a0..2931977c2 100644
--- a/numpy/core/src/multiarray/dtypemeta.c
+++ b/numpy/core/src/multiarray/dtypemeta.c
@@ -400,9 +400,10 @@ default_builtin_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
static PyArray_DTypeMeta *
string_unicode_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
{
- assert(cls->type_num < NPY_NTYPES);
+ assert(cls->type_num < NPY_NTYPES && cls != other);
if (!other->legacy || (!PyTypeNum_ISNUMBER(other->type_num) &&
- (cls->type_num == NPY_STRING && other->type_num == NPY_UNICODE))) {
+ /* Not numeric so defer unless cls is unicode and other is string */
+ !(cls->type_num == NPY_UNICODE && other->type_num == NPY_STRING))) {
Py_INCREF(Py_NotImplemented);
return (PyArray_DTypeMeta *)Py_NotImplemented;
}
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 866a96e31..280874d21 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -888,39 +888,41 @@ class TestTypes:
assert np.can_cast(rational_dt, double_dt)
assert np.promote_types(double_dt, rational_dt) is double_dt
- def test_promote_types_strings(self):
- assert_equal(np.promote_types('bool', 'S'), np.dtype('S5'))
- assert_equal(np.promote_types('b', 'S'), np.dtype('S4'))
- assert_equal(np.promote_types('u1', 'S'), np.dtype('S3'))
- assert_equal(np.promote_types('u2', 'S'), np.dtype('S5'))
- assert_equal(np.promote_types('u4', 'S'), np.dtype('S10'))
- assert_equal(np.promote_types('u8', 'S'), np.dtype('S20'))
- assert_equal(np.promote_types('i1', 'S'), np.dtype('S4'))
- assert_equal(np.promote_types('i2', 'S'), np.dtype('S6'))
- assert_equal(np.promote_types('i4', 'S'), np.dtype('S11'))
- assert_equal(np.promote_types('i8', 'S'), np.dtype('S21'))
- assert_equal(np.promote_types('bool', 'U'), np.dtype('U5'))
- assert_equal(np.promote_types('b', 'U'), np.dtype('U4'))
- assert_equal(np.promote_types('u1', 'U'), np.dtype('U3'))
- assert_equal(np.promote_types('u2', 'U'), np.dtype('U5'))
- assert_equal(np.promote_types('u4', 'U'), np.dtype('U10'))
- assert_equal(np.promote_types('u8', 'U'), np.dtype('U20'))
- assert_equal(np.promote_types('i1', 'U'), np.dtype('U4'))
- assert_equal(np.promote_types('i2', 'U'), np.dtype('U6'))
- assert_equal(np.promote_types('i4', 'U'), np.dtype('U11'))
- assert_equal(np.promote_types('i8', 'U'), np.dtype('U21'))
- assert_equal(np.promote_types('bool', 'S1'), np.dtype('S5'))
- assert_equal(np.promote_types('bool', 'S30'), np.dtype('S30'))
- assert_equal(np.promote_types('b', 'S1'), np.dtype('S4'))
- assert_equal(np.promote_types('b', 'S30'), np.dtype('S30'))
- assert_equal(np.promote_types('u1', 'S1'), np.dtype('S3'))
- assert_equal(np.promote_types('u1', 'S30'), np.dtype('S30'))
- assert_equal(np.promote_types('u2', 'S1'), np.dtype('S5'))
- assert_equal(np.promote_types('u2', 'S30'), np.dtype('S30'))
- assert_equal(np.promote_types('u4', 'S1'), np.dtype('S10'))
- assert_equal(np.promote_types('u4', 'S30'), np.dtype('S30'))
- assert_equal(np.promote_types('u8', 'S1'), np.dtype('S20'))
- assert_equal(np.promote_types('u8', 'S30'), np.dtype('S30'))
+ @pytest.mark.parametrize("swap", ["", "swap"])
+ @pytest.mark.parametrize("string_dtype", ["U", "S"])
+ def test_promote_types_strings(self, swap, string_dtype):
+ if swap == "swap":
+ promote_types = lambda a, b: np.promote_types(b, a)
+ else:
+ promote_types = np.promote_types
+
+ S = string_dtype
+ # Promote numeric with unsized string:
+ assert_equal(promote_types('bool', S), np.dtype(S+'5'))
+ assert_equal(promote_types('b', S), np.dtype(S+'4'))
+ assert_equal(promote_types('u1', S), np.dtype(S+'3'))
+ assert_equal(promote_types('u2', S), np.dtype(S+'5'))
+ assert_equal(promote_types('u4', S), np.dtype(S+'10'))
+ assert_equal(promote_types('u8', S), np.dtype(S+'20'))
+ assert_equal(promote_types('i1', S), np.dtype(S+'4'))
+ assert_equal(promote_types('i2', S), np.dtype(S+'6'))
+ assert_equal(promote_types('i4', S), np.dtype(S+'11'))
+ assert_equal(promote_types('i8', S), np.dtype(S+'21'))
+ # Promote numeric with sized string:
+ assert_equal(promote_types('bool', S+'1'), np.dtype(S+'5'))
+ assert_equal(promote_types('bool', S+'30'), np.dtype(S+'30'))
+ assert_equal(promote_types('b', S+'1'), np.dtype(S+'4'))
+ assert_equal(promote_types('b', S+'30'), np.dtype(S+'30'))
+ assert_equal(promote_types('u1', S+'1'), np.dtype(S+'3'))
+ assert_equal(promote_types('u1', S+'30'), np.dtype(S+'30'))
+ assert_equal(promote_types('u2', S+'1'), np.dtype(S+'5'))
+ assert_equal(promote_types('u2', S+'30'), np.dtype(S+'30'))
+ assert_equal(promote_types('u4', S+'1'), np.dtype(S+'10'))
+ assert_equal(promote_types('u4', S+'30'), np.dtype(S+'30'))
+ assert_equal(promote_types('u8', S+'1'), np.dtype(S+'20'))
+ assert_equal(promote_types('u8', S+'30'), np.dtype(S+'30'))
+ # Promote with object:
+ assert_equal(promote_types('O', S+'30'), np.dtype('O'))
@pytest.mark.parametrize(["dtype1", "dtype2"],
[[np.dtype("V6"), np.dtype("V10")],