diff options
author | Matti Picus <matti.picus@gmail.com> | 2020-10-07 21:40:27 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-07 21:40:27 +0300 |
commit | 4ebbaaeff09aca51e0006e62026a35e020c3b49f (patch) | |
tree | 0f19d56671d9a187a72c754be193de19d8fc890e | |
parent | 2a267e6a49ed68da01761c92deb7c90be207660d (diff) | |
parent | e5f2ce3c6465222f3dfdc186c930b3e049a4d597 (diff) | |
download | numpy-4ebbaaeff09aca51e0006e62026a35e020c3b49f.tar.gz |
Merge pull request #17320 from seberg/relax-object-dtype-with-ref
BUG: allow registration of hard-coded structured dtypes
-rw-r--r-- | numpy/core/src/multiarray/_multiarray_tests.c.src | 68 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dtypemeta.c | 22 | ||||
-rw-r--r-- | numpy/core/src/multiarray/usertypes.c | 35 | ||||
-rw-r--r-- | numpy/core/tests/test_dtype.py | 37 |
4 files changed, 149 insertions, 13 deletions
diff --git a/numpy/core/src/multiarray/_multiarray_tests.c.src b/numpy/core/src/multiarray/_multiarray_tests.c.src index 0bf6958cd..5b6b6dc78 100644 --- a/numpy/core/src/multiarray/_multiarray_tests.c.src +++ b/numpy/core/src/multiarray/_multiarray_tests.c.src @@ -619,6 +619,71 @@ fromstring_null_term_c_api(PyObject *dummy, PyObject *byte_obj) } +/* + * Create a custom field dtype from an existing void one (and test some errors). + * The dtypes created by this function may be not be usable (or even crash + * while using). + */ +static PyObject * +create_custom_field_dtype(PyObject *NPY_UNUSED(mod), PyObject *args) +{ + PyArray_Descr *dtype; + PyTypeObject *scalar_type; + PyTypeObject *original_type = NULL; + int error_path; + + if (!PyArg_ParseTuple(args, "O!O!i", + &PyArrayDescr_Type, &dtype, + &PyType_Type, &scalar_type, + &error_path)) { + return NULL; + } + /* check that the result should be more or less valid */ + if (dtype->type_num != NPY_VOID || dtype->fields == NULL || + !PyDict_CheckExact(dtype->fields) || + PyTuple_Size(dtype->names) != 1 || + !PyDataType_REFCHK(dtype) || + dtype->elsize != sizeof(PyObject *)) { + PyErr_SetString(PyExc_ValueError, + "Bad dtype passed to test function, must be an object " + "containing void with a single field."); + return NULL; + } + + /* Copy and then appropriate this dtype */ + original_type = Py_TYPE(dtype); + dtype = PyArray_DescrNew(dtype); + if (dtype == NULL) { + return NULL; + } + + Py_INCREF(scalar_type); + Py_SETREF(dtype->typeobj, scalar_type); + if (error_path == 1) { + /* Test that we reject this, if fields was not already set */ + Py_SETREF(dtype->fields, NULL); + } + else if (error_path == 2) { + /* + * Test that we reject this if the type is not set to something that + * we are pretty sure can be safely replaced. + */ + Py_SET_TYPE(dtype, scalar_type); + } + else if (error_path != 0) { + PyErr_SetString(PyExc_ValueError, + "invalid error argument to test function."); + } + if (PyArray_RegisterDataType(dtype) < 0) { + /* Fix original type in the error_path == 2 case. */ + Py_SET_TYPE(dtype, original_type); + return NULL; + } + Py_INCREF(dtype); + return (PyObject *)dtype; +} + + /* check no elison for avoided increfs */ static PyObject * incref_elide(PyObject *dummy, PyObject *args) @@ -2090,6 +2155,9 @@ static PyMethodDef Multiarray_TestsMethods[] = { {"fromstring_null_term_c_api", fromstring_null_term_c_api, METH_O, NULL}, + {"create_custom_field_dtype", + create_custom_field_dtype, + METH_VARARGS, NULL}, {"incref_elide", incref_elide, METH_VARARGS, NULL}, diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c index dbe5ba476..84d9dc381 100644 --- a/numpy/core/src/multiarray/dtypemeta.c +++ b/numpy/core/src/multiarray/dtypemeta.c @@ -455,10 +455,28 @@ object_common_dtype( NPY_NO_EXPORT int dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) { - if (Py_TYPE(descr) != &PyArrayDescr_Type) { + int has_type_set = Py_TYPE(descr) == &PyArrayDescr_Type; + + if (!has_type_set) { + /* Accept if the type was filled in from an existing builtin dtype */ + for (int i = 0; i < NPY_NTYPES; i++) { + PyArray_Descr *builtin = PyArray_DescrFromType(i); + has_type_set = Py_TYPE(descr) == Py_TYPE(builtin); + Py_DECREF(builtin); + if (has_type_set) { + break; + } + } + } + if (!has_type_set) { PyErr_Format(PyExc_RuntimeError, "During creation/wrapping of legacy DType, the original class " - "was not PyArrayDescr_Type (it is replaced in this step)."); + "was not of PyArrayDescr_Type (it is replaced in this step). " + "The extension creating a custom DType for type %S must be " + "modified to ensure `Py_TYPE(descr) == &PyArrayDescr_Type` or " + "that of an existing dtype (with the assumption it is just " + "copied over and can be replaced).", + descr->typeobj, Py_TYPE(descr)); return -1; } diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c index 3727567e0..f8bb5ece7 100644 --- a/numpy/core/src/multiarray/usertypes.c +++ b/numpy/core/src/multiarray/usertypes.c @@ -196,7 +196,7 @@ PyArray_RegisterDataType(PyArray_Descr *descr) } } typenum = NPY_USERDEF + NPY_NUMUSERTYPES; - descr->type_num = typenum; + descr->type_num = -1; if (PyDataType_ISUNSIZED(descr)) { PyErr_SetString(PyExc_ValueError, "cannot register a" \ "flexible data-type"); @@ -215,18 +215,31 @@ PyArray_RegisterDataType(PyArray_Descr *descr) " is missing."); return -1; } - if (descr->flags & (NPY_ITEM_IS_POINTER | NPY_ITEM_REFCOUNT)) { - PyErr_SetString(PyExc_ValueError, - "Legacy user dtypes referencing python objects or generally " - "allocated memory are unsupported. " - "If you see this error in an existing, working code base, " - "please contact the NumPy developers."); - return -1; - } if (descr->typeobj == NULL) { PyErr_SetString(PyExc_ValueError, "missing typeobject"); return -1; } + if (descr->flags & (NPY_ITEM_IS_POINTER | NPY_ITEM_REFCOUNT)) { + /* + * User dtype can't actually do reference counting, however, there + * are existing hacks (e.g. xpress), which use a structured one: + * dtype((xpress.var, [('variable', 'O')])) + * so we have to support this. But such a structure must be constant + * (i.e. fixed at registration time, this is the case for `xpress`). + */ + if (descr->names == NULL || descr->fields == NULL || + !PyDict_CheckExact(descr->fields)) { + PyErr_Format(PyExc_ValueError, + "Failed to register dtype for %S: Legacy user dtypes " + "using `NPY_ITEM_IS_POINTER` or `NPY_ITEM_REFCOUNT` are" + "unsupported. It is possible to create such a dtype only " + "if it is a structured dtype with names and fields " + "hardcoded at registration time.\n" + "Please contact the NumPy developers if this used to work " + "but now fails.", descr->typeobj); + return -1; + } + } if (test_deprecated_arrfuncs_members(f) < 0) { return -1; @@ -243,7 +256,7 @@ PyArray_RegisterDataType(PyArray_Descr *descr) if (dtypemeta_wrap_legacy_descriptor(descr) < 0) { return -1; } - + descr->type_num = typenum; return typenum; } @@ -303,7 +316,7 @@ PyArray_RegisterCanCast(PyArray_Descr *descr, int totype, if (!PyTypeNum_ISUSERDEF(descr->type_num) && !PyTypeNum_ISUSERDEF(totype)) { PyErr_SetString(PyExc_ValueError, - "At least one of the types provided to" + "At least one of the types provided to " "RegisterCanCast must be user-defined."); return -1; } diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index 2e2b0dbe2..898ceebcd 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -6,6 +6,7 @@ import gc import numpy as np from numpy.core._rational_tests import rational +from numpy.core._multiarray_tests import create_custom_field_dtype from numpy.testing import ( assert_, assert_equal, assert_array_equal, assert_raises, HAS_REFCOUNT) from numpy.compat import pickle @@ -1338,3 +1339,39 @@ class TestFromCTypes: pair_type = np.dtype('{},{}'.format(*pair)) expected = np.dtype([('f0', pair[0]), ('f1', pair[1])]) assert_equal(pair_type, expected) + + +class TestUserDType: + @pytest.mark.leaks_references(reason="dynamically creates custom dtype.") + def test_custom_structured_dtype(self): + class mytype: + pass + + blueprint = np.dtype([("field", object)]) + dt = create_custom_field_dtype(blueprint, mytype, 0) + assert dt.type == mytype + # We cannot (currently) *create* this dtype with `np.dtype` because + # mytype does not inherit from `np.generic`. This seems like an + # unnecessary restriction, but one that has been around forever: + assert np.dtype(mytype) == np.dtype("O") + + with pytest.raises(RuntimeError): + # Registering a second time should fail + create_custom_field_dtype(blueprint, mytype, 0) + + def test_custom_structured_dtype_errors(self): + class mytype: + pass + + blueprint = np.dtype([("field", object)]) + + with pytest.raises(ValueError): + # Tests what happens if fields are unset during creation + # which is currently rejected due to the containing object + # (see PyArray_RegisterDataType). + create_custom_field_dtype(blueprint, mytype, 1) + + with pytest.raises(RuntimeError): + # Tests that a dtype must have its type field set up to np.dtype + # or in this case a builtin instance. + create_custom_field_dtype(blueprint, mytype, 2) |