diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2020-08-21 16:40:22 -0500 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2020-09-02 12:53:32 -0500 |
commit | b0dd380451085c250f6239dabb30d4e7235ac41b (patch) | |
tree | 4ad1c8bb61d93bc096ace9f75dc813b4e9eecfec | |
parent | e3c84a44b68966ab887a3623a0ff57169e508deb (diff) | |
download | numpy-b0dd380451085c250f6239dabb30d4e7235ac41b.tar.gz |
MAINT: Move dtype instance to DType class cast
This function is better housed in convert_datatype.c and was only
in array_coercion, because we did not use it anywhere else before.
This also somewhat modifies the logic and cleans up use-cases of
it in array_coercion.c
-rw-r--r-- | numpy/core/src/multiarray/array_coercion.c | 61 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 49 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.h | 3 |
3 files changed, 62 insertions, 51 deletions
diff --git a/numpy/core/src/multiarray/array_coercion.c b/numpy/core/src/multiarray/array_coercion.c index ffb5bd632..4e2354991 100644 --- a/numpy/core/src/multiarray/array_coercion.c +++ b/numpy/core/src/multiarray/array_coercion.c @@ -306,51 +306,6 @@ discover_dtype_from_pyobject( } -/* - * This function should probably become public API eventually. At this - * time it is implemented by falling back to `PyArray_AdaptFlexibleDType`. - * We will use `CastingImpl[from, to].adjust_descriptors(...)` to implement - * this logic. - */ -static NPY_INLINE PyArray_Descr * -cast_descriptor_to_fixed_dtype( - PyArray_Descr *descr, PyArray_DTypeMeta *fixed_DType) -{ - if (fixed_DType == NULL) { - /* Nothing to do, we only need to promote the new dtype */ - Py_INCREF(descr); - return descr; - } - - if (!fixed_DType->parametric) { - /* - * Don't actually do anything, the default is always the result - * of any cast. - */ - return fixed_DType->default_descr(fixed_DType); - } - if (PyObject_TypeCheck((PyObject *)descr, (PyTypeObject *)fixed_DType)) { - Py_INCREF(descr); - return descr; - } - /* - * TODO: When this is implemented for all dtypes, the special cases - * can be removed... - */ - if (fixed_DType->legacy && fixed_DType->parametric && - NPY_DTYPE(descr)->legacy) { - PyArray_Descr *flex_dtype = PyArray_DescrFromType(fixed_DType->type_num); - return PyArray_AdaptFlexibleDType(descr, flex_dtype); - } - - PyErr_SetString(PyExc_NotImplementedError, - "Must use casting to find the correct dtype, this is " - "not yet implemented! " - "(It should not be possible to hit this code currently!)"); - return NULL; -} - - /** * Discover the correct descriptor from a known DType class and scalar. * If the fixed DType can discover a dtype instance/descr all is fine, @@ -392,7 +347,7 @@ find_scalar_descriptor( return descr; } - Py_SETREF(descr, cast_descriptor_to_fixed_dtype(descr, fixed_DType)); + Py_SETREF(descr, PyArray_CastDescrToDType(descr, fixed_DType)); return descr; } @@ -727,8 +682,13 @@ find_descriptor_from_array( enum _dtype_discovery_flags flags = 0; *out_descr = NULL; - if (NPY_UNLIKELY(DType != NULL && DType->parametric && - PyArray_ISOBJECT(arr))) { + if (DType == NULL) { + *out_descr = PyArray_DESCR(arr); + Py_INCREF(*out_descr); + return 0; + } + + if (NPY_UNLIKELY(DType->parametric && PyArray_ISOBJECT(arr))) { /* * We have one special case, if (and only if) the input array is of * object DType and the dtype is not fixed already but parametric. @@ -777,7 +737,7 @@ find_descriptor_from_array( } Py_DECREF(iter); } - else if (DType != NULL && NPY_UNLIKELY(DType->type_num == NPY_DATETIME) && + else if (NPY_UNLIKELY(DType->type_num == NPY_DATETIME) && PyArray_ISSTRING(arr)) { /* * TODO: This branch should be deprecated IMO, the workaround is @@ -806,8 +766,7 @@ find_descriptor_from_array( * If this is not an object array figure out the dtype cast, * or simply use the returned DType. */ - *out_descr = cast_descriptor_to_fixed_dtype( - PyArray_DESCR(arr), DType); + *out_descr = PyArray_CastDescrToDType(PyArray_DESCR(arr), DType); if (*out_descr == NULL) { return -1; } diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index d9121707b..0ccc0b538 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -1031,6 +1031,55 @@ ensure_dtype_nbo(PyArray_Descr *type) } } + +/** + * This function should possibly become public API eventually. At this + * time it is implemented by falling back to `PyArray_AdaptFlexibleDType`. + * We will use `CastingImpl[from, to].adjust_descriptors(...)` to implement + * this logic. + * Before that, the API needs to be reviewed though. + * + * WARNING: This function currently does not guarantee that `descr` can + * actually be cast to the given DType. + * + * @param descr The dtype instance to adapt "cast" + * @param given_DType The DType class for which we wish to find an instance able + * to represent `descr`. + * @returns Instance of `given_DType`. If `given_DType` is parametric the + * descr may be adapted to hold it. + */ +NPY_NO_EXPORT PyArray_Descr * +PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType) +{ + if (NPY_DTYPE(descr) == given_DType) { + Py_INCREF(descr); + return descr; + } + if (!given_DType->parametric) { + /* + * Don't actually do anything, the default is always the result + * of any cast. + */ + return given_DType->default_descr(given_DType); + } + if (PyObject_TypeCheck((PyObject *)descr, (PyTypeObject *)given_DType)) { + Py_INCREF(descr); + return descr; + } + + if (!given_DType->legacy) { + PyErr_SetString(PyExc_NotImplementedError, + "Must use casting to find the correct DType for a parametric " + "user DType. This is not yet implemented (this error should be " + "unreachable)."); + return NULL; + } + + PyArray_Descr *flex_dtype = PyArray_DescrNew(given_DType->singleton); + return PyArray_AdaptFlexibleDType(descr, flex_dtype); +} + + /*NUMPY_API * Produces the smallest size and lowest kind type to which both * input types can be cast. diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h index 9b7f39db2..eef2f7313 100644 --- a/numpy/core/src/multiarray/convert_datatype.h +++ b/numpy/core/src/multiarray/convert_datatype.h @@ -49,4 +49,7 @@ npy_set_invalid_cast_error( NPY_NO_EXPORT PyArray_Descr * PyArray_AdaptFlexibleDType(PyArray_Descr *data_dtype, PyArray_Descr *flex_dtype); +NPY_NO_EXPORT PyArray_Descr * +PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType); + #endif |