diff options
| -rw-r--r-- | numpy/core/src/multiarray/dtypemeta.c | 22 | ||||
| -rw-r--r-- | numpy/core/tests/test_dtype.py | 11 |
2 files changed, 29 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c index 73aa79a66..23c39606f 100644 --- a/numpy/core/src/multiarray/dtypemeta.c +++ b/numpy/core/src/multiarray/dtypemeta.c @@ -147,13 +147,27 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) * In particular our own DTypes can be true static declarations. * However, this function remains necessary for legacy user dtypes. */ - char *tp_name = PyDataMem_NEW(100); + + const char *scalar_name = descr->typeobj->tp_name; + /* + * We have to take only the name, and ignore the module to get + * a reasonable __name__, since static types are limited in this regard + * (this is not ideal, but not a big issue in practice). + * This is what Python does to print __name__ for static types. + */ + const char *dot = strrchr(scalar_name, '.'); + if (dot) { + scalar_name = dot + 1; + } + ssize_t name_length = strlen(scalar_name) + 14; + + char *tp_name = malloc(name_length); if (tp_name == NULL) { PyErr_NoMemory(); return -1; } - snprintf(tp_name, 100, "numpy.dtype[%s]", - descr->typeobj->tp_name); + + snprintf(tp_name, name_length, "numpy.dtype[%s]", scalar_name); PyArray_DTypeMeta *dtype_class = malloc(sizeof(PyArray_DTypeMeta)); if (dtype_class == NULL) { @@ -176,7 +190,7 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) .tp_basicsize = sizeof(PyArray_Descr), .tp_flags = Py_TPFLAGS_DEFAULT, .tp_base = &PyArrayDescr_Type, - .tp_new = (newfunc)legacy_dtype_default_new + .tp_new = (newfunc)legacy_dtype_default_new, },}, .is_legacy = 1, .is_abstract = 0, /* this is a concrete DType */ diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index c9a65cd9c..cbe8322a3 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -1091,6 +1091,17 @@ class TestFromDTypeAttribute: with pytest.raises(RecursionError): np.dtype(dt(1)) + +class TestDTypeClasses: + @pytest.mark.parametrize("dtype", list(np.typecodes['All']) + [rational]) + def test_dtypes_are_subclasses(self, dtype): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype) + assert type(dtype) is not np.dtype + assert type(dtype).__name__ == f"dtype[{dtype.type.__name__}]" + assert type(dtype).__module__ == "numpy" + + class TestFromCTypes: @staticmethod |
