summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-02-18 13:53:40 -0800
committerSebastian Berg <sebastian@sipsolutions.net>2020-03-18 17:29:34 -0500
commite02cae81b82b1c45ffeb62875934dfc49b7b779a (patch)
tree4e45c5d3c8b95adf799e62b5bf152fb5ef3816b2 /numpy
parentc3c0786fd24a5302a8bc6762f92a8e78deee1602 (diff)
downloadnumpy-e02cae81b82b1c45ffeb62875934dfc49b7b779a.tar.gz
Add tests and fixup __name__
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/dtypemeta.c22
-rw-r--r--numpy/core/tests/test_dtype.py11
2 files changed, 29 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c
index 73aa79a66..23c39606f 100644
--- a/numpy/core/src/multiarray/dtypemeta.c
+++ b/numpy/core/src/multiarray/dtypemeta.c
@@ -147,13 +147,27 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
* In particular our own DTypes can be true static declarations.
* However, this function remains necessary for legacy user dtypes.
*/
- char *tp_name = PyDataMem_NEW(100);
+
+ const char *scalar_name = descr->typeobj->tp_name;
+ /*
+ * We have to take only the name, and ignore the module to get
+ * a reasonable __name__, since static types are limited in this regard
+ * (this is not ideal, but not a big issue in practice).
+ * This is what Python does to print __name__ for static types.
+ */
+ const char *dot = strrchr(scalar_name, '.');
+ if (dot) {
+ scalar_name = dot + 1;
+ }
+ ssize_t name_length = strlen(scalar_name) + 14;
+
+ char *tp_name = malloc(name_length);
if (tp_name == NULL) {
PyErr_NoMemory();
return -1;
}
- snprintf(tp_name, 100, "numpy.dtype[%s]",
- descr->typeobj->tp_name);
+
+ snprintf(tp_name, name_length, "numpy.dtype[%s]", scalar_name);
PyArray_DTypeMeta *dtype_class = malloc(sizeof(PyArray_DTypeMeta));
if (dtype_class == NULL) {
@@ -176,7 +190,7 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
.tp_basicsize = sizeof(PyArray_Descr),
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_base = &PyArrayDescr_Type,
- .tp_new = (newfunc)legacy_dtype_default_new
+ .tp_new = (newfunc)legacy_dtype_default_new,
},},
.is_legacy = 1,
.is_abstract = 0, /* this is a concrete DType */
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index c9a65cd9c..cbe8322a3 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -1091,6 +1091,17 @@ class TestFromDTypeAttribute:
with pytest.raises(RecursionError):
np.dtype(dt(1))
+
+class TestDTypeClasses:
+ @pytest.mark.parametrize("dtype", list(np.typecodes['All']) + [rational])
+ def test_dtypes_are_subclasses(self, dtype):
+ dtype = np.dtype(dtype)
+ assert isinstance(dtype, np.dtype)
+ assert type(dtype) is not np.dtype
+ assert type(dtype).__name__ == f"dtype[{dtype.type.__name__}]"
+ assert type(dtype).__module__ == "numpy"
+
+
class TestFromCTypes:
@staticmethod