diff options
author | Matti Picus <matti.picus@gmail.com> | 2020-09-22 15:54:18 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-22 15:54:18 +0300 |
commit | 0d2536650ff1b750a9008cfd2893dcc1db58f94f (patch) | |
tree | c4c54da90acf8e3919119c704fcaf4b620678bb5 | |
parent | f981a7ff4b02204f215eb089ce77f9a3f3906b1e (diff) | |
parent | b40f6bb22d7e71533e0b450493530e8fdd08afa5 (diff) | |
download | numpy-0d2536650ff1b750a9008cfd2893dcc1db58f94f.tar.gz |
Merge pull request #17137 from seberg/restructure-dtype-promotion
API,MAINT: Rewrite promotion using common DType and common instance
-rw-r--r-- | numpy/core/include/numpy/ndarraytypes.h | 6 | ||||
-rw-r--r-- | numpy/core/src/multiarray/array_coercion.c | 73 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 413 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.h | 9 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dtypemeta.c | 145 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dtypemeta.h | 16 | ||||
-rw-r--r-- | numpy/core/src/multiarray/usertypes.c | 121 | ||||
-rw-r--r-- | numpy/core/src/multiarray/usertypes.h | 4 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_type_resolution.c | 15 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 57 |
10 files changed, 451 insertions, 408 deletions
diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h index 6eca4afdb..6bf54938f 100644 --- a/numpy/core/include/numpy/ndarraytypes.h +++ b/numpy/core/include/numpy/ndarraytypes.h @@ -1839,6 +1839,10 @@ typedef void (PyDataMem_EventHookFunc)(void *inp, void *outp, size_t size, PyArray_DTypeMeta *cls, PyTypeObject *obj); typedef PyArray_Descr *(default_descr_function)(PyArray_DTypeMeta *cls); + typedef PyArray_DTypeMeta *(common_dtype_function)( + PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtyep2); + typedef PyArray_Descr *(common_instance_function)( + PyArray_Descr *dtype1, PyArray_Descr *dtyep2); /* * While NumPy DTypes would not need to be heap types the plan is to @@ -1894,6 +1898,8 @@ typedef void (PyDataMem_EventHookFunc)(void *inp, void *outp, size_t size, discover_descr_from_pyobject_function *discover_descr_from_pyobject; is_known_scalar_type_function *is_known_scalar_type; default_descr_function *default_descr; + common_dtype_function *common_dtype; + common_instance_function *common_instance; }; #endif /* NPY_INTERNAL_BUILD */ diff --git a/numpy/core/src/multiarray/array_coercion.c b/numpy/core/src/multiarray/array_coercion.c index 3f3fd1387..aae8d5141 100644 --- a/numpy/core/src/multiarray/array_coercion.c +++ b/numpy/core/src/multiarray/array_coercion.c @@ -306,51 +306,6 @@ discover_dtype_from_pyobject( } -/* - * This function should probably become public API eventually. At this - * time it is implemented by falling back to `PyArray_AdaptFlexibleDType`. - * We will use `CastingImpl[from, to].adjust_descriptors(...)` to implement - * this logic. - */ -static NPY_INLINE PyArray_Descr * -cast_descriptor_to_fixed_dtype( - PyArray_Descr *descr, PyArray_DTypeMeta *fixed_DType) -{ - if (fixed_DType == NULL) { - /* Nothing to do, we only need to promote the new dtype */ - Py_INCREF(descr); - return descr; - } - - if (!fixed_DType->parametric) { - /* - * Don't actually do anything, the default is always the result - * of any cast. - */ - return fixed_DType->default_descr(fixed_DType); - } - if (PyObject_TypeCheck((PyObject *)descr, (PyTypeObject *)fixed_DType)) { - Py_INCREF(descr); - return descr; - } - /* - * TODO: When this is implemented for all dtypes, the special cases - * can be removed... - */ - if (fixed_DType->legacy && fixed_DType->parametric && - NPY_DTYPE(descr)->legacy) { - PyArray_Descr *flex_dtype = PyArray_DescrFromType(fixed_DType->type_num); - return PyArray_AdaptFlexibleDType(descr, flex_dtype); - } - - PyErr_SetString(PyExc_NotImplementedError, - "Must use casting to find the correct dtype, this is " - "not yet implemented! " - "(It should not be possible to hit this code currently!)"); - return NULL; -} - - /** * Discover the correct descriptor from a known DType class and scalar. * If the fixed DType can discover a dtype instance/descr all is fine, @@ -392,7 +347,7 @@ find_scalar_descriptor( return descr; } - Py_SETREF(descr, cast_descriptor_to_fixed_dtype(descr, fixed_DType)); + Py_SETREF(descr, PyArray_CastDescrToDType(descr, fixed_DType)); return descr; } @@ -727,8 +682,13 @@ find_descriptor_from_array( enum _dtype_discovery_flags flags = 0; *out_descr = NULL; - if (NPY_UNLIKELY(DType != NULL && DType->parametric && - PyArray_ISOBJECT(arr))) { + if (DType == NULL) { + *out_descr = PyArray_DESCR(arr); + Py_INCREF(*out_descr); + return 0; + } + + if (NPY_UNLIKELY(DType->parametric && PyArray_ISOBJECT(arr))) { /* * We have one special case, if (and only if) the input array is of * object DType and the dtype is not fixed already but parametric. @@ -777,7 +737,7 @@ find_descriptor_from_array( } Py_DECREF(iter); } - else if (DType != NULL && NPY_UNLIKELY(DType->type_num == NPY_DATETIME) && + else if (NPY_UNLIKELY(DType->type_num == NPY_DATETIME) && PyArray_ISSTRING(arr)) { /* * TODO: This branch should be deprecated IMO, the workaround is @@ -806,8 +766,7 @@ find_descriptor_from_array( * If this is not an object array figure out the dtype cast, * or simply use the returned DType. */ - *out_descr = cast_descriptor_to_fixed_dtype( - PyArray_DESCR(arr), DType); + *out_descr = PyArray_CastDescrToDType(PyArray_DESCR(arr), DType); if (*out_descr == NULL) { return -1; } @@ -1325,15 +1284,9 @@ PyArray_DiscoverDTypeAndShape( * the correct default. */ if (fixed_DType != NULL) { - if (fixed_DType->default_descr == NULL) { - Py_INCREF(fixed_DType->singleton); - *out_descr = fixed_DType->singleton; - } - else { - *out_descr = fixed_DType->default_descr(fixed_DType); - if (*out_descr == NULL) { - goto fail; - } + *out_descr = fixed_DType->default_descr(fixed_DType); + if (*out_descr == NULL) { + goto fail; } } } diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index d9121707b..f700bdc99 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -1019,7 +1019,7 @@ promote_types(PyArray_Descr *type1, PyArray_Descr *type2, * Returns a new reference to type if it is already NBO, otherwise * returns a copy converted to NBO. */ -static PyArray_Descr * +NPY_NO_EXPORT PyArray_Descr * ensure_dtype_nbo(PyArray_Descr *type) { if (PyArray_ISNBO(type->byteorder)) { @@ -1031,327 +1031,148 @@ ensure_dtype_nbo(PyArray_Descr *type) } } -/*NUMPY_API - * Produces the smallest size and lowest kind type to which both - * input types can be cast. + +/** + * This function should possibly become public API eventually. At this + * time it is implemented by falling back to `PyArray_AdaptFlexibleDType`. + * We will use `CastingImpl[from, to].adjust_descriptors(...)` to implement + * this logic. + * Before that, the API needs to be reviewed though. + * + * WARNING: This function currently does not guarantee that `descr` can + * actually be cast to the given DType. + * + * @param descr The dtype instance to adapt "cast" + * @param given_DType The DType class for which we wish to find an instance able + * to represent `descr`. + * @returns Instance of `given_DType`. If `given_DType` is parametric the + * descr may be adapted to hold it. */ NPY_NO_EXPORT PyArray_Descr * -PyArray_PromoteTypes(PyArray_Descr *type1, PyArray_Descr *type2) +PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType) { - int type_num1, type_num2, ret_type_num; - - /* - * Fast path for identical dtypes. - * - * Non-native-byte-order types are converted to native ones below, so we - * can't quit early. - */ - if (type1 == type2 && PyArray_ISNBO(type1->byteorder)) { - Py_INCREF(type1); - return type1; + if (NPY_DTYPE(descr) == given_DType) { + Py_INCREF(descr); + return descr; } - - type_num1 = type1->type_num; - type_num2 = type2->type_num; - - /* If they're built-in types, use the promotion table */ - if (type_num1 < NPY_NTYPES && type_num2 < NPY_NTYPES) { - ret_type_num = _npy_type_promotion_table[type_num1][type_num2]; + if (!given_DType->parametric) { /* - * The table doesn't handle string/unicode/void/datetime/timedelta, - * so check the result + * Don't actually do anything, the default is always the result + * of any cast. */ - if (ret_type_num >= 0) { - return PyArray_DescrFromType(ret_type_num); - } + return given_DType->default_descr(given_DType); + } + if (PyObject_TypeCheck((PyObject *)descr, (PyTypeObject *)given_DType)) { + Py_INCREF(descr); + return descr; } - /* If one or both are user defined, calculate it */ - else { - int skind1 = NPY_NOSCALAR, skind2 = NPY_NOSCALAR, skind; - if (PyArray_CanCastTo(type2, type1)) { - /* Promoted types are always native byte order */ - return ensure_dtype_nbo(type1); - } - else if (PyArray_CanCastTo(type1, type2)) { - /* Promoted types are always native byte order */ - return ensure_dtype_nbo(type2); - } + if (!given_DType->legacy) { + PyErr_SetString(PyExc_NotImplementedError, + "Must use casting to find the correct DType for a parametric " + "user DType. This is not yet implemented (this error should be " + "unreachable)."); + return NULL; + } - /* Convert the 'kind' char into a scalar kind */ - switch (type1->kind) { - case 'b': - skind1 = NPY_BOOL_SCALAR; - break; - case 'u': - skind1 = NPY_INTPOS_SCALAR; - break; - case 'i': - skind1 = NPY_INTNEG_SCALAR; - break; - case 'f': - skind1 = NPY_FLOAT_SCALAR; - break; - case 'c': - skind1 = NPY_COMPLEX_SCALAR; - break; - } - switch (type2->kind) { - case 'b': - skind2 = NPY_BOOL_SCALAR; - break; - case 'u': - skind2 = NPY_INTPOS_SCALAR; - break; - case 'i': - skind2 = NPY_INTNEG_SCALAR; - break; - case 'f': - skind2 = NPY_FLOAT_SCALAR; - break; - case 'c': - skind2 = NPY_COMPLEX_SCALAR; - break; - } + PyArray_Descr *flex_dtype = PyArray_DescrNew(given_DType->singleton); + return PyArray_AdaptFlexibleDType(descr, flex_dtype); +} - /* If both are scalars, there may be a promotion possible */ - if (skind1 != NPY_NOSCALAR && skind2 != NPY_NOSCALAR) { - /* Start with the larger scalar kind */ - skind = (skind1 > skind2) ? skind1 : skind2; - ret_type_num = _npy_smallest_type_of_kind_table[skind]; +/** + * This function defines the common DType operator. + * + * Note that the common DType will not be "object" (unless one of the dtypes + * is object), even though object can technically represent all values + * correctly. + * + * TODO: Before exposure, we should review the return value (e.g. no error + * when no common DType is found). + * + * @param dtype1 DType class to find the common type for. + * @param dtype2 Second DType class. + * @return The common DType or NULL with an error set + */ +NPY_NO_EXPORT PyArray_DTypeMeta * +PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2) +{ + if (dtype1 == dtype2) { + Py_INCREF(dtype1); + return dtype1; + } - for (;;) { + PyArray_DTypeMeta *common_dtype; - /* If there is no larger type of this kind, try a larger kind */ - if (ret_type_num < 0) { - ++skind; - /* Use -1 to signal no promoted type found */ - if (skind < NPY_NSCALARKINDS) { - ret_type_num = _npy_smallest_type_of_kind_table[skind]; - } - else { - break; - } - } + common_dtype = dtype1->common_dtype(dtype1, dtype2); + if (common_dtype == (PyArray_DTypeMeta *)Py_NotImplemented) { + Py_DECREF(common_dtype); + common_dtype = dtype2->common_dtype(dtype2, dtype1); + } + if (common_dtype == NULL) { + return NULL; + } + if (common_dtype == (PyArray_DTypeMeta *)Py_NotImplemented) { + Py_DECREF(Py_NotImplemented); + PyErr_Format(PyExc_TypeError, + "The DTypes %S and %S do not have a common DType. " + "For example they cannot be stored in a single array unless " + "the dtype is `object`.", dtype1, dtype2); + return NULL; + } + return common_dtype; +} - /* If we found a type to which we can promote both, done! */ - if (PyArray_CanCastSafely(type_num1, ret_type_num) && - PyArray_CanCastSafely(type_num2, ret_type_num)) { - return PyArray_DescrFromType(ret_type_num); - } - /* Try the next larger type of this kind */ - ret_type_num = _npy_next_larger_type_table[ret_type_num]; - } +/*NUMPY_API + * Produces the smallest size and lowest kind type to which both + * input types can be cast. + */ +NPY_NO_EXPORT PyArray_Descr * +PyArray_PromoteTypes(PyArray_Descr *type1, PyArray_Descr *type2) +{ + PyArray_DTypeMeta *common_dtype; + PyArray_Descr *res; - } + /* Fast path for identical inputs (NOTE: This path preserves metadata!) */ + if (type1 == type2 && PyArray_ISNBO(type1->byteorder)) { + Py_INCREF(type1); + return type1; + } - PyErr_SetString(PyExc_TypeError, - "invalid type promotion with custom data type"); + common_dtype = PyArray_CommonDType(NPY_DTYPE(type1), NPY_DTYPE(type2)); + if (common_dtype == NULL) { return NULL; } - switch (type_num1) { - /* BOOL can convert to anything except datetime/void */ - case NPY_BOOL: - if (type_num2 == NPY_STRING || type_num2 == NPY_UNICODE) { - int char_size = 1; - if (type_num2 == NPY_UNICODE) { - char_size = 4; - } - if (type2->elsize < 5 * char_size) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type2); - ret = ensure_dtype_nbo(temp); - ret->elsize = 5 * char_size; - Py_DECREF(temp); - return ret; - } - return ensure_dtype_nbo(type2); - } - else if (type_num2 != NPY_DATETIME && type_num2 != NPY_VOID) { - return ensure_dtype_nbo(type2); - } - break; - /* For strings and unicodes, take the larger size */ - case NPY_STRING: - if (type_num2 == NPY_STRING) { - if (type1->elsize > type2->elsize) { - return ensure_dtype_nbo(type1); - } - else { - return ensure_dtype_nbo(type2); - } - } - else if (type_num2 == NPY_UNICODE) { - if (type2->elsize >= type1->elsize * 4) { - return ensure_dtype_nbo(type2); - } - else { - PyArray_Descr *d = PyArray_DescrNewFromType(NPY_UNICODE); - if (d == NULL) { - return NULL; - } - d->elsize = type1->elsize * 4; - return d; - } - } - /* Allow NUMBER -> STRING */ - else if (PyTypeNum_ISNUMBER(type_num2)) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type1); - PyDataType_MAKEUNSIZED(temp); - - temp = PyArray_AdaptFlexibleDType(type2, temp); - if (temp == NULL) { - return NULL; - } - if (temp->elsize > type1->elsize) { - ret = ensure_dtype_nbo(temp); - } - else { - ret = ensure_dtype_nbo(type1); - } - Py_DECREF(temp); - return ret; - } - break; - case NPY_UNICODE: - if (type_num2 == NPY_UNICODE) { - if (type1->elsize > type2->elsize) { - return ensure_dtype_nbo(type1); - } - else { - return ensure_dtype_nbo(type2); - } - } - else if (type_num2 == NPY_STRING) { - if (type1->elsize >= type2->elsize * 4) { - return ensure_dtype_nbo(type1); - } - else { - PyArray_Descr *d = PyArray_DescrNewFromType(NPY_UNICODE); - if (d == NULL) { - return NULL; - } - d->elsize = type2->elsize * 4; - return d; - } - } - /* Allow NUMBER -> UNICODE */ - else if (PyTypeNum_ISNUMBER(type_num2)) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type1); - PyDataType_MAKEUNSIZED(temp); - temp = PyArray_AdaptFlexibleDType(type2, temp); - if (temp == NULL) { - return NULL; - } - if (temp->elsize > type1->elsize) { - ret = ensure_dtype_nbo(temp); - } - else { - ret = ensure_dtype_nbo(type1); - } - Py_DECREF(temp); - return ret; - } - break; - case NPY_DATETIME: - case NPY_TIMEDELTA: - if (type_num2 == NPY_DATETIME || type_num2 == NPY_TIMEDELTA) { - return datetime_type_promotion(type1, type2); - } - break; + if (!common_dtype->parametric) { + res = common_dtype->default_descr(common_dtype); + Py_DECREF(common_dtype); + return res; } - switch (type_num2) { - /* BOOL can convert to almost anything */ - case NPY_BOOL: - if (type_num2 == NPY_STRING || type_num2 == NPY_UNICODE) { - int char_size = 1; - if (type_num2 == NPY_UNICODE) { - char_size = 4; - } - if (type2->elsize < 5 * char_size) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type2); - ret = ensure_dtype_nbo(temp); - ret->elsize = 5 * char_size; - Py_DECREF(temp); - return ret; - } - return ensure_dtype_nbo(type2); - } - else if (type_num1 != NPY_DATETIME && type_num1 != NPY_TIMEDELTA && - type_num1 != NPY_VOID) { - return ensure_dtype_nbo(type1); - } - break; - case NPY_STRING: - /* Allow NUMBER -> STRING */ - if (PyTypeNum_ISNUMBER(type_num1)) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type2); - PyDataType_MAKEUNSIZED(temp); - temp = PyArray_AdaptFlexibleDType(type1, temp); - if (temp == NULL) { - return NULL; - } - if (temp->elsize > type2->elsize) { - ret = ensure_dtype_nbo(temp); - } - else { - ret = ensure_dtype_nbo(type2); - } - Py_DECREF(temp); - return ret; - } - break; - case NPY_UNICODE: - /* Allow NUMBER -> UNICODE */ - if (PyTypeNum_ISNUMBER(type_num1)) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type2); - PyDataType_MAKEUNSIZED(temp); - temp = PyArray_AdaptFlexibleDType(type1, temp); - if (temp == NULL) { - return NULL; - } - if (temp->elsize > type2->elsize) { - ret = ensure_dtype_nbo(temp); - } - else { - ret = ensure_dtype_nbo(type2); - } - Py_DECREF(temp); - return ret; - } - break; - case NPY_TIMEDELTA: - if (PyTypeNum_ISSIGNED(type_num1)) { - return ensure_dtype_nbo(type2); - } - break; + /* Cast the input types to the common DType if necessary */ + type1 = PyArray_CastDescrToDType(type1, common_dtype); + if (type1 == NULL) { + Py_DECREF(common_dtype); + return NULL; } - - /* For types equivalent up to endianness, can return either */ - if (PyArray_CanCastTypeTo(type1, type2, NPY_EQUIV_CASTING)) { - return ensure_dtype_nbo(type1); + type2 = PyArray_CastDescrToDType(type2, common_dtype); + if (type2 == NULL) { + Py_DECREF(type1); + Py_DECREF(common_dtype); + return NULL; } - /* TODO: Also combine fields, subarrays, strings, etc */ - /* - printf("invalid type promotion: "); - PyObject_Print(type1, stdout, 0); - printf(" "); - PyObject_Print(type2, stdout, 0); - printf("\n"); - */ - PyErr_SetString(PyExc_TypeError, "invalid type promotion"); - return NULL; + * And find the common instance of the two inputs + * NOTE: Common instance preserves metadata (normally and of one input) + */ + res = common_dtype->common_instance(type1, type2); + Py_DECREF(type1); + Py_DECREF(type2); + Py_DECREF(common_dtype); + return res; } /* diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h index 9b7f39db2..a2b36b497 100644 --- a/numpy/core/src/multiarray/convert_datatype.h +++ b/numpy/core/src/multiarray/convert_datatype.h @@ -10,6 +10,9 @@ PyArray_ObjectType(PyObject *op, int minimum_type); NPY_NO_EXPORT PyArrayObject ** PyArray_ConvertToCommonType(PyObject *op, int *retn); +NPY_NO_EXPORT PyArray_DTypeMeta * +PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2); + NPY_NO_EXPORT int PyArray_ValidType(int type); @@ -18,6 +21,9 @@ NPY_NO_EXPORT npy_bool can_cast_scalar_to(PyArray_Descr *scal_type, char *scal_data, PyArray_Descr *to, NPY_CASTING casting); +NPY_NO_EXPORT PyArray_Descr * +ensure_dtype_nbo(PyArray_Descr *type); + NPY_NO_EXPORT int should_use_min_scalar(npy_intp narrs, PyArrayObject **arr, npy_intp ndtypes, PyArray_Descr **dtypes); @@ -49,4 +55,7 @@ npy_set_invalid_cast_error( NPY_NO_EXPORT PyArray_Descr * PyArray_AdaptFlexibleDType(PyArray_Descr *data_dtype, PyArray_Descr *flex_dtype); +NPY_NO_EXPORT PyArray_Descr * +PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType); + #endif diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c index d07dc700d..531f746d8 100644 --- a/numpy/core/src/multiarray/dtypemeta.c +++ b/numpy/core/src/multiarray/dtypemeta.c @@ -15,6 +15,9 @@ #include "dtypemeta.h" #include "_datetime.h" #include "array_coercion.h" +#include "scalartypes.h" +#include "convert_datatype.h" +#include "usertypes.h" static void @@ -194,6 +197,14 @@ discover_datetime_and_timedelta_from_pyobject( static PyArray_Descr * +nonparametric_default_descr(PyArray_DTypeMeta *cls) +{ + Py_INCREF(cls->singleton); + return cls->singleton; +} + + +static PyArray_Descr * flexible_default_descr(PyArray_DTypeMeta *cls) { PyArray_Descr *res = PyArray_DescrNewFromType(cls->type_num); @@ -208,6 +219,34 @@ flexible_default_descr(PyArray_DTypeMeta *cls) } +static PyArray_Descr * +string_unicode_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2) +{ + if (descr1->elsize >= descr2->elsize) { + return ensure_dtype_nbo(descr1); + } + else { + return ensure_dtype_nbo(descr2); + } +} + + +static PyArray_Descr * +void_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2) +{ + /* + * We currently do not support promotion of void types unless they + * are equivalent. + */ + if (!PyArray_CanCastTypeTo(descr1, descr2, NPY_EQUIV_CASTING)) { + PyErr_SetString(PyExc_TypeError, + "invalid type promotion with structured or void datatype(s)."); + return NULL; + } + Py_INCREF(descr1); + return descr1; +} + static int python_builtins_are_known_scalar_types( PyArray_DTypeMeta *NPY_UNUSED(cls), PyTypeObject *pytype) @@ -281,6 +320,86 @@ string_known_scalar_types( } +/* + * The following set of functions define the common dtype operator for + * the builtin types. + */ +static PyArray_DTypeMeta * +default_builtin_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) +{ + assert(cls->type_num < NPY_NTYPES); + if (!other->legacy || other->type_num > cls->type_num) { + /* Let the more generic (larger type number) DType handle this */ + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; + } + + /* + * Note: The use of the promotion table should probably be revised at + * some point. It may be most useful to remove it entirely and then + * consider adding a fast path/cache `PyArray_CommonDType()` itself. + */ + int common_num = _npy_type_promotion_table[cls->type_num][other->type_num]; + if (common_num < 0) { + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; + } + return PyArray_DTypeFromTypeNum(common_num); +} + + +static PyArray_DTypeMeta * +string_unicode_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) +{ + assert(cls->type_num < NPY_NTYPES); + if (!other->legacy || other->type_num > cls->type_num || + other->type_num == NPY_OBJECT) { + /* Let the more generic (larger type number) DType handle this */ + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; + } + /* + * The builtin types are ordered by complexity (aside from object) here. + * Arguably, we should not consider numbers and strings "common", but + * we currently do. + */ + Py_INCREF(cls); + return cls; +} + +static PyArray_DTypeMeta * +datetime_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) +{ + if (cls->type_num == NPY_DATETIME && other->type_num == NPY_TIMEDELTA) { + /* + * TODO: We actually currently do allow promotion here. This is + * currently relied on within `np.add(datetime, timedelta)`, + * while for concatenation the cast step will fail. + */ + Py_INCREF(cls); + return cls; + } + return default_builtin_common_dtype(cls, other); +} + + + +static PyArray_DTypeMeta * +object_common_dtype( + PyArray_DTypeMeta *cls, PyArray_DTypeMeta *NPY_UNUSED(other)) +{ + /* + * The object DType is special in that it can represent everything, + * including all potential user DTypes. + * One reason to defer (or error) here might be if the other DType + * does not support scalars so that e.g. `arr1d[0]` returns a 0-D array + * and `arr.astype(object)` would fail. But object casts are special. + */ + Py_INCREF(cls); + return cls; +} + + /** * This function takes a PyArray_Descr and replaces its base class with * a newly created dtype subclass (DTypeMeta instances). @@ -398,14 +517,28 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) dtype_class->f = descr->f; dtype_class->kind = descr->kind; - /* Strings and voids have (strange) logic around scalars. */ + /* Set default functions (correct for most dtypes, override below) */ + dtype_class->default_descr = nonparametric_default_descr; + dtype_class->discover_descr_from_pyobject = ( + nonparametric_discover_descr_from_pyobject); dtype_class->is_known_scalar_type = python_builtins_are_known_scalar_types; + dtype_class->common_dtype = default_builtin_common_dtype; + dtype_class->common_instance = NULL; - if (PyTypeNum_ISDATETIME(descr->type_num)) { + if (PyTypeNum_ISUSERDEF(descr->type_num)) { + dtype_class->common_dtype = legacy_userdtype_common_dtype_function; + } + else if (descr->type_num == NPY_OBJECT) { + dtype_class->common_dtype = object_common_dtype; + } + else if (PyTypeNum_ISDATETIME(descr->type_num)) { /* Datetimes are flexible, but were not considered previously */ dtype_class->parametric = NPY_TRUE; + dtype_class->default_descr = flexible_default_descr; dtype_class->discover_descr_from_pyobject = ( discover_datetime_and_timedelta_from_pyobject); + dtype_class->common_dtype = datetime_common_dtype; + dtype_class->common_instance = datetime_type_promotion; if (descr->type_num == NPY_DATETIME) { dtype_class->is_known_scalar_type = datetime_known_scalar_types; } @@ -416,18 +549,16 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) if (descr->type_num == NPY_VOID) { dtype_class->discover_descr_from_pyobject = ( void_discover_descr_from_pyobject); + dtype_class->common_instance = void_common_instance; } else { dtype_class->is_known_scalar_type = string_known_scalar_types; dtype_class->discover_descr_from_pyobject = ( string_discover_descr_from_pyobject); + dtype_class->common_dtype = string_unicode_common_dtype; + dtype_class->common_instance = string_unicode_common_instance; } } - else { - /* nonparametric case */ - dtype_class->discover_descr_from_pyobject = ( - nonparametric_discover_descr_from_pyobject); - } if (_PyArray_MapPyTypeToDType(dtype_class, descr->typeobj, PyTypeNum_ISUSERDEF(dtype_class->type_num)) < 0) { diff --git a/numpy/core/src/multiarray/dtypemeta.h b/numpy/core/src/multiarray/dtypemeta.h index e0909a7eb..83cf7c07e 100644 --- a/numpy/core/src/multiarray/dtypemeta.h +++ b/numpy/core/src/multiarray/dtypemeta.h @@ -2,6 +2,22 @@ #define _NPY_DTYPEMETA_H #define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr)) +/* + * This function will hopefully be phased out or replaced, but was convenient + * for incremental implementation of new DTypes based on DTypeMeta. + * (Error checking is not required for DescrFromType, assuming that the + * type is valid.) + */ +static NPY_INLINE PyArray_DTypeMeta * +PyArray_DTypeFromTypeNum(int typenum) +{ + PyArray_Descr *descr = PyArray_DescrFromType(typenum); + PyArray_DTypeMeta *dtype = NPY_DTYPE(descr); + Py_INCREF(dtype); + Py_DECREF(descr); + return dtype; +} + NPY_NO_EXPORT int dtypemeta_wrap_legacy_descriptor(PyArray_Descr *dtypem); diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c index d75fd3130..3727567e0 100644 --- a/numpy/core/src/multiarray/usertypes.c +++ b/numpy/core/src/multiarray/usertypes.c @@ -38,6 +38,7 @@ maintainer email: oliphant.travis@ieee.org #include "usertypes.h" #include "dtypemeta.h" +#include "scalartypes.h" NPY_NO_EXPORT PyArray_Descr **userdescrs=NULL; @@ -350,3 +351,123 @@ PyArray_RegisterCanCast(PyArray_Descr *descr, int totype, return _append_new(&descr->f->cancastscalarkindto[scalar], totype); } } + + +/* + * Legacy user DTypes implemented the common DType operation + * (as used in type promotion/result_type, and e.g. the type for + * concatenation), by using "safe cast" logic. + * + * New DTypes do have this behaviour generally, but we use can-cast + * when legacy user dtypes are involved. + */ +NPY_NO_EXPORT PyArray_DTypeMeta * +legacy_userdtype_common_dtype_function( + PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) +{ + int skind1 = NPY_NOSCALAR, skind2 = NPY_NOSCALAR, skind; + + if (!other->legacy) { + /* legacy DTypes can always defer to new style ones */ + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; + } + /* Defer so that only one of the types handles the cast */ + if (cls->type_num < other->type_num) { + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; + } + + /* Check whether casting is possible from one type to the other */ + if (PyArray_CanCastSafely(cls->type_num, other->type_num)) { + Py_INCREF(other); + return other; + } + if (PyArray_CanCastSafely(other->type_num, cls->type_num)) { + Py_INCREF(cls); + return cls; + } + + /* + * The following code used to be part of PyArray_PromoteTypes(). + * We can expect that this code is never used. + * In principle, it allows for promotion of two different user dtypes + * to a single NumPy dtype of the same "kind". In practice + * using the same `kind` as NumPy was never possible due to an + * simplification where `PyArray_EquivTypes(descr1, descr2)` will + * return True if both kind and element size match (e.g. bfloat16 and + * float16 would be equivalent). + * The option is also very obscure and not used in the examples. + */ + + /* Convert the 'kind' char into a scalar kind */ + switch (cls->kind) { + case 'b': + skind1 = NPY_BOOL_SCALAR; + break; + case 'u': + skind1 = NPY_INTPOS_SCALAR; + break; + case 'i': + skind1 = NPY_INTNEG_SCALAR; + break; + case 'f': + skind1 = NPY_FLOAT_SCALAR; + break; + case 'c': + skind1 = NPY_COMPLEX_SCALAR; + break; + } + switch (other->kind) { + case 'b': + skind2 = NPY_BOOL_SCALAR; + break; + case 'u': + skind2 = NPY_INTPOS_SCALAR; + break; + case 'i': + skind2 = NPY_INTNEG_SCALAR; + break; + case 'f': + skind2 = NPY_FLOAT_SCALAR; + break; + case 'c': + skind2 = NPY_COMPLEX_SCALAR; + break; + } + + /* If both are scalars, there may be a promotion possible */ + if (skind1 != NPY_NOSCALAR && skind2 != NPY_NOSCALAR) { + + /* Start with the larger scalar kind */ + skind = (skind1 > skind2) ? skind1 : skind2; + int ret_type_num = _npy_smallest_type_of_kind_table[skind]; + + for (;;) { + + /* If there is no larger type of this kind, try a larger kind */ + if (ret_type_num < 0) { + ++skind; + /* Use -1 to signal no promoted type found */ + if (skind < NPY_NSCALARKINDS) { + ret_type_num = _npy_smallest_type_of_kind_table[skind]; + } + else { + break; + } + } + + /* If we found a type to which we can promote both, done! */ + if (PyArray_CanCastSafely(cls->type_num, ret_type_num) && + PyArray_CanCastSafely(other->type_num, ret_type_num)) { + return PyArray_DTypeFromTypeNum(ret_type_num); + } + + /* Try the next larger type of this kind */ + ret_type_num = _npy_next_larger_type_table[ret_type_num]; + } + } + + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; +} diff --git a/numpy/core/src/multiarray/usertypes.h b/numpy/core/src/multiarray/usertypes.h index b3e386c5c..1b323d458 100644 --- a/numpy/core/src/multiarray/usertypes.h +++ b/numpy/core/src/multiarray/usertypes.h @@ -17,4 +17,8 @@ NPY_NO_EXPORT int PyArray_RegisterCastFunc(PyArray_Descr *descr, int totype, PyArray_VectorUnaryFunc *castfunc); +NPY_NO_EXPORT PyArray_DTypeMeta * +legacy_userdtype_common_dtype_function( + PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other); + #endif diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c index aa6f34d59..3abeb2c5a 100644 --- a/numpy/core/src/umath/ufunc_type_resolution.c +++ b/numpy/core/src/umath/ufunc_type_resolution.c @@ -236,21 +236,6 @@ PyUFunc_ValidateCasting(PyUFuncObject *ufunc, return 0; } -/* - * Returns a new reference to type if it is already NBO, otherwise - * returns a copy converted to NBO. - */ -static PyArray_Descr * -ensure_dtype_nbo(PyArray_Descr *type) -{ - if (PyArray_ISNBO(type->byteorder)) { - Py_INCREF(type); - return type; - } - else { - return PyArray_DescrNewByteorder(type, NPY_NATIVE); - } -} /*UFUNC_API * diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index ae5ee4c88..f5428f98c 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -941,7 +941,7 @@ class TestTypes: return res = np.promote_types(dtype, dtype) - if res.char in "?bhilqpBHILQPefdgFDGOmM": + if res.char in "?bhilqpBHILQPefdgFDGOmM" or dtype.type is rational: # Metadata is lost for simple promotions (they create a new dtype) assert res.metadata is None else: @@ -976,41 +976,20 @@ class TestTypes: # Promotion failed, this test only checks metadata return - # The rules for when metadata is preserved and which dtypes metadta - # will be used are very confusing and depend on multiple paths. - # This long if statement attempts to reproduce this: - if dtype1.type is rational or dtype2.type is rational: - # User dtype promotion preserves byte-order here: - if np.can_cast(res, dtype1): - assert res.metadata == dtype1.metadata - else: - assert res.metadata == dtype2.metadata - - elif res.char in "?bhilqpBHILQPefdgFDGOmM": + if res.char in "?bhilqpBHILQPefdgFDGOmM" or res.type is rational: # All simple types lose metadata (due to using promotion table): assert res.metadata is None - elif res.kind in "SU" and dtype1 == dtype2: - # Strings give precedence to the second dtype: - assert res is dtype2 elif res == dtype1: # If one result is the result, it is usually returned unchanged: assert res is dtype1 elif res == dtype2: - # If one result is the result, it is usually returned unchanged: - assert res is dtype2 - elif dtype1.kind == "S" and dtype2.kind == "U": - # Promotion creates a new unicode dtype from scratch - assert res.metadata is None - elif dtype1.kind == "U" and dtype2.kind == "S": - # Promotion creates a new unicode dtype from scratch - assert res.metadata is None - elif res.kind in "SU" and dtype2.kind != res.kind: - # We build on top of dtype1: - assert res.metadata == dtype1.metadata - elif res.kind in "SU" and res.kind == dtype1.kind: - assert res.metadata == dtype1.metadata - elif res.kind in "SU" and res.kind == dtype2.kind: - assert res.metadata == dtype2.metadata + # dtype1 may have been cast to the same type/kind as dtype2. + # If the resulting dtype is identical we currently pick the cast + # version of dtype1, which lost the metadata: + if np.promote_types(dtype1, dtype2.kind) == dtype2: + res.metadata is None + else: + res.metadata == metadata2 else: assert res.metadata is None @@ -1025,6 +1004,24 @@ class TestTypes: assert res_bs == res assert res_bs.metadata == res.metadata + @pytest.mark.parametrize(["dtype1", "dtype2"], + [[np.dtype("V6"), np.dtype("V10")], + [np.dtype([("name1", "i8")]), np.dtype([("name2", "i8")])], + [np.dtype("i8,i8"), np.dtype("i4,i4")], + ]) + def test_invalid_void_promotion(self, dtype1, dtype2): + # Mainly test structured void promotion, which currently allows + # byte-swapping, but nothing else: + with pytest.raises(TypeError): + np.promote_types(dtype1, dtype2) + + @pytest.mark.parametrize(["dtype1", "dtype2"], + [[np.dtype("V10"), np.dtype("V10")], + [np.dtype([("name1", "<i8")]), np.dtype([("name1", ">i8")])], + [np.dtype("i8,i8"), np.dtype("i8,>i8")], + ]) + def test_valid_void_promotion(self, dtype1, dtype2): + assert np.promote_types(dtype1, dtype2) is dtype1 def test_can_cast(self): assert_(np.can_cast(np.int32, np.int64)) |