summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c92
-rw-r--r--numpy/core/src/multiarray/convert_datatype.h4
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.c72
3 files changed, 60 insertions, 108 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index 8e24d6361..025c66013 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -1767,46 +1767,22 @@ dtype_kind_to_simplified_ordering(char kind)
}
}
-/*NUMPY_API
- * Produces the result type of a bunch of inputs, using the UFunc
- * type promotion rules. Use this function when you have a set of
- * input arrays, and need to determine an output array dtype.
- *
- * If all the inputs are scalars (have 0 dimensions) or the maximum "kind"
- * of the scalars is greater than the maximum "kind" of the arrays, does
- * a regular type promotion.
- *
- * Otherwise, does a type promotion on the MinScalarType
- * of all the inputs. Data types passed directly are treated as array
- * types.
- *
+
+/*
+ * Determine if there is a mix of scalars and arrays/dtypes.
+ * If this is the case, the scalars should be handled as the minimum type
+ * capable of holding the value when the maximum "category" of the scalars
+ * surpasses the maximum "category" of the arrays/dtypes.
+ * If the scalars are of a lower or same category as the arrays, they may be
+ * demoted to a lower type within their category (the lowest type they can
+ * be cast to safely according to scalar casting rules).
*/
-NPY_NO_EXPORT PyArray_Descr *
-PyArray_ResultType(npy_intp narrs, PyArrayObject **arr,
- npy_intp ndtypes, PyArray_Descr **dtypes)
+NPY_NO_EXPORT int
+should_use_min_scalar(npy_intp narrs, PyArrayObject **arr,
+ npy_intp ndtypes, PyArray_Descr **dtypes)
{
- npy_intp i;
- int use_min_scalar;
-
- /* If there's just one type, pass it through */
- if (narrs + ndtypes == 1) {
- PyArray_Descr *ret = NULL;
- if (narrs == 1) {
- ret = PyArray_DESCR(arr[0]);
- }
- else {
- ret = dtypes[0];
- }
- Py_INCREF(ret);
- return ret;
- }
+ int use_min_scalar = 0;
- /*
- * Determine if there are any scalars, and if so, whether
- * the maximum "kind" of the scalars surpasses the maximum
- * "kind" of the arrays
- */
- use_min_scalar = 0;
if (narrs > 0) {
int all_scalars;
int max_scalar_kind = -1;
@@ -1815,7 +1791,7 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr,
all_scalars = (ndtypes > 0) ? 0 : 1;
/* Compute the maximum "kinds" and whether everything is scalar */
- for (i = 0; i < narrs; ++i) {
+ for (npy_intp i = 0; i < narrs; ++i) {
if (PyArray_NDIM(arr[i]) == 0) {
int kind = dtype_kind_to_simplified_ordering(
PyArray_DESCR(arr[i])->kind);
@@ -1836,7 +1812,7 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr,
* If the max scalar kind is bigger than the max array kind,
* finish computing the max array kind
*/
- for (i = 0; i < ndtypes; ++i) {
+ for (npy_intp i = 0; i < ndtypes; ++i) {
int kind = dtype_kind_to_simplified_ordering(dtypes[i]->kind);
if (kind > max_array_kind) {
max_array_kind = kind;
@@ -1848,6 +1824,44 @@ PyArray_ResultType(npy_intp narrs, PyArrayObject **arr,
use_min_scalar = 1;
}
}
+ return use_min_scalar;
+}
+
+
+/*NUMPY_API
+ * Produces the result type of a bunch of inputs, using the UFunc
+ * type promotion rules. Use this function when you have a set of
+ * input arrays, and need to determine an output array dtype.
+ *
+ * If all the inputs are scalars (have 0 dimensions) or the maximum "kind"
+ * of the scalars is greater than the maximum "kind" of the arrays, does
+ * a regular type promotion.
+ *
+ * Otherwise, does a type promotion on the MinScalarType
+ * of all the inputs. Data types passed directly are treated as array
+ * types.
+ *
+ */
+NPY_NO_EXPORT PyArray_Descr *
+PyArray_ResultType(npy_intp narrs, PyArrayObject **arr,
+ npy_intp ndtypes, PyArray_Descr **dtypes)
+{
+ npy_intp i;
+
+ /* If there's just one type, pass it through */
+ if (narrs + ndtypes == 1) {
+ PyArray_Descr *ret = NULL;
+ if (narrs == 1) {
+ ret = PyArray_DESCR(arr[0]);
+ }
+ else {
+ ret = dtypes[0];
+ }
+ Py_INCREF(ret);
+ return ret;
+ }
+
+ int use_min_scalar = should_use_min_scalar(narrs, arr, ndtypes, dtypes);
/* Loop through all the types, promoting them */
if (!use_min_scalar) {
diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h
index 653557161..72867ead8 100644
--- a/numpy/core/src/multiarray/convert_datatype.h
+++ b/numpy/core/src/multiarray/convert_datatype.h
@@ -18,6 +18,10 @@ NPY_NO_EXPORT npy_bool
can_cast_scalar_to(PyArray_Descr *scal_type, char *scal_data,
PyArray_Descr *to, NPY_CASTING casting);
+NPY_NO_EXPORT int
+should_use_min_scalar(npy_intp narrs, PyArrayObject **arr,
+ npy_intp ndtypes, PyArray_Descr **dtypes);
+
/*
* This function calls Py_DECREF on flex_dtype, and replaces it with
* a new dtype that has been adapted based on the values in data_dtype
diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c
index 58f915c6e..25dd002ac 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.c
+++ b/numpy/core/src/umath/ufunc_type_resolution.c
@@ -24,6 +24,7 @@
#include "ufunc_type_resolution.h"
#include "ufunc_object.h"
#include "common.h"
+#include "convert_datatype.h"
#include "mem_overlap.h"
#if defined(HAVE_CBLAS)
@@ -1938,73 +1939,6 @@ type_tuple_userloop_type_resolver(PyUFuncObject *self,
return 0;
}
-/*
- * Provides an ordering for the dtype 'kind' character codes, to help
- * determine when to use the min_scalar_type function. This groups
- * 'kind' into boolean, integer, floating point, and everything else.
- */
-
-static int
-dtype_kind_to_simplified_ordering(char kind)
-{
- switch (kind) {
- /* Boolean kind */
- case 'b':
- return 0;
- /* Unsigned int kind */
- case 'u':
- /* Signed int kind */
- case 'i':
- return 1;
- /* Float kind */
- case 'f':
- /* Complex kind */
- case 'c':
- return 2;
- /* Anything else */
- default:
- return 3;
- }
-}
-
-static int
-should_use_min_scalar(PyArrayObject **op, int nop)
-{
- int i, use_min_scalar, kind;
- int all_scalars = 1, max_scalar_kind = -1, max_array_kind = -1;
-
- /*
- * Determine if there are any scalars, and if so, whether
- * the maximum "kind" of the scalars surpasses the maximum
- * "kind" of the arrays
- */
- use_min_scalar = 0;
- if (nop > 1) {
- for(i = 0; i < nop; ++i) {
- kind = dtype_kind_to_simplified_ordering(
- PyArray_DESCR(op[i])->kind);
- if (PyArray_NDIM(op[i]) == 0) {
- if (kind > max_scalar_kind) {
- max_scalar_kind = kind;
- }
- }
- else {
- all_scalars = 0;
- if (kind > max_array_kind) {
- max_array_kind = kind;
- }
-
- }
- }
-
- /* Indicate whether to use the min_scalar_type function */
- if (!all_scalars && max_array_kind >= max_scalar_kind) {
- use_min_scalar = 1;
- }
- }
-
- return use_min_scalar;
-}
/*
* Does a linear search for the best inner loop of the ufunc.
@@ -2030,7 +1964,7 @@ linear_search_type_resolver(PyUFuncObject *self,
ufunc_name = ufunc_get_name_cstr(self);
- use_min_scalar = should_use_min_scalar(op, nin);
+ use_min_scalar = should_use_min_scalar(nin, op, 0, NULL);
/* If the ufunc has userloops, search for them. */
if (self->userloops) {
@@ -2139,7 +2073,7 @@ type_tuple_type_resolver(PyUFuncObject *self,
ufunc_name = ufunc_get_name_cstr(self);
- use_min_scalar = should_use_min_scalar(op, nin);
+ use_min_scalar = should_use_min_scalar(nin, op, 0, NULL);
/* Fill in specified_types from the tuple or string */
if (PyTuple_Check(type_tup)) {