diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2011-04-13 12:56:30 -0700 |
---|---|---|
committer | Mark Wiebe <mwwiebe@gmail.com> | 2011-04-22 14:04:33 -0700 |
commit | 492ae119cd4a5b57065cf15c4bb05d8730ce6d36 (patch) | |
tree | 475624f69c8a5bb9de1496f4c4ce62aab5434a73 | |
parent | ffa8786846f9ec604a9f17d6bff0bebbeb1fdd3f (diff) | |
download | numpy-492ae119cd4a5b57065cf15c4bb05d8730ce6d36.tar.gz |
BUG: Fix type promotion regression for the result_type function (Ticket #1798)
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 120 |
1 files changed, 105 insertions, 15 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index 3359a5573..1c64af85e 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -1072,13 +1072,45 @@ PyArray_MinScalarType(PyArrayObject *arr) } } +/* + * Provides an ordering for the dtype 'kind' character codes, to help + * determine when to use the min_scalar_type function. This groups + * 'kind' into boolean, integer, floating point, and everything else. + */ +static int +dtype_kind_to_simplified_ordering(char kind) +{ + switch (kind) { + /* Boolean kind */ + case 'b': + return 0; + /* Unsigned int kind */ + case 'u': + /* Signed int kind */ + case 'i': + return 1; + /* Float kind */ + case 'f': + /* Complex kind */ + case 'c': + return 2; + /* Anything else */ + default: + return 3; + } +} + /*NUMPY_API * Produces the result type of a bunch of inputs, using the UFunc - * type promotion rules. + * type promotion rules. Use this function when you have a set of + * input arrays, and need to determine an output array dtype. * - * If all the inputs are scalars (have 0 dimensions), does a regular - * type promotion. Otherwise, does a type promotion on the MinScalarType - * of all the inputs. Data types passed directly are treated as vector + * If all the inputs are scalars (have 0 dimensions) or the maximum "kind" + * of the scalars is greater than the maximum "kind" of the arrays, does + * a regular type promotion. + * + * Otherwise, does a type promotion on the MinScalarType + * of all the inputs. Data types passed directly are treated as array * types. * */ @@ -1087,7 +1119,7 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr, npy_intp ndtypes, PyArray_Descr **dtypes) { npy_intp i; - int all_scalar; + int use_min_scalar = 0; PyArray_Descr *ret = NULL, *tmpret; int ret_is_small_unsigned = 0; @@ -1103,22 +1135,54 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr, return ret; } - /* Determine if there are any scalars */ - if (ndtypes > 0) { - all_scalar = 0; - } - else { - all_scalar = 1; + /* + * Determine if there are any scalars, and if so, whether + * the maximum "kind" of the scalars surpasses the maximum + * "kind" of the arrays + */ + if (narrs > 0) { + int all_scalars, max_scalar_kind = -1, max_array_kind = -1; + int kind; + + all_scalars = (ndtypes > 0) ? 0 : 1; + + /* Compute the maximum "kinds" and whether everything is scalar */ for (i = 0; i < narrs; ++i) { - if (PyArray_NDIM(arr[i]) != 0) { - all_scalar = 0; - break; + if (PyArray_NDIM(arr[i]) == 0) { + kind = dtype_kind_to_simplified_ordering( + PyArray_DESCR(arr[i])->kind); + if (kind > max_scalar_kind) { + max_scalar_kind = kind; + } + } + else { + all_scalars = 0; + kind = dtype_kind_to_simplified_ordering( + PyArray_DESCR(arr[i])->kind); + if (kind > max_array_kind) { + max_array_kind = kind; + } + } + } + /* + * If the max scalar kind is bigger than the max array kind, + * finish computing the max array kind + */ + for (i = 0; i < ndtypes; ++i) { + kind = dtype_kind_to_simplified_ordering(dtypes[i]->kind); + if (kind > max_array_kind) { + max_array_kind = kind; } } + + /* Indicate whether to use the min_scalar_type function */ + if (!all_scalars && max_array_kind >= max_scalar_kind) { + use_min_scalar = 1; + } } /* Loop through all the types, promoting them */ - if (all_scalar) { + if (!use_min_scalar) { for (i = 0; i < narrs; ++i) { PyArray_Descr *tmp = PyArray_DESCR(arr[i]); /* Combine it with the existing type */ @@ -1132,6 +1196,29 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr, tmpret = PyArray_PromoteTypes(tmp, ret); Py_DECREF(ret); ret = tmpret; + if (ret == NULL) { + return NULL; + } + } + } + } + + for (i = 0; i < ndtypes; ++i) { + PyArray_Descr *tmp = dtypes[i]; + /* Combine it with the existing type */ + if (ret == NULL) { + ret = tmp; + Py_INCREF(ret); + } + else { + /* Only call promote if the types aren't the same dtype */ + if (tmp != ret || !PyArray_ISNBO(tmp->byteorder)) { + tmpret = PyArray_PromoteTypes(tmp, ret); + Py_DECREF(ret); + ret = tmpret; + if (ret == NULL) { + return NULL; + } } } } @@ -1229,6 +1316,9 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr, } Py_DECREF(ret); ret = tmpret; + if (ret == NULL) { + return NULL; + } } } } |