summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2020-10-07 21:40:27 +0300
committerGitHub <noreply@github.com>2020-10-07 21:40:27 +0300
commit4ebbaaeff09aca51e0006e62026a35e020c3b49f (patch)
tree0f19d56671d9a187a72c754be193de19d8fc890e
parent2a267e6a49ed68da01761c92deb7c90be207660d (diff)
parente5f2ce3c6465222f3dfdc186c930b3e049a4d597 (diff)
downloadnumpy-4ebbaaeff09aca51e0006e62026a35e020c3b49f.tar.gz
Merge pull request #17320 from seberg/relax-object-dtype-with-ref
BUG: allow registration of hard-coded structured dtypes
-rw-r--r--numpy/core/src/multiarray/_multiarray_tests.c.src68
-rw-r--r--numpy/core/src/multiarray/dtypemeta.c22
-rw-r--r--numpy/core/src/multiarray/usertypes.c35
-rw-r--r--numpy/core/tests/test_dtype.py37
4 files changed, 149 insertions, 13 deletions
diff --git a/numpy/core/src/multiarray/_multiarray_tests.c.src b/numpy/core/src/multiarray/_multiarray_tests.c.src
index 0bf6958cd..5b6b6dc78 100644
--- a/numpy/core/src/multiarray/_multiarray_tests.c.src
+++ b/numpy/core/src/multiarray/_multiarray_tests.c.src
@@ -619,6 +619,71 @@ fromstring_null_term_c_api(PyObject *dummy, PyObject *byte_obj)
}
+/*
+ * Create a custom field dtype from an existing void one (and test some errors).
+ * The dtypes created by this function may be not be usable (or even crash
+ * while using).
+ */
+static PyObject *
+create_custom_field_dtype(PyObject *NPY_UNUSED(mod), PyObject *args)
+{
+ PyArray_Descr *dtype;
+ PyTypeObject *scalar_type;
+ PyTypeObject *original_type = NULL;
+ int error_path;
+
+ if (!PyArg_ParseTuple(args, "O!O!i",
+ &PyArrayDescr_Type, &dtype,
+ &PyType_Type, &scalar_type,
+ &error_path)) {
+ return NULL;
+ }
+ /* check that the result should be more or less valid */
+ if (dtype->type_num != NPY_VOID || dtype->fields == NULL ||
+ !PyDict_CheckExact(dtype->fields) ||
+ PyTuple_Size(dtype->names) != 1 ||
+ !PyDataType_REFCHK(dtype) ||
+ dtype->elsize != sizeof(PyObject *)) {
+ PyErr_SetString(PyExc_ValueError,
+ "Bad dtype passed to test function, must be an object "
+ "containing void with a single field.");
+ return NULL;
+ }
+
+ /* Copy and then appropriate this dtype */
+ original_type = Py_TYPE(dtype);
+ dtype = PyArray_DescrNew(dtype);
+ if (dtype == NULL) {
+ return NULL;
+ }
+
+ Py_INCREF(scalar_type);
+ Py_SETREF(dtype->typeobj, scalar_type);
+ if (error_path == 1) {
+ /* Test that we reject this, if fields was not already set */
+ Py_SETREF(dtype->fields, NULL);
+ }
+ else if (error_path == 2) {
+ /*
+ * Test that we reject this if the type is not set to something that
+ * we are pretty sure can be safely replaced.
+ */
+ Py_SET_TYPE(dtype, scalar_type);
+ }
+ else if (error_path != 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "invalid error argument to test function.");
+ }
+ if (PyArray_RegisterDataType(dtype) < 0) {
+ /* Fix original type in the error_path == 2 case. */
+ Py_SET_TYPE(dtype, original_type);
+ return NULL;
+ }
+ Py_INCREF(dtype);
+ return (PyObject *)dtype;
+}
+
+
/* check no elison for avoided increfs */
static PyObject *
incref_elide(PyObject *dummy, PyObject *args)
@@ -2090,6 +2155,9 @@ static PyMethodDef Multiarray_TestsMethods[] = {
{"fromstring_null_term_c_api",
fromstring_null_term_c_api,
METH_O, NULL},
+ {"create_custom_field_dtype",
+ create_custom_field_dtype,
+ METH_VARARGS, NULL},
{"incref_elide",
incref_elide,
METH_VARARGS, NULL},
diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c
index dbe5ba476..84d9dc381 100644
--- a/numpy/core/src/multiarray/dtypemeta.c
+++ b/numpy/core/src/multiarray/dtypemeta.c
@@ -455,10 +455,28 @@ object_common_dtype(
NPY_NO_EXPORT int
dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
{
- if (Py_TYPE(descr) != &PyArrayDescr_Type) {
+ int has_type_set = Py_TYPE(descr) == &PyArrayDescr_Type;
+
+ if (!has_type_set) {
+ /* Accept if the type was filled in from an existing builtin dtype */
+ for (int i = 0; i < NPY_NTYPES; i++) {
+ PyArray_Descr *builtin = PyArray_DescrFromType(i);
+ has_type_set = Py_TYPE(descr) == Py_TYPE(builtin);
+ Py_DECREF(builtin);
+ if (has_type_set) {
+ break;
+ }
+ }
+ }
+ if (!has_type_set) {
PyErr_Format(PyExc_RuntimeError,
"During creation/wrapping of legacy DType, the original class "
- "was not PyArrayDescr_Type (it is replaced in this step).");
+ "was not of PyArrayDescr_Type (it is replaced in this step). "
+ "The extension creating a custom DType for type %S must be "
+ "modified to ensure `Py_TYPE(descr) == &PyArrayDescr_Type` or "
+ "that of an existing dtype (with the assumption it is just "
+ "copied over and can be replaced).",
+ descr->typeobj, Py_TYPE(descr));
return -1;
}
diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c
index 3727567e0..f8bb5ece7 100644
--- a/numpy/core/src/multiarray/usertypes.c
+++ b/numpy/core/src/multiarray/usertypes.c
@@ -196,7 +196,7 @@ PyArray_RegisterDataType(PyArray_Descr *descr)
}
}
typenum = NPY_USERDEF + NPY_NUMUSERTYPES;
- descr->type_num = typenum;
+ descr->type_num = -1;
if (PyDataType_ISUNSIZED(descr)) {
PyErr_SetString(PyExc_ValueError, "cannot register a" \
"flexible data-type");
@@ -215,18 +215,31 @@ PyArray_RegisterDataType(PyArray_Descr *descr)
" is missing.");
return -1;
}
- if (descr->flags & (NPY_ITEM_IS_POINTER | NPY_ITEM_REFCOUNT)) {
- PyErr_SetString(PyExc_ValueError,
- "Legacy user dtypes referencing python objects or generally "
- "allocated memory are unsupported. "
- "If you see this error in an existing, working code base, "
- "please contact the NumPy developers.");
- return -1;
- }
if (descr->typeobj == NULL) {
PyErr_SetString(PyExc_ValueError, "missing typeobject");
return -1;
}
+ if (descr->flags & (NPY_ITEM_IS_POINTER | NPY_ITEM_REFCOUNT)) {
+ /*
+ * User dtype can't actually do reference counting, however, there
+ * are existing hacks (e.g. xpress), which use a structured one:
+ * dtype((xpress.var, [('variable', 'O')]))
+ * so we have to support this. But such a structure must be constant
+ * (i.e. fixed at registration time, this is the case for `xpress`).
+ */
+ if (descr->names == NULL || descr->fields == NULL ||
+ !PyDict_CheckExact(descr->fields)) {
+ PyErr_Format(PyExc_ValueError,
+ "Failed to register dtype for %S: Legacy user dtypes "
+ "using `NPY_ITEM_IS_POINTER` or `NPY_ITEM_REFCOUNT` are"
+ "unsupported. It is possible to create such a dtype only "
+ "if it is a structured dtype with names and fields "
+ "hardcoded at registration time.\n"
+ "Please contact the NumPy developers if this used to work "
+ "but now fails.", descr->typeobj);
+ return -1;
+ }
+ }
if (test_deprecated_arrfuncs_members(f) < 0) {
return -1;
@@ -243,7 +256,7 @@ PyArray_RegisterDataType(PyArray_Descr *descr)
if (dtypemeta_wrap_legacy_descriptor(descr) < 0) {
return -1;
}
-
+ descr->type_num = typenum;
return typenum;
}
@@ -303,7 +316,7 @@ PyArray_RegisterCanCast(PyArray_Descr *descr, int totype,
if (!PyTypeNum_ISUSERDEF(descr->type_num) &&
!PyTypeNum_ISUSERDEF(totype)) {
PyErr_SetString(PyExc_ValueError,
- "At least one of the types provided to"
+ "At least one of the types provided to "
"RegisterCanCast must be user-defined.");
return -1;
}
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 2e2b0dbe2..898ceebcd 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -6,6 +6,7 @@ import gc
import numpy as np
from numpy.core._rational_tests import rational
+from numpy.core._multiarray_tests import create_custom_field_dtype
from numpy.testing import (
assert_, assert_equal, assert_array_equal, assert_raises, HAS_REFCOUNT)
from numpy.compat import pickle
@@ -1338,3 +1339,39 @@ class TestFromCTypes:
pair_type = np.dtype('{},{}'.format(*pair))
expected = np.dtype([('f0', pair[0]), ('f1', pair[1])])
assert_equal(pair_type, expected)
+
+
+class TestUserDType:
+ @pytest.mark.leaks_references(reason="dynamically creates custom dtype.")
+ def test_custom_structured_dtype(self):
+ class mytype:
+ pass
+
+ blueprint = np.dtype([("field", object)])
+ dt = create_custom_field_dtype(blueprint, mytype, 0)
+ assert dt.type == mytype
+ # We cannot (currently) *create* this dtype with `np.dtype` because
+ # mytype does not inherit from `np.generic`. This seems like an
+ # unnecessary restriction, but one that has been around forever:
+ assert np.dtype(mytype) == np.dtype("O")
+
+ with pytest.raises(RuntimeError):
+ # Registering a second time should fail
+ create_custom_field_dtype(blueprint, mytype, 0)
+
+ def test_custom_structured_dtype_errors(self):
+ class mytype:
+ pass
+
+ blueprint = np.dtype([("field", object)])
+
+ with pytest.raises(ValueError):
+ # Tests what happens if fields are unset during creation
+ # which is currently rejected due to the containing object
+ # (see PyArray_RegisterDataType).
+ create_custom_field_dtype(blueprint, mytype, 1)
+
+ with pytest.raises(RuntimeError):
+ # Tests that a dtype must have its type field set up to np.dtype
+ # or in this case a builtin instance.
+ create_custom_field_dtype(blueprint, mytype, 2)