summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-08-21 16:40:22 -0500
committerSebastian Berg <sebastian@sipsolutions.net>2020-09-02 12:53:32 -0500
commitb0dd380451085c250f6239dabb30d4e7235ac41b (patch)
tree4ad1c8bb61d93bc096ace9f75dc813b4e9eecfec
parente3c84a44b68966ab887a3623a0ff57169e508deb (diff)
downloadnumpy-b0dd380451085c250f6239dabb30d4e7235ac41b.tar.gz
MAINT: Move dtype instance to DType class cast
This function is better housed in convert_datatype.c and was only in array_coercion, because we did not use it anywhere else before. This also somewhat modifies the logic and cleans up use-cases of it in array_coercion.c
-rw-r--r--numpy/core/src/multiarray/array_coercion.c61
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c49
-rw-r--r--numpy/core/src/multiarray/convert_datatype.h3
3 files changed, 62 insertions, 51 deletions
diff --git a/numpy/core/src/multiarray/array_coercion.c b/numpy/core/src/multiarray/array_coercion.c
index ffb5bd632..4e2354991 100644
--- a/numpy/core/src/multiarray/array_coercion.c
+++ b/numpy/core/src/multiarray/array_coercion.c
@@ -306,51 +306,6 @@ discover_dtype_from_pyobject(
}
-/*
- * This function should probably become public API eventually. At this
- * time it is implemented by falling back to `PyArray_AdaptFlexibleDType`.
- * We will use `CastingImpl[from, to].adjust_descriptors(...)` to implement
- * this logic.
- */
-static NPY_INLINE PyArray_Descr *
-cast_descriptor_to_fixed_dtype(
- PyArray_Descr *descr, PyArray_DTypeMeta *fixed_DType)
-{
- if (fixed_DType == NULL) {
- /* Nothing to do, we only need to promote the new dtype */
- Py_INCREF(descr);
- return descr;
- }
-
- if (!fixed_DType->parametric) {
- /*
- * Don't actually do anything, the default is always the result
- * of any cast.
- */
- return fixed_DType->default_descr(fixed_DType);
- }
- if (PyObject_TypeCheck((PyObject *)descr, (PyTypeObject *)fixed_DType)) {
- Py_INCREF(descr);
- return descr;
- }
- /*
- * TODO: When this is implemented for all dtypes, the special cases
- * can be removed...
- */
- if (fixed_DType->legacy && fixed_DType->parametric &&
- NPY_DTYPE(descr)->legacy) {
- PyArray_Descr *flex_dtype = PyArray_DescrFromType(fixed_DType->type_num);
- return PyArray_AdaptFlexibleDType(descr, flex_dtype);
- }
-
- PyErr_SetString(PyExc_NotImplementedError,
- "Must use casting to find the correct dtype, this is "
- "not yet implemented! "
- "(It should not be possible to hit this code currently!)");
- return NULL;
-}
-
-
/**
* Discover the correct descriptor from a known DType class and scalar.
* If the fixed DType can discover a dtype instance/descr all is fine,
@@ -392,7 +347,7 @@ find_scalar_descriptor(
return descr;
}
- Py_SETREF(descr, cast_descriptor_to_fixed_dtype(descr, fixed_DType));
+ Py_SETREF(descr, PyArray_CastDescrToDType(descr, fixed_DType));
return descr;
}
@@ -727,8 +682,13 @@ find_descriptor_from_array(
enum _dtype_discovery_flags flags = 0;
*out_descr = NULL;
- if (NPY_UNLIKELY(DType != NULL && DType->parametric &&
- PyArray_ISOBJECT(arr))) {
+ if (DType == NULL) {
+ *out_descr = PyArray_DESCR(arr);
+ Py_INCREF(*out_descr);
+ return 0;
+ }
+
+ if (NPY_UNLIKELY(DType->parametric && PyArray_ISOBJECT(arr))) {
/*
* We have one special case, if (and only if) the input array is of
* object DType and the dtype is not fixed already but parametric.
@@ -777,7 +737,7 @@ find_descriptor_from_array(
}
Py_DECREF(iter);
}
- else if (DType != NULL && NPY_UNLIKELY(DType->type_num == NPY_DATETIME) &&
+ else if (NPY_UNLIKELY(DType->type_num == NPY_DATETIME) &&
PyArray_ISSTRING(arr)) {
/*
* TODO: This branch should be deprecated IMO, the workaround is
@@ -806,8 +766,7 @@ find_descriptor_from_array(
* If this is not an object array figure out the dtype cast,
* or simply use the returned DType.
*/
- *out_descr = cast_descriptor_to_fixed_dtype(
- PyArray_DESCR(arr), DType);
+ *out_descr = PyArray_CastDescrToDType(PyArray_DESCR(arr), DType);
if (*out_descr == NULL) {
return -1;
}
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index d9121707b..0ccc0b538 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -1031,6 +1031,55 @@ ensure_dtype_nbo(PyArray_Descr *type)
}
}
+
+/**
+ * This function should possibly become public API eventually. At this
+ * time it is implemented by falling back to `PyArray_AdaptFlexibleDType`.
+ * We will use `CastingImpl[from, to].adjust_descriptors(...)` to implement
+ * this logic.
+ * Before that, the API needs to be reviewed though.
+ *
+ * WARNING: This function currently does not guarantee that `descr` can
+ * actually be cast to the given DType.
+ *
+ * @param descr The dtype instance to adapt "cast"
+ * @param given_DType The DType class for which we wish to find an instance able
+ * to represent `descr`.
+ * @returns Instance of `given_DType`. If `given_DType` is parametric the
+ * descr may be adapted to hold it.
+ */
+NPY_NO_EXPORT PyArray_Descr *
+PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType)
+{
+ if (NPY_DTYPE(descr) == given_DType) {
+ Py_INCREF(descr);
+ return descr;
+ }
+ if (!given_DType->parametric) {
+ /*
+ * Don't actually do anything, the default is always the result
+ * of any cast.
+ */
+ return given_DType->default_descr(given_DType);
+ }
+ if (PyObject_TypeCheck((PyObject *)descr, (PyTypeObject *)given_DType)) {
+ Py_INCREF(descr);
+ return descr;
+ }
+
+ if (!given_DType->legacy) {
+ PyErr_SetString(PyExc_NotImplementedError,
+ "Must use casting to find the correct DType for a parametric "
+ "user DType. This is not yet implemented (this error should be "
+ "unreachable).");
+ return NULL;
+ }
+
+ PyArray_Descr *flex_dtype = PyArray_DescrNew(given_DType->singleton);
+ return PyArray_AdaptFlexibleDType(descr, flex_dtype);
+}
+
+
/*NUMPY_API
* Produces the smallest size and lowest kind type to which both
* input types can be cast.
diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h
index 9b7f39db2..eef2f7313 100644
--- a/numpy/core/src/multiarray/convert_datatype.h
+++ b/numpy/core/src/multiarray/convert_datatype.h
@@ -49,4 +49,7 @@ npy_set_invalid_cast_error(
NPY_NO_EXPORT PyArray_Descr *
PyArray_AdaptFlexibleDType(PyArray_Descr *data_dtype, PyArray_Descr *flex_dtype);
+NPY_NO_EXPORT PyArray_Descr *
+PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType);
+
#endif