diff options
-rw-r--r-- | numpy/core/src/multiarray/array_coercion.c | 36 | ||||
-rw-r--r-- | numpy/core/src/multiarray/usertypes.c | 6 | ||||
-rw-r--r-- | numpy/core/tests/test_dtype.py | 4 |
3 files changed, 28 insertions, 18 deletions
diff --git a/numpy/core/src/multiarray/array_coercion.c b/numpy/core/src/multiarray/array_coercion.c index 64a06d58b..4831dfca6 100644 --- a/numpy/core/src/multiarray/array_coercion.c +++ b/numpy/core/src/multiarray/array_coercion.c @@ -128,7 +128,9 @@ _prime_global_pytype_to_type_dict(void) /** - * Add a new mapping from a python type to the DType class. + * Add a new mapping from a python type to the DType class. For a user + * defined legacy dtype, this function does nothing unless the pytype + * subclass from `np.generic`. * * This assumes that the DType class is guaranteed to hold on the * python type (this assumption is guaranteed). @@ -145,21 +147,29 @@ _PyArray_MapPyTypeToDType( { PyObject *Dtype_obj = (PyObject *)DType; - if (userdef) { + if (userdef && !PyObject_IsSubclass( + (PyObject *)pytype, (PyObject *)&PyGenericArrType_Type)) { /* - * It seems we did not strictly enforce this in the legacy dtype - * API, but assume that it is always true. Further, this could be - * relaxed in the future. In particular we should have a new - * superclass of ``np.generic`` in order to note enforce the array - * scalar behaviour. + * We expect that user dtypes (for now) will subclass some numpy + * scalar class to allow automatic discovery. */ - if (!PyObject_IsSubclass((PyObject *)pytype, (PyObject *)&PyGenericArrType_Type)) { - PyErr_Format(PyExc_RuntimeError, - "currently it is only possible to register a DType " - "for scalars deriving from `np.generic`, got '%S'.", - (PyObject *)pytype); - return -1; + if (DType->legacy) { + /* + * For legacy user dtypes, discovery relied on subclassing, but + * arbitrary type objects are supported, so do nothing. + */ + return 0; } + /* + * We currently enforce that user DTypes subclass from `np.generic` + * (this should become a `np.generic` base class and may be lifted + * entirely). + */ + PyErr_Format(PyExc_RuntimeError, + "currently it is only possible to register a DType " + "for scalars deriving from `np.generic`, got '%S'.", + (PyObject *)pytype); + return -1; } /* Create the global dictionary if it does not exist */ diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c index f8bb5ece7..1404c9b68 100644 --- a/numpy/core/src/multiarray/usertypes.c +++ b/numpy/core/src/multiarray/usertypes.c @@ -251,12 +251,16 @@ PyArray_RegisterDataType(PyArray_Descr *descr) PyErr_SetString(PyExc_MemoryError, "RegisterDataType"); return -1; } + userdescrs[NPY_NUMUSERTYPES++] = descr; + descr->type_num = typenum; if (dtypemeta_wrap_legacy_descriptor(descr) < 0) { + descr->type_num = -1; + NPY_NUMUSERTYPES--; return -1; } - descr->type_num = typenum; + return typenum; } diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index 898ceebcd..45cc0b8b3 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -1355,10 +1355,6 @@ class TestUserDType: # unnecessary restriction, but one that has been around forever: assert np.dtype(mytype) == np.dtype("O") - with pytest.raises(RuntimeError): - # Registering a second time should fail - create_custom_field_dtype(blueprint, mytype, 0) - def test_custom_structured_dtype_errors(self): class mytype: pass |