summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/array_coercion.c36
-rw-r--r--numpy/core/src/multiarray/usertypes.c6
-rw-r--r--numpy/core/tests/test_dtype.py4
3 files changed, 28 insertions, 18 deletions
diff --git a/numpy/core/src/multiarray/array_coercion.c b/numpy/core/src/multiarray/array_coercion.c
index 64a06d58b..4831dfca6 100644
--- a/numpy/core/src/multiarray/array_coercion.c
+++ b/numpy/core/src/multiarray/array_coercion.c
@@ -128,7 +128,9 @@ _prime_global_pytype_to_type_dict(void)
/**
- * Add a new mapping from a python type to the DType class.
+ * Add a new mapping from a python type to the DType class. For a user
+ * defined legacy dtype, this function does nothing unless the pytype
+ * subclass from `np.generic`.
*
* This assumes that the DType class is guaranteed to hold on the
* python type (this assumption is guaranteed).
@@ -145,21 +147,29 @@ _PyArray_MapPyTypeToDType(
{
PyObject *Dtype_obj = (PyObject *)DType;
- if (userdef) {
+ if (userdef && !PyObject_IsSubclass(
+ (PyObject *)pytype, (PyObject *)&PyGenericArrType_Type)) {
/*
- * It seems we did not strictly enforce this in the legacy dtype
- * API, but assume that it is always true. Further, this could be
- * relaxed in the future. In particular we should have a new
- * superclass of ``np.generic`` in order to note enforce the array
- * scalar behaviour.
+ * We expect that user dtypes (for now) will subclass some numpy
+ * scalar class to allow automatic discovery.
*/
- if (!PyObject_IsSubclass((PyObject *)pytype, (PyObject *)&PyGenericArrType_Type)) {
- PyErr_Format(PyExc_RuntimeError,
- "currently it is only possible to register a DType "
- "for scalars deriving from `np.generic`, got '%S'.",
- (PyObject *)pytype);
- return -1;
+ if (DType->legacy) {
+ /*
+ * For legacy user dtypes, discovery relied on subclassing, but
+ * arbitrary type objects are supported, so do nothing.
+ */
+ return 0;
}
+ /*
+ * We currently enforce that user DTypes subclass from `np.generic`
+ * (this should become a `np.generic` base class and may be lifted
+ * entirely).
+ */
+ PyErr_Format(PyExc_RuntimeError,
+ "currently it is only possible to register a DType "
+ "for scalars deriving from `np.generic`, got '%S'.",
+ (PyObject *)pytype);
+ return -1;
}
/* Create the global dictionary if it does not exist */
diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c
index f8bb5ece7..1404c9b68 100644
--- a/numpy/core/src/multiarray/usertypes.c
+++ b/numpy/core/src/multiarray/usertypes.c
@@ -251,12 +251,16 @@ PyArray_RegisterDataType(PyArray_Descr *descr)
PyErr_SetString(PyExc_MemoryError, "RegisterDataType");
return -1;
}
+
userdescrs[NPY_NUMUSERTYPES++] = descr;
+ descr->type_num = typenum;
if (dtypemeta_wrap_legacy_descriptor(descr) < 0) {
+ descr->type_num = -1;
+ NPY_NUMUSERTYPES--;
return -1;
}
- descr->type_num = typenum;
+
return typenum;
}
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 898ceebcd..45cc0b8b3 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -1355,10 +1355,6 @@ class TestUserDType:
# unnecessary restriction, but one that has been around forever:
assert np.dtype(mytype) == np.dtype("O")
- with pytest.raises(RuntimeError):
- # Registering a second time should fail
- create_custom_field_dtype(blueprint, mytype, 0)
-
def test_custom_structured_dtype_errors(self):
class mytype:
pass