summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2020-09-22 15:54:18 +0300
committerGitHub <noreply@github.com>2020-09-22 15:54:18 +0300
commit0d2536650ff1b750a9008cfd2893dcc1db58f94f (patch)
treec4c54da90acf8e3919119c704fcaf4b620678bb5
parentf981a7ff4b02204f215eb089ce77f9a3f3906b1e (diff)
parentb40f6bb22d7e71533e0b450493530e8fdd08afa5 (diff)
downloadnumpy-0d2536650ff1b750a9008cfd2893dcc1db58f94f.tar.gz
Merge pull request #17137 from seberg/restructure-dtype-promotion
API,MAINT: Rewrite promotion using common DType and common instance
-rw-r--r--numpy/core/include/numpy/ndarraytypes.h6
-rw-r--r--numpy/core/src/multiarray/array_coercion.c73
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c413
-rw-r--r--numpy/core/src/multiarray/convert_datatype.h9
-rw-r--r--numpy/core/src/multiarray/dtypemeta.c145
-rw-r--r--numpy/core/src/multiarray/dtypemeta.h16
-rw-r--r--numpy/core/src/multiarray/usertypes.c121
-rw-r--r--numpy/core/src/multiarray/usertypes.h4
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.c15
-rw-r--r--numpy/core/tests/test_numeric.py57
10 files changed, 451 insertions, 408 deletions
diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h
index 6eca4afdb..6bf54938f 100644
--- a/numpy/core/include/numpy/ndarraytypes.h
+++ b/numpy/core/include/numpy/ndarraytypes.h
@@ -1839,6 +1839,10 @@ typedef void (PyDataMem_EventHookFunc)(void *inp, void *outp, size_t size,
PyArray_DTypeMeta *cls, PyTypeObject *obj);
typedef PyArray_Descr *(default_descr_function)(PyArray_DTypeMeta *cls);
+ typedef PyArray_DTypeMeta *(common_dtype_function)(
+ PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtyep2);
+ typedef PyArray_Descr *(common_instance_function)(
+ PyArray_Descr *dtype1, PyArray_Descr *dtyep2);
/*
* While NumPy DTypes would not need to be heap types the plan is to
@@ -1894,6 +1898,8 @@ typedef void (PyDataMem_EventHookFunc)(void *inp, void *outp, size_t size,
discover_descr_from_pyobject_function *discover_descr_from_pyobject;
is_known_scalar_type_function *is_known_scalar_type;
default_descr_function *default_descr;
+ common_dtype_function *common_dtype;
+ common_instance_function *common_instance;
};
#endif /* NPY_INTERNAL_BUILD */
diff --git a/numpy/core/src/multiarray/array_coercion.c b/numpy/core/src/multiarray/array_coercion.c
index 3f3fd1387..aae8d5141 100644
--- a/numpy/core/src/multiarray/array_coercion.c
+++ b/numpy/core/src/multiarray/array_coercion.c
@@ -306,51 +306,6 @@ discover_dtype_from_pyobject(
}
-/*
- * This function should probably become public API eventually. At this
- * time it is implemented by falling back to `PyArray_AdaptFlexibleDType`.
- * We will use `CastingImpl[from, to].adjust_descriptors(...)` to implement
- * this logic.
- */
-static NPY_INLINE PyArray_Descr *
-cast_descriptor_to_fixed_dtype(
- PyArray_Descr *descr, PyArray_DTypeMeta *fixed_DType)
-{
- if (fixed_DType == NULL) {
- /* Nothing to do, we only need to promote the new dtype */
- Py_INCREF(descr);
- return descr;
- }
-
- if (!fixed_DType->parametric) {
- /*
- * Don't actually do anything, the default is always the result
- * of any cast.
- */
- return fixed_DType->default_descr(fixed_DType);
- }
- if (PyObject_TypeCheck((PyObject *)descr, (PyTypeObject *)fixed_DType)) {
- Py_INCREF(descr);
- return descr;
- }
- /*
- * TODO: When this is implemented for all dtypes, the special cases
- * can be removed...
- */
- if (fixed_DType->legacy && fixed_DType->parametric &&
- NPY_DTYPE(descr)->legacy) {
- PyArray_Descr *flex_dtype = PyArray_DescrFromType(fixed_DType->type_num);
- return PyArray_AdaptFlexibleDType(descr, flex_dtype);
- }
-
- PyErr_SetString(PyExc_NotImplementedError,
- "Must use casting to find the correct dtype, this is "
- "not yet implemented! "
- "(It should not be possible to hit this code currently!)");
- return NULL;
-}
-
-
/**
* Discover the correct descriptor from a known DType class and scalar.
* If the fixed DType can discover a dtype instance/descr all is fine,
@@ -392,7 +347,7 @@ find_scalar_descriptor(
return descr;
}
- Py_SETREF(descr, cast_descriptor_to_fixed_dtype(descr, fixed_DType));
+ Py_SETREF(descr, PyArray_CastDescrToDType(descr, fixed_DType));
return descr;
}
@@ -727,8 +682,13 @@ find_descriptor_from_array(
enum _dtype_discovery_flags flags = 0;
*out_descr = NULL;
- if (NPY_UNLIKELY(DType != NULL && DType->parametric &&
- PyArray_ISOBJECT(arr))) {
+ if (DType == NULL) {
+ *out_descr = PyArray_DESCR(arr);
+ Py_INCREF(*out_descr);
+ return 0;
+ }
+
+ if (NPY_UNLIKELY(DType->parametric && PyArray_ISOBJECT(arr))) {
/*
* We have one special case, if (and only if) the input array is of
* object DType and the dtype is not fixed already but parametric.
@@ -777,7 +737,7 @@ find_descriptor_from_array(
}
Py_DECREF(iter);
}
- else if (DType != NULL && NPY_UNLIKELY(DType->type_num == NPY_DATETIME) &&
+ else if (NPY_UNLIKELY(DType->type_num == NPY_DATETIME) &&
PyArray_ISSTRING(arr)) {
/*
* TODO: This branch should be deprecated IMO, the workaround is
@@ -806,8 +766,7 @@ find_descriptor_from_array(
* If this is not an object array figure out the dtype cast,
* or simply use the returned DType.
*/
- *out_descr = cast_descriptor_to_fixed_dtype(
- PyArray_DESCR(arr), DType);
+ *out_descr = PyArray_CastDescrToDType(PyArray_DESCR(arr), DType);
if (*out_descr == NULL) {
return -1;
}
@@ -1325,15 +1284,9 @@ PyArray_DiscoverDTypeAndShape(
* the correct default.
*/
if (fixed_DType != NULL) {
- if (fixed_DType->default_descr == NULL) {
- Py_INCREF(fixed_DType->singleton);
- *out_descr = fixed_DType->singleton;
- }
- else {
- *out_descr = fixed_DType->default_descr(fixed_DType);
- if (*out_descr == NULL) {
- goto fail;
- }
+ *out_descr = fixed_DType->default_descr(fixed_DType);
+ if (*out_descr == NULL) {
+ goto fail;
}
}
}
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index d9121707b..f700bdc99 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -1019,7 +1019,7 @@ promote_types(PyArray_Descr *type1, PyArray_Descr *type2,
* Returns a new reference to type if it is already NBO, otherwise
* returns a copy converted to NBO.
*/
-static PyArray_Descr *
+NPY_NO_EXPORT PyArray_Descr *
ensure_dtype_nbo(PyArray_Descr *type)
{
if (PyArray_ISNBO(type->byteorder)) {
@@ -1031,327 +1031,148 @@ ensure_dtype_nbo(PyArray_Descr *type)
}
}
-/*NUMPY_API
- * Produces the smallest size and lowest kind type to which both
- * input types can be cast.
+
+/**
+ * This function should possibly become public API eventually. At this
+ * time it is implemented by falling back to `PyArray_AdaptFlexibleDType`.
+ * We will use `CastingImpl[from, to].adjust_descriptors(...)` to implement
+ * this logic.
+ * Before that, the API needs to be reviewed though.
+ *
+ * WARNING: This function currently does not guarantee that `descr` can
+ * actually be cast to the given DType.
+ *
+ * @param descr The dtype instance to adapt "cast"
+ * @param given_DType The DType class for which we wish to find an instance able
+ * to represent `descr`.
+ * @returns Instance of `given_DType`. If `given_DType` is parametric the
+ * descr may be adapted to hold it.
*/
NPY_NO_EXPORT PyArray_Descr *
-PyArray_PromoteTypes(PyArray_Descr *type1, PyArray_Descr *type2)
+PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType)
{
- int type_num1, type_num2, ret_type_num;
-
- /*
- * Fast path for identical dtypes.
- *
- * Non-native-byte-order types are converted to native ones below, so we
- * can't quit early.
- */
- if (type1 == type2 && PyArray_ISNBO(type1->byteorder)) {
- Py_INCREF(type1);
- return type1;
+ if (NPY_DTYPE(descr) == given_DType) {
+ Py_INCREF(descr);
+ return descr;
}
-
- type_num1 = type1->type_num;
- type_num2 = type2->type_num;
-
- /* If they're built-in types, use the promotion table */
- if (type_num1 < NPY_NTYPES && type_num2 < NPY_NTYPES) {
- ret_type_num = _npy_type_promotion_table[type_num1][type_num2];
+ if (!given_DType->parametric) {
/*
- * The table doesn't handle string/unicode/void/datetime/timedelta,
- * so check the result
+ * Don't actually do anything, the default is always the result
+ * of any cast.
*/
- if (ret_type_num >= 0) {
- return PyArray_DescrFromType(ret_type_num);
- }
+ return given_DType->default_descr(given_DType);
+ }
+ if (PyObject_TypeCheck((PyObject *)descr, (PyTypeObject *)given_DType)) {
+ Py_INCREF(descr);
+ return descr;
}
- /* If one or both are user defined, calculate it */
- else {
- int skind1 = NPY_NOSCALAR, skind2 = NPY_NOSCALAR, skind;
- if (PyArray_CanCastTo(type2, type1)) {
- /* Promoted types are always native byte order */
- return ensure_dtype_nbo(type1);
- }
- else if (PyArray_CanCastTo(type1, type2)) {
- /* Promoted types are always native byte order */
- return ensure_dtype_nbo(type2);
- }
+ if (!given_DType->legacy) {
+ PyErr_SetString(PyExc_NotImplementedError,
+ "Must use casting to find the correct DType for a parametric "
+ "user DType. This is not yet implemented (this error should be "
+ "unreachable).");
+ return NULL;
+ }
- /* Convert the 'kind' char into a scalar kind */
- switch (type1->kind) {
- case 'b':
- skind1 = NPY_BOOL_SCALAR;
- break;
- case 'u':
- skind1 = NPY_INTPOS_SCALAR;
- break;
- case 'i':
- skind1 = NPY_INTNEG_SCALAR;
- break;
- case 'f':
- skind1 = NPY_FLOAT_SCALAR;
- break;
- case 'c':
- skind1 = NPY_COMPLEX_SCALAR;
- break;
- }
- switch (type2->kind) {
- case 'b':
- skind2 = NPY_BOOL_SCALAR;
- break;
- case 'u':
- skind2 = NPY_INTPOS_SCALAR;
- break;
- case 'i':
- skind2 = NPY_INTNEG_SCALAR;
- break;
- case 'f':
- skind2 = NPY_FLOAT_SCALAR;
- break;
- case 'c':
- skind2 = NPY_COMPLEX_SCALAR;
- break;
- }
+ PyArray_Descr *flex_dtype = PyArray_DescrNew(given_DType->singleton);
+ return PyArray_AdaptFlexibleDType(descr, flex_dtype);
+}
- /* If both are scalars, there may be a promotion possible */
- if (skind1 != NPY_NOSCALAR && skind2 != NPY_NOSCALAR) {
- /* Start with the larger scalar kind */
- skind = (skind1 > skind2) ? skind1 : skind2;
- ret_type_num = _npy_smallest_type_of_kind_table[skind];
+/**
+ * This function defines the common DType operator.
+ *
+ * Note that the common DType will not be "object" (unless one of the dtypes
+ * is object), even though object can technically represent all values
+ * correctly.
+ *
+ * TODO: Before exposure, we should review the return value (e.g. no error
+ * when no common DType is found).
+ *
+ * @param dtype1 DType class to find the common type for.
+ * @param dtype2 Second DType class.
+ * @return The common DType or NULL with an error set
+ */
+NPY_NO_EXPORT PyArray_DTypeMeta *
+PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2)
+{
+ if (dtype1 == dtype2) {
+ Py_INCREF(dtype1);
+ return dtype1;
+ }
- for (;;) {
+ PyArray_DTypeMeta *common_dtype;
- /* If there is no larger type of this kind, try a larger kind */
- if (ret_type_num < 0) {
- ++skind;
- /* Use -1 to signal no promoted type found */
- if (skind < NPY_NSCALARKINDS) {
- ret_type_num = _npy_smallest_type_of_kind_table[skind];
- }
- else {
- break;
- }
- }
+ common_dtype = dtype1->common_dtype(dtype1, dtype2);
+ if (common_dtype == (PyArray_DTypeMeta *)Py_NotImplemented) {
+ Py_DECREF(common_dtype);
+ common_dtype = dtype2->common_dtype(dtype2, dtype1);
+ }
+ if (common_dtype == NULL) {
+ return NULL;
+ }
+ if (common_dtype == (PyArray_DTypeMeta *)Py_NotImplemented) {
+ Py_DECREF(Py_NotImplemented);
+ PyErr_Format(PyExc_TypeError,
+ "The DTypes %S and %S do not have a common DType. "
+ "For example they cannot be stored in a single array unless "
+ "the dtype is `object`.", dtype1, dtype2);
+ return NULL;
+ }
+ return common_dtype;
+}
- /* If we found a type to which we can promote both, done! */
- if (PyArray_CanCastSafely(type_num1, ret_type_num) &&
- PyArray_CanCastSafely(type_num2, ret_type_num)) {
- return PyArray_DescrFromType(ret_type_num);
- }
- /* Try the next larger type of this kind */
- ret_type_num = _npy_next_larger_type_table[ret_type_num];
- }
+/*NUMPY_API
+ * Produces the smallest size and lowest kind type to which both
+ * input types can be cast.
+ */
+NPY_NO_EXPORT PyArray_Descr *
+PyArray_PromoteTypes(PyArray_Descr *type1, PyArray_Descr *type2)
+{
+ PyArray_DTypeMeta *common_dtype;
+ PyArray_Descr *res;
- }
+ /* Fast path for identical inputs (NOTE: This path preserves metadata!) */
+ if (type1 == type2 && PyArray_ISNBO(type1->byteorder)) {
+ Py_INCREF(type1);
+ return type1;
+ }
- PyErr_SetString(PyExc_TypeError,
- "invalid type promotion with custom data type");
+ common_dtype = PyArray_CommonDType(NPY_DTYPE(type1), NPY_DTYPE(type2));
+ if (common_dtype == NULL) {
return NULL;
}
- switch (type_num1) {
- /* BOOL can convert to anything except datetime/void */
- case NPY_BOOL:
- if (type_num2 == NPY_STRING || type_num2 == NPY_UNICODE) {
- int char_size = 1;
- if (type_num2 == NPY_UNICODE) {
- char_size = 4;
- }
- if (type2->elsize < 5 * char_size) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type2);
- ret = ensure_dtype_nbo(temp);
- ret->elsize = 5 * char_size;
- Py_DECREF(temp);
- return ret;
- }
- return ensure_dtype_nbo(type2);
- }
- else if (type_num2 != NPY_DATETIME && type_num2 != NPY_VOID) {
- return ensure_dtype_nbo(type2);
- }
- break;
- /* For strings and unicodes, take the larger size */
- case NPY_STRING:
- if (type_num2 == NPY_STRING) {
- if (type1->elsize > type2->elsize) {
- return ensure_dtype_nbo(type1);
- }
- else {
- return ensure_dtype_nbo(type2);
- }
- }
- else if (type_num2 == NPY_UNICODE) {
- if (type2->elsize >= type1->elsize * 4) {
- return ensure_dtype_nbo(type2);
- }
- else {
- PyArray_Descr *d = PyArray_DescrNewFromType(NPY_UNICODE);
- if (d == NULL) {
- return NULL;
- }
- d->elsize = type1->elsize * 4;
- return d;
- }
- }
- /* Allow NUMBER -> STRING */
- else if (PyTypeNum_ISNUMBER(type_num2)) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type1);
- PyDataType_MAKEUNSIZED(temp);
-
- temp = PyArray_AdaptFlexibleDType(type2, temp);
- if (temp == NULL) {
- return NULL;
- }
- if (temp->elsize > type1->elsize) {
- ret = ensure_dtype_nbo(temp);
- }
- else {
- ret = ensure_dtype_nbo(type1);
- }
- Py_DECREF(temp);
- return ret;
- }
- break;
- case NPY_UNICODE:
- if (type_num2 == NPY_UNICODE) {
- if (type1->elsize > type2->elsize) {
- return ensure_dtype_nbo(type1);
- }
- else {
- return ensure_dtype_nbo(type2);
- }
- }
- else if (type_num2 == NPY_STRING) {
- if (type1->elsize >= type2->elsize * 4) {
- return ensure_dtype_nbo(type1);
- }
- else {
- PyArray_Descr *d = PyArray_DescrNewFromType(NPY_UNICODE);
- if (d == NULL) {
- return NULL;
- }
- d->elsize = type2->elsize * 4;
- return d;
- }
- }
- /* Allow NUMBER -> UNICODE */
- else if (PyTypeNum_ISNUMBER(type_num2)) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type1);
- PyDataType_MAKEUNSIZED(temp);
- temp = PyArray_AdaptFlexibleDType(type2, temp);
- if (temp == NULL) {
- return NULL;
- }
- if (temp->elsize > type1->elsize) {
- ret = ensure_dtype_nbo(temp);
- }
- else {
- ret = ensure_dtype_nbo(type1);
- }
- Py_DECREF(temp);
- return ret;
- }
- break;
- case NPY_DATETIME:
- case NPY_TIMEDELTA:
- if (type_num2 == NPY_DATETIME || type_num2 == NPY_TIMEDELTA) {
- return datetime_type_promotion(type1, type2);
- }
- break;
+ if (!common_dtype->parametric) {
+ res = common_dtype->default_descr(common_dtype);
+ Py_DECREF(common_dtype);
+ return res;
}
- switch (type_num2) {
- /* BOOL can convert to almost anything */
- case NPY_BOOL:
- if (type_num2 == NPY_STRING || type_num2 == NPY_UNICODE) {
- int char_size = 1;
- if (type_num2 == NPY_UNICODE) {
- char_size = 4;
- }
- if (type2->elsize < 5 * char_size) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type2);
- ret = ensure_dtype_nbo(temp);
- ret->elsize = 5 * char_size;
- Py_DECREF(temp);
- return ret;
- }
- return ensure_dtype_nbo(type2);
- }
- else if (type_num1 != NPY_DATETIME && type_num1 != NPY_TIMEDELTA &&
- type_num1 != NPY_VOID) {
- return ensure_dtype_nbo(type1);
- }
- break;
- case NPY_STRING:
- /* Allow NUMBER -> STRING */
- if (PyTypeNum_ISNUMBER(type_num1)) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type2);
- PyDataType_MAKEUNSIZED(temp);
- temp = PyArray_AdaptFlexibleDType(type1, temp);
- if (temp == NULL) {
- return NULL;
- }
- if (temp->elsize > type2->elsize) {
- ret = ensure_dtype_nbo(temp);
- }
- else {
- ret = ensure_dtype_nbo(type2);
- }
- Py_DECREF(temp);
- return ret;
- }
- break;
- case NPY_UNICODE:
- /* Allow NUMBER -> UNICODE */
- if (PyTypeNum_ISNUMBER(type_num1)) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type2);
- PyDataType_MAKEUNSIZED(temp);
- temp = PyArray_AdaptFlexibleDType(type1, temp);
- if (temp == NULL) {
- return NULL;
- }
- if (temp->elsize > type2->elsize) {
- ret = ensure_dtype_nbo(temp);
- }
- else {
- ret = ensure_dtype_nbo(type2);
- }
- Py_DECREF(temp);
- return ret;
- }
- break;
- case NPY_TIMEDELTA:
- if (PyTypeNum_ISSIGNED(type_num1)) {
- return ensure_dtype_nbo(type2);
- }
- break;
+ /* Cast the input types to the common DType if necessary */
+ type1 = PyArray_CastDescrToDType(type1, common_dtype);
+ if (type1 == NULL) {
+ Py_DECREF(common_dtype);
+ return NULL;
}
-
- /* For types equivalent up to endianness, can return either */
- if (PyArray_CanCastTypeTo(type1, type2, NPY_EQUIV_CASTING)) {
- return ensure_dtype_nbo(type1);
+ type2 = PyArray_CastDescrToDType(type2, common_dtype);
+ if (type2 == NULL) {
+ Py_DECREF(type1);
+ Py_DECREF(common_dtype);
+ return NULL;
}
- /* TODO: Also combine fields, subarrays, strings, etc */
-
/*
- printf("invalid type promotion: ");
- PyObject_Print(type1, stdout, 0);
- printf(" ");
- PyObject_Print(type2, stdout, 0);
- printf("\n");
- */
- PyErr_SetString(PyExc_TypeError, "invalid type promotion");
- return NULL;
+ * And find the common instance of the two inputs
+ * NOTE: Common instance preserves metadata (normally and of one input)
+ */
+ res = common_dtype->common_instance(type1, type2);
+ Py_DECREF(type1);
+ Py_DECREF(type2);
+ Py_DECREF(common_dtype);
+ return res;
}
/*
diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h
index 9b7f39db2..a2b36b497 100644
--- a/numpy/core/src/multiarray/convert_datatype.h
+++ b/numpy/core/src/multiarray/convert_datatype.h
@@ -10,6 +10,9 @@ PyArray_ObjectType(PyObject *op, int minimum_type);
NPY_NO_EXPORT PyArrayObject **
PyArray_ConvertToCommonType(PyObject *op, int *retn);
+NPY_NO_EXPORT PyArray_DTypeMeta *
+PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2);
+
NPY_NO_EXPORT int
PyArray_ValidType(int type);
@@ -18,6 +21,9 @@ NPY_NO_EXPORT npy_bool
can_cast_scalar_to(PyArray_Descr *scal_type, char *scal_data,
PyArray_Descr *to, NPY_CASTING casting);
+NPY_NO_EXPORT PyArray_Descr *
+ensure_dtype_nbo(PyArray_Descr *type);
+
NPY_NO_EXPORT int
should_use_min_scalar(npy_intp narrs, PyArrayObject **arr,
npy_intp ndtypes, PyArray_Descr **dtypes);
@@ -49,4 +55,7 @@ npy_set_invalid_cast_error(
NPY_NO_EXPORT PyArray_Descr *
PyArray_AdaptFlexibleDType(PyArray_Descr *data_dtype, PyArray_Descr *flex_dtype);
+NPY_NO_EXPORT PyArray_Descr *
+PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType);
+
#endif
diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c
index d07dc700d..531f746d8 100644
--- a/numpy/core/src/multiarray/dtypemeta.c
+++ b/numpy/core/src/multiarray/dtypemeta.c
@@ -15,6 +15,9 @@
#include "dtypemeta.h"
#include "_datetime.h"
#include "array_coercion.h"
+#include "scalartypes.h"
+#include "convert_datatype.h"
+#include "usertypes.h"
static void
@@ -194,6 +197,14 @@ discover_datetime_and_timedelta_from_pyobject(
static PyArray_Descr *
+nonparametric_default_descr(PyArray_DTypeMeta *cls)
+{
+ Py_INCREF(cls->singleton);
+ return cls->singleton;
+}
+
+
+static PyArray_Descr *
flexible_default_descr(PyArray_DTypeMeta *cls)
{
PyArray_Descr *res = PyArray_DescrNewFromType(cls->type_num);
@@ -208,6 +219,34 @@ flexible_default_descr(PyArray_DTypeMeta *cls)
}
+static PyArray_Descr *
+string_unicode_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2)
+{
+ if (descr1->elsize >= descr2->elsize) {
+ return ensure_dtype_nbo(descr1);
+ }
+ else {
+ return ensure_dtype_nbo(descr2);
+ }
+}
+
+
+static PyArray_Descr *
+void_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2)
+{
+ /*
+ * We currently do not support promotion of void types unless they
+ * are equivalent.
+ */
+ if (!PyArray_CanCastTypeTo(descr1, descr2, NPY_EQUIV_CASTING)) {
+ PyErr_SetString(PyExc_TypeError,
+ "invalid type promotion with structured or void datatype(s).");
+ return NULL;
+ }
+ Py_INCREF(descr1);
+ return descr1;
+}
+
static int
python_builtins_are_known_scalar_types(
PyArray_DTypeMeta *NPY_UNUSED(cls), PyTypeObject *pytype)
@@ -281,6 +320,86 @@ string_known_scalar_types(
}
+/*
+ * The following set of functions define the common dtype operator for
+ * the builtin types.
+ */
+static PyArray_DTypeMeta *
+default_builtin_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
+{
+ assert(cls->type_num < NPY_NTYPES);
+ if (!other->legacy || other->type_num > cls->type_num) {
+ /* Let the more generic (larger type number) DType handle this */
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+ }
+
+ /*
+ * Note: The use of the promotion table should probably be revised at
+ * some point. It may be most useful to remove it entirely and then
+ * consider adding a fast path/cache `PyArray_CommonDType()` itself.
+ */
+ int common_num = _npy_type_promotion_table[cls->type_num][other->type_num];
+ if (common_num < 0) {
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+ }
+ return PyArray_DTypeFromTypeNum(common_num);
+}
+
+
+static PyArray_DTypeMeta *
+string_unicode_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
+{
+ assert(cls->type_num < NPY_NTYPES);
+ if (!other->legacy || other->type_num > cls->type_num ||
+ other->type_num == NPY_OBJECT) {
+ /* Let the more generic (larger type number) DType handle this */
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+ }
+ /*
+ * The builtin types are ordered by complexity (aside from object) here.
+ * Arguably, we should not consider numbers and strings "common", but
+ * we currently do.
+ */
+ Py_INCREF(cls);
+ return cls;
+}
+
+static PyArray_DTypeMeta *
+datetime_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
+{
+ if (cls->type_num == NPY_DATETIME && other->type_num == NPY_TIMEDELTA) {
+ /*
+ * TODO: We actually currently do allow promotion here. This is
+ * currently relied on within `np.add(datetime, timedelta)`,
+ * while for concatenation the cast step will fail.
+ */
+ Py_INCREF(cls);
+ return cls;
+ }
+ return default_builtin_common_dtype(cls, other);
+}
+
+
+
+static PyArray_DTypeMeta *
+object_common_dtype(
+ PyArray_DTypeMeta *cls, PyArray_DTypeMeta *NPY_UNUSED(other))
+{
+ /*
+ * The object DType is special in that it can represent everything,
+ * including all potential user DTypes.
+ * One reason to defer (or error) here might be if the other DType
+ * does not support scalars so that e.g. `arr1d[0]` returns a 0-D array
+ * and `arr.astype(object)` would fail. But object casts are special.
+ */
+ Py_INCREF(cls);
+ return cls;
+}
+
+
/**
* This function takes a PyArray_Descr and replaces its base class with
* a newly created dtype subclass (DTypeMeta instances).
@@ -398,14 +517,28 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
dtype_class->f = descr->f;
dtype_class->kind = descr->kind;
- /* Strings and voids have (strange) logic around scalars. */
+ /* Set default functions (correct for most dtypes, override below) */
+ dtype_class->default_descr = nonparametric_default_descr;
+ dtype_class->discover_descr_from_pyobject = (
+ nonparametric_discover_descr_from_pyobject);
dtype_class->is_known_scalar_type = python_builtins_are_known_scalar_types;
+ dtype_class->common_dtype = default_builtin_common_dtype;
+ dtype_class->common_instance = NULL;
- if (PyTypeNum_ISDATETIME(descr->type_num)) {
+ if (PyTypeNum_ISUSERDEF(descr->type_num)) {
+ dtype_class->common_dtype = legacy_userdtype_common_dtype_function;
+ }
+ else if (descr->type_num == NPY_OBJECT) {
+ dtype_class->common_dtype = object_common_dtype;
+ }
+ else if (PyTypeNum_ISDATETIME(descr->type_num)) {
/* Datetimes are flexible, but were not considered previously */
dtype_class->parametric = NPY_TRUE;
+ dtype_class->default_descr = flexible_default_descr;
dtype_class->discover_descr_from_pyobject = (
discover_datetime_and_timedelta_from_pyobject);
+ dtype_class->common_dtype = datetime_common_dtype;
+ dtype_class->common_instance = datetime_type_promotion;
if (descr->type_num == NPY_DATETIME) {
dtype_class->is_known_scalar_type = datetime_known_scalar_types;
}
@@ -416,18 +549,16 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
if (descr->type_num == NPY_VOID) {
dtype_class->discover_descr_from_pyobject = (
void_discover_descr_from_pyobject);
+ dtype_class->common_instance = void_common_instance;
}
else {
dtype_class->is_known_scalar_type = string_known_scalar_types;
dtype_class->discover_descr_from_pyobject = (
string_discover_descr_from_pyobject);
+ dtype_class->common_dtype = string_unicode_common_dtype;
+ dtype_class->common_instance = string_unicode_common_instance;
}
}
- else {
- /* nonparametric case */
- dtype_class->discover_descr_from_pyobject = (
- nonparametric_discover_descr_from_pyobject);
- }
if (_PyArray_MapPyTypeToDType(dtype_class, descr->typeobj,
PyTypeNum_ISUSERDEF(dtype_class->type_num)) < 0) {
diff --git a/numpy/core/src/multiarray/dtypemeta.h b/numpy/core/src/multiarray/dtypemeta.h
index e0909a7eb..83cf7c07e 100644
--- a/numpy/core/src/multiarray/dtypemeta.h
+++ b/numpy/core/src/multiarray/dtypemeta.h
@@ -2,6 +2,22 @@
#define _NPY_DTYPEMETA_H
#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))
+/*
+ * This function will hopefully be phased out or replaced, but was convenient
+ * for incremental implementation of new DTypes based on DTypeMeta.
+ * (Error checking is not required for DescrFromType, assuming that the
+ * type is valid.)
+ */
+static NPY_INLINE PyArray_DTypeMeta *
+PyArray_DTypeFromTypeNum(int typenum)
+{
+ PyArray_Descr *descr = PyArray_DescrFromType(typenum);
+ PyArray_DTypeMeta *dtype = NPY_DTYPE(descr);
+ Py_INCREF(dtype);
+ Py_DECREF(descr);
+ return dtype;
+}
+
NPY_NO_EXPORT int
dtypemeta_wrap_legacy_descriptor(PyArray_Descr *dtypem);
diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c
index d75fd3130..3727567e0 100644
--- a/numpy/core/src/multiarray/usertypes.c
+++ b/numpy/core/src/multiarray/usertypes.c
@@ -38,6 +38,7 @@ maintainer email: oliphant.travis@ieee.org
#include "usertypes.h"
#include "dtypemeta.h"
+#include "scalartypes.h"
NPY_NO_EXPORT PyArray_Descr **userdescrs=NULL;
@@ -350,3 +351,123 @@ PyArray_RegisterCanCast(PyArray_Descr *descr, int totype,
return _append_new(&descr->f->cancastscalarkindto[scalar], totype);
}
}
+
+
+/*
+ * Legacy user DTypes implemented the common DType operation
+ * (as used in type promotion/result_type, and e.g. the type for
+ * concatenation), by using "safe cast" logic.
+ *
+ * New DTypes do have this behaviour generally, but we use can-cast
+ * when legacy user dtypes are involved.
+ */
+NPY_NO_EXPORT PyArray_DTypeMeta *
+legacy_userdtype_common_dtype_function(
+ PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
+{
+ int skind1 = NPY_NOSCALAR, skind2 = NPY_NOSCALAR, skind;
+
+ if (!other->legacy) {
+ /* legacy DTypes can always defer to new style ones */
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+ }
+ /* Defer so that only one of the types handles the cast */
+ if (cls->type_num < other->type_num) {
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+ }
+
+ /* Check whether casting is possible from one type to the other */
+ if (PyArray_CanCastSafely(cls->type_num, other->type_num)) {
+ Py_INCREF(other);
+ return other;
+ }
+ if (PyArray_CanCastSafely(other->type_num, cls->type_num)) {
+ Py_INCREF(cls);
+ return cls;
+ }
+
+ /*
+ * The following code used to be part of PyArray_PromoteTypes().
+ * We can expect that this code is never used.
+ * In principle, it allows for promotion of two different user dtypes
+ * to a single NumPy dtype of the same "kind". In practice
+ * using the same `kind` as NumPy was never possible due to an
+ * simplification where `PyArray_EquivTypes(descr1, descr2)` will
+ * return True if both kind and element size match (e.g. bfloat16 and
+ * float16 would be equivalent).
+ * The option is also very obscure and not used in the examples.
+ */
+
+ /* Convert the 'kind' char into a scalar kind */
+ switch (cls->kind) {
+ case 'b':
+ skind1 = NPY_BOOL_SCALAR;
+ break;
+ case 'u':
+ skind1 = NPY_INTPOS_SCALAR;
+ break;
+ case 'i':
+ skind1 = NPY_INTNEG_SCALAR;
+ break;
+ case 'f':
+ skind1 = NPY_FLOAT_SCALAR;
+ break;
+ case 'c':
+ skind1 = NPY_COMPLEX_SCALAR;
+ break;
+ }
+ switch (other->kind) {
+ case 'b':
+ skind2 = NPY_BOOL_SCALAR;
+ break;
+ case 'u':
+ skind2 = NPY_INTPOS_SCALAR;
+ break;
+ case 'i':
+ skind2 = NPY_INTNEG_SCALAR;
+ break;
+ case 'f':
+ skind2 = NPY_FLOAT_SCALAR;
+ break;
+ case 'c':
+ skind2 = NPY_COMPLEX_SCALAR;
+ break;
+ }
+
+ /* If both are scalars, there may be a promotion possible */
+ if (skind1 != NPY_NOSCALAR && skind2 != NPY_NOSCALAR) {
+
+ /* Start with the larger scalar kind */
+ skind = (skind1 > skind2) ? skind1 : skind2;
+ int ret_type_num = _npy_smallest_type_of_kind_table[skind];
+
+ for (;;) {
+
+ /* If there is no larger type of this kind, try a larger kind */
+ if (ret_type_num < 0) {
+ ++skind;
+ /* Use -1 to signal no promoted type found */
+ if (skind < NPY_NSCALARKINDS) {
+ ret_type_num = _npy_smallest_type_of_kind_table[skind];
+ }
+ else {
+ break;
+ }
+ }
+
+ /* If we found a type to which we can promote both, done! */
+ if (PyArray_CanCastSafely(cls->type_num, ret_type_num) &&
+ PyArray_CanCastSafely(other->type_num, ret_type_num)) {
+ return PyArray_DTypeFromTypeNum(ret_type_num);
+ }
+
+ /* Try the next larger type of this kind */
+ ret_type_num = _npy_next_larger_type_table[ret_type_num];
+ }
+ }
+
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+}
diff --git a/numpy/core/src/multiarray/usertypes.h b/numpy/core/src/multiarray/usertypes.h
index b3e386c5c..1b323d458 100644
--- a/numpy/core/src/multiarray/usertypes.h
+++ b/numpy/core/src/multiarray/usertypes.h
@@ -17,4 +17,8 @@ NPY_NO_EXPORT int
PyArray_RegisterCastFunc(PyArray_Descr *descr, int totype,
PyArray_VectorUnaryFunc *castfunc);
+NPY_NO_EXPORT PyArray_DTypeMeta *
+legacy_userdtype_common_dtype_function(
+ PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other);
+
#endif
diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c
index aa6f34d59..3abeb2c5a 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.c
+++ b/numpy/core/src/umath/ufunc_type_resolution.c
@@ -236,21 +236,6 @@ PyUFunc_ValidateCasting(PyUFuncObject *ufunc,
return 0;
}
-/*
- * Returns a new reference to type if it is already NBO, otherwise
- * returns a copy converted to NBO.
- */
-static PyArray_Descr *
-ensure_dtype_nbo(PyArray_Descr *type)
-{
- if (PyArray_ISNBO(type->byteorder)) {
- Py_INCREF(type);
- return type;
- }
- else {
- return PyArray_DescrNewByteorder(type, NPY_NATIVE);
- }
-}
/*UFUNC_API
*
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index ae5ee4c88..f5428f98c 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -941,7 +941,7 @@ class TestTypes:
return
res = np.promote_types(dtype, dtype)
- if res.char in "?bhilqpBHILQPefdgFDGOmM":
+ if res.char in "?bhilqpBHILQPefdgFDGOmM" or dtype.type is rational:
# Metadata is lost for simple promotions (they create a new dtype)
assert res.metadata is None
else:
@@ -976,41 +976,20 @@ class TestTypes:
# Promotion failed, this test only checks metadata
return
- # The rules for when metadata is preserved and which dtypes metadta
- # will be used are very confusing and depend on multiple paths.
- # This long if statement attempts to reproduce this:
- if dtype1.type is rational or dtype2.type is rational:
- # User dtype promotion preserves byte-order here:
- if np.can_cast(res, dtype1):
- assert res.metadata == dtype1.metadata
- else:
- assert res.metadata == dtype2.metadata
-
- elif res.char in "?bhilqpBHILQPefdgFDGOmM":
+ if res.char in "?bhilqpBHILQPefdgFDGOmM" or res.type is rational:
# All simple types lose metadata (due to using promotion table):
assert res.metadata is None
- elif res.kind in "SU" and dtype1 == dtype2:
- # Strings give precedence to the second dtype:
- assert res is dtype2
elif res == dtype1:
# If one result is the result, it is usually returned unchanged:
assert res is dtype1
elif res == dtype2:
- # If one result is the result, it is usually returned unchanged:
- assert res is dtype2
- elif dtype1.kind == "S" and dtype2.kind == "U":
- # Promotion creates a new unicode dtype from scratch
- assert res.metadata is None
- elif dtype1.kind == "U" and dtype2.kind == "S":
- # Promotion creates a new unicode dtype from scratch
- assert res.metadata is None
- elif res.kind in "SU" and dtype2.kind != res.kind:
- # We build on top of dtype1:
- assert res.metadata == dtype1.metadata
- elif res.kind in "SU" and res.kind == dtype1.kind:
- assert res.metadata == dtype1.metadata
- elif res.kind in "SU" and res.kind == dtype2.kind:
- assert res.metadata == dtype2.metadata
+ # dtype1 may have been cast to the same type/kind as dtype2.
+ # If the resulting dtype is identical we currently pick the cast
+ # version of dtype1, which lost the metadata:
+ if np.promote_types(dtype1, dtype2.kind) == dtype2:
+ res.metadata is None
+ else:
+ res.metadata == metadata2
else:
assert res.metadata is None
@@ -1025,6 +1004,24 @@ class TestTypes:
assert res_bs == res
assert res_bs.metadata == res.metadata
+ @pytest.mark.parametrize(["dtype1", "dtype2"],
+ [[np.dtype("V6"), np.dtype("V10")],
+ [np.dtype([("name1", "i8")]), np.dtype([("name2", "i8")])],
+ [np.dtype("i8,i8"), np.dtype("i4,i4")],
+ ])
+ def test_invalid_void_promotion(self, dtype1, dtype2):
+ # Mainly test structured void promotion, which currently allows
+ # byte-swapping, but nothing else:
+ with pytest.raises(TypeError):
+ np.promote_types(dtype1, dtype2)
+
+ @pytest.mark.parametrize(["dtype1", "dtype2"],
+ [[np.dtype("V10"), np.dtype("V10")],
+ [np.dtype([("name1", "<i8")]), np.dtype([("name1", ">i8")])],
+ [np.dtype("i8,i8"), np.dtype("i8,>i8")],
+ ])
+ def test_valid_void_promotion(self, dtype1, dtype2):
+ assert np.promote_types(dtype1, dtype2) is dtype1
def test_can_cast(self):
assert_(np.can_cast(np.int32, np.int64))