diff options
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 92 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.h | 4 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_type_resolution.c | 72 |
3 files changed, 60 insertions, 108 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index 8e24d6361..025c66013 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -1767,46 +1767,22 @@ dtype_kind_to_simplified_ordering(char kind) } } -/*NUMPY_API - * Produces the result type of a bunch of inputs, using the UFunc - * type promotion rules. Use this function when you have a set of - * input arrays, and need to determine an output array dtype. - * - * If all the inputs are scalars (have 0 dimensions) or the maximum "kind" - * of the scalars is greater than the maximum "kind" of the arrays, does - * a regular type promotion. - * - * Otherwise, does a type promotion on the MinScalarType - * of all the inputs. Data types passed directly are treated as array - * types. - * + +/* + * Determine if there is a mix of scalars and arrays/dtypes. + * If this is the case, the scalars should be handled as the minimum type + * capable of holding the value when the maximum "category" of the scalars + * surpasses the maximum "category" of the arrays/dtypes. + * If the scalars are of a lower or same category as the arrays, they may be + * demoted to a lower type within their category (the lowest type they can + * be cast to safely according to scalar casting rules). */ -NPY_NO_EXPORT PyArray_Descr * -PyArray_ResultType(npy_intp narrs, PyArrayObject **arr, - npy_intp ndtypes, PyArray_Descr **dtypes) +NPY_NO_EXPORT int +should_use_min_scalar(npy_intp narrs, PyArrayObject **arr, + npy_intp ndtypes, PyArray_Descr **dtypes) { - npy_intp i; - int use_min_scalar; - - /* If there's just one type, pass it through */ - if (narrs + ndtypes == 1) { - PyArray_Descr *ret = NULL; - if (narrs == 1) { - ret = PyArray_DESCR(arr[0]); - } - else { - ret = dtypes[0]; - } - Py_INCREF(ret); - return ret; - } + int use_min_scalar = 0; - /* - * Determine if there are any scalars, and if so, whether - * the maximum "kind" of the scalars surpasses the maximum - * "kind" of the arrays - */ - use_min_scalar = 0; if (narrs > 0) { int all_scalars; int max_scalar_kind = -1; @@ -1815,7 +1791,7 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr, all_scalars = (ndtypes > 0) ? 0 : 1; /* Compute the maximum "kinds" and whether everything is scalar */ - for (i = 0; i < narrs; ++i) { + for (npy_intp i = 0; i < narrs; ++i) { if (PyArray_NDIM(arr[i]) == 0) { int kind = dtype_kind_to_simplified_ordering( PyArray_DESCR(arr[i])->kind); @@ -1836,7 +1812,7 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr, * If the max scalar kind is bigger than the max array kind, * finish computing the max array kind */ - for (i = 0; i < ndtypes; ++i) { + for (npy_intp i = 0; i < ndtypes; ++i) { int kind = dtype_kind_to_simplified_ordering(dtypes[i]->kind); if (kind > max_array_kind) { max_array_kind = kind; @@ -1848,6 +1824,44 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr, use_min_scalar = 1; } } + return use_min_scalar; +} + + +/*NUMPY_API + * Produces the result type of a bunch of inputs, using the UFunc + * type promotion rules. Use this function when you have a set of + * input arrays, and need to determine an output array dtype. + * + * If all the inputs are scalars (have 0 dimensions) or the maximum "kind" + * of the scalars is greater than the maximum "kind" of the arrays, does + * a regular type promotion. + * + * Otherwise, does a type promotion on the MinScalarType + * of all the inputs. Data types passed directly are treated as array + * types. + * + */ +NPY_NO_EXPORT PyArray_Descr * +PyArray_ResultType(npy_intp narrs, PyArrayObject **arr, + npy_intp ndtypes, PyArray_Descr **dtypes) +{ + npy_intp i; + + /* If there's just one type, pass it through */ + if (narrs + ndtypes == 1) { + PyArray_Descr *ret = NULL; + if (narrs == 1) { + ret = PyArray_DESCR(arr[0]); + } + else { + ret = dtypes[0]; + } + Py_INCREF(ret); + return ret; + } + + int use_min_scalar = should_use_min_scalar(narrs, arr, ndtypes, dtypes); /* Loop through all the types, promoting them */ if (!use_min_scalar) { diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h index 653557161..72867ead8 100644 --- a/numpy/core/src/multiarray/convert_datatype.h +++ b/numpy/core/src/multiarray/convert_datatype.h @@ -18,6 +18,10 @@ NPY_NO_EXPORT npy_bool can_cast_scalar_to(PyArray_Descr *scal_type, char *scal_data, PyArray_Descr *to, NPY_CASTING casting); +NPY_NO_EXPORT int +should_use_min_scalar(npy_intp narrs, PyArrayObject **arr, + npy_intp ndtypes, PyArray_Descr **dtypes); + /* * This function calls Py_DECREF on flex_dtype, and replaces it with * a new dtype that has been adapted based on the values in data_dtype diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c index 58f915c6e..25dd002ac 100644 --- a/numpy/core/src/umath/ufunc_type_resolution.c +++ b/numpy/core/src/umath/ufunc_type_resolution.c @@ -24,6 +24,7 @@ #include "ufunc_type_resolution.h" #include "ufunc_object.h" #include "common.h" +#include "convert_datatype.h" #include "mem_overlap.h" #if defined(HAVE_CBLAS) @@ -1938,73 +1939,6 @@ type_tuple_userloop_type_resolver(PyUFuncObject *self, return 0; } -/* - * Provides an ordering for the dtype 'kind' character codes, to help - * determine when to use the min_scalar_type function. This groups - * 'kind' into boolean, integer, floating point, and everything else. - */ - -static int -dtype_kind_to_simplified_ordering(char kind) -{ - switch (kind) { - /* Boolean kind */ - case 'b': - return 0; - /* Unsigned int kind */ - case 'u': - /* Signed int kind */ - case 'i': - return 1; - /* Float kind */ - case 'f': - /* Complex kind */ - case 'c': - return 2; - /* Anything else */ - default: - return 3; - } -} - -static int -should_use_min_scalar(PyArrayObject **op, int nop) -{ - int i, use_min_scalar, kind; - int all_scalars = 1, max_scalar_kind = -1, max_array_kind = -1; - - /* - * Determine if there are any scalars, and if so, whether - * the maximum "kind" of the scalars surpasses the maximum - * "kind" of the arrays - */ - use_min_scalar = 0; - if (nop > 1) { - for(i = 0; i < nop; ++i) { - kind = dtype_kind_to_simplified_ordering( - PyArray_DESCR(op[i])->kind); - if (PyArray_NDIM(op[i]) == 0) { - if (kind > max_scalar_kind) { - max_scalar_kind = kind; - } - } - else { - all_scalars = 0; - if (kind > max_array_kind) { - max_array_kind = kind; - } - - } - } - - /* Indicate whether to use the min_scalar_type function */ - if (!all_scalars && max_array_kind >= max_scalar_kind) { - use_min_scalar = 1; - } - } - - return use_min_scalar; -} /* * Does a linear search for the best inner loop of the ufunc. @@ -2030,7 +1964,7 @@ linear_search_type_resolver(PyUFuncObject *self, ufunc_name = ufunc_get_name_cstr(self); - use_min_scalar = should_use_min_scalar(op, nin); + use_min_scalar = should_use_min_scalar(nin, op, 0, NULL); /* If the ufunc has userloops, search for them. */ if (self->userloops) { @@ -2139,7 +2073,7 @@ type_tuple_type_resolver(PyUFuncObject *self, ufunc_name = ufunc_get_name_cstr(self); - use_min_scalar = should_use_min_scalar(op, nin); + use_min_scalar = should_use_min_scalar(nin, op, 0, NULL); /* Fill in specified_types from the tuple or string */ if (PyTuple_Check(type_tup)) { |