summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2011-04-13 12:56:30 -0700
committerMark Wiebe <mwwiebe@gmail.com>2011-04-22 14:04:33 -0700
commit492ae119cd4a5b57065cf15c4bb05d8730ce6d36 (patch)
tree475624f69c8a5bb9de1496f4c4ce62aab5434a73
parentffa8786846f9ec604a9f17d6bff0bebbeb1fdd3f (diff)
downloadnumpy-492ae119cd4a5b57065cf15c4bb05d8730ce6d36.tar.gz
BUG: Fix type promotion regression for the result_type function (Ticket #1798)
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c120
1 files changed, 105 insertions, 15 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index 3359a5573..1c64af85e 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -1072,13 +1072,45 @@ PyArray_MinScalarType(PyArrayObject *arr)
}
}
+/*
+ * Provides an ordering for the dtype 'kind' character codes, to help
+ * determine when to use the min_scalar_type function. This groups
+ * 'kind' into boolean, integer, floating point, and everything else.
+ */
+static int
+dtype_kind_to_simplified_ordering(char kind)
+{
+ switch (kind) {
+ /* Boolean kind */
+ case 'b':
+ return 0;
+ /* Unsigned int kind */
+ case 'u':
+ /* Signed int kind */
+ case 'i':
+ return 1;
+ /* Float kind */
+ case 'f':
+ /* Complex kind */
+ case 'c':
+ return 2;
+ /* Anything else */
+ default:
+ return 3;
+ }
+}
+
/*NUMPY_API
* Produces the result type of a bunch of inputs, using the UFunc
- * type promotion rules.
+ * type promotion rules. Use this function when you have a set of
+ * input arrays, and need to determine an output array dtype.
*
- * If all the inputs are scalars (have 0 dimensions), does a regular
- * type promotion. Otherwise, does a type promotion on the MinScalarType
- * of all the inputs. Data types passed directly are treated as vector
+ * If all the inputs are scalars (have 0 dimensions) or the maximum "kind"
+ * of the scalars is greater than the maximum "kind" of the arrays, does
+ * a regular type promotion.
+ *
+ * Otherwise, does a type promotion on the MinScalarType
+ * of all the inputs. Data types passed directly are treated as array
* types.
*
*/
@@ -1087,7 +1119,7 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr,
npy_intp ndtypes, PyArray_Descr **dtypes)
{
npy_intp i;
- int all_scalar;
+ int use_min_scalar = 0;
PyArray_Descr *ret = NULL, *tmpret;
int ret_is_small_unsigned = 0;
@@ -1103,22 +1135,54 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr,
return ret;
}
- /* Determine if there are any scalars */
- if (ndtypes > 0) {
- all_scalar = 0;
- }
- else {
- all_scalar = 1;
+ /*
+ * Determine if there are any scalars, and if so, whether
+ * the maximum "kind" of the scalars surpasses the maximum
+ * "kind" of the arrays
+ */
+ if (narrs > 0) {
+ int all_scalars, max_scalar_kind = -1, max_array_kind = -1;
+ int kind;
+
+ all_scalars = (ndtypes > 0) ? 0 : 1;
+
+ /* Compute the maximum "kinds" and whether everything is scalar */
for (i = 0; i < narrs; ++i) {
- if (PyArray_NDIM(arr[i]) != 0) {
- all_scalar = 0;
- break;
+ if (PyArray_NDIM(arr[i]) == 0) {
+ kind = dtype_kind_to_simplified_ordering(
+ PyArray_DESCR(arr[i])->kind);
+ if (kind > max_scalar_kind) {
+ max_scalar_kind = kind;
+ }
+ }
+ else {
+ all_scalars = 0;
+ kind = dtype_kind_to_simplified_ordering(
+ PyArray_DESCR(arr[i])->kind);
+ if (kind > max_array_kind) {
+ max_array_kind = kind;
+ }
+ }
+ }
+ /*
+ * If the max scalar kind is bigger than the max array kind,
+ * finish computing the max array kind
+ */
+ for (i = 0; i < ndtypes; ++i) {
+ kind = dtype_kind_to_simplified_ordering(dtypes[i]->kind);
+ if (kind > max_array_kind) {
+ max_array_kind = kind;
}
}
+
+ /* Indicate whether to use the min_scalar_type function */
+ if (!all_scalars && max_array_kind >= max_scalar_kind) {
+ use_min_scalar = 1;
+ }
}
/* Loop through all the types, promoting them */
- if (all_scalar) {
+ if (!use_min_scalar) {
for (i = 0; i < narrs; ++i) {
PyArray_Descr *tmp = PyArray_DESCR(arr[i]);
/* Combine it with the existing type */
@@ -1132,6 +1196,29 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr,
tmpret = PyArray_PromoteTypes(tmp, ret);
Py_DECREF(ret);
ret = tmpret;
+ if (ret == NULL) {
+ return NULL;
+ }
+ }
+ }
+ }
+
+ for (i = 0; i < ndtypes; ++i) {
+ PyArray_Descr *tmp = dtypes[i];
+ /* Combine it with the existing type */
+ if (ret == NULL) {
+ ret = tmp;
+ Py_INCREF(ret);
+ }
+ else {
+ /* Only call promote if the types aren't the same dtype */
+ if (tmp != ret || !PyArray_ISNBO(tmp->byteorder)) {
+ tmpret = PyArray_PromoteTypes(tmp, ret);
+ Py_DECREF(ret);
+ ret = tmpret;
+ if (ret == NULL) {
+ return NULL;
+ }
}
}
}
@@ -1229,6 +1316,9 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr,
}
Py_DECREF(ret);
ret = tmpret;
+ if (ret == NULL) {
+ return NULL;
+ }
}
}
}