summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwiebe@enthought.com>2011-05-26 11:53:10 -0500
committerMark Wiebe <mwiebe@enthought.com>2011-05-26 11:53:10 -0500
commit3729edf672540be7f17b62c4984b49f69212f706 (patch)
tree9a00f6844c0f508ad73b759c5b10fb3a0c50f34a
parenta12f0d1d2308044e87c9902c78d809a8fcb465f1 (diff)
downloadnumpy-3729edf672540be7f17b62c4984b49f69212f706.tar.gz
ENH: umath: Add a binary comparison type resolution function
Also remove some DATETIME ufuncs that didn't make sense (since the datetime type has no zero, these functions still make sense for TIMEDELTA).
-rw-r--r--numpy/core/code_generators/generate_umath.py17
-rw-r--r--numpy/core/code_generators/numpy_api.py4
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c97
-rw-r--r--numpy/core/src/multiarray/datetime.c96
-rw-r--r--numpy/core/src/umath/loops.c.src96
-rw-r--r--numpy/core/src/umath/ufunc_object.c187
6 files changed, 361 insertions, 136 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index 9382b1fae..f81291395 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -193,6 +193,7 @@ O = 'O'
P = 'P'
ints = 'bBhHiIlLqQ'
times = 'Mm'
+timedeltaonly = 'm'
intsO = ints + O
bints = '?' + ints
bintsO = bints + O
@@ -209,12 +210,14 @@ allP = bints+times+flts+cmplxP
nobool = all[1:]
noobj = all[:-3]+all[-2:]
nobool_or_obj = all[1:-3]+all[-2:]
+nobool_or_datetime = all[1:-2]+all[-1:]
intflt = ints+flts
intfltcmplx = ints+flts+cmplx
nocmplx = bints+times+flts
nocmplxO = nocmplx+O
nocmplxP = nocmplx+P
notimes_or_obj = bints + inexact
+nodatetime_or_obj = bints + inexact
# Find which code corresponds to int64.
int64 = ''
@@ -317,7 +320,7 @@ defdict = {
'absolute' :
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.absolute'),
- TD(bints+flts+times),
+ TD(bints+flts+timedeltaonly),
TD(cmplx, out=('f', 'd', 'g')),
TD(O, f='PyNumber_Absolute'),
),
@@ -329,14 +332,14 @@ defdict = {
'negative' :
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.negative'),
- TD(bints+flts+times),
+ TD(bints+flts+timedeltaonly),
TD(cmplx, f='neg'),
TD(O, f='PyNumber_Negative'),
),
'sign' :
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.sign'),
- TD(nobool),
+ TD(nobool_or_datetime),
),
'greater' :
Ufunc(2, 1, None,
@@ -371,25 +374,25 @@ defdict = {
'logical_and' :
Ufunc(2, 1, One,
docstrings.get('numpy.core.umath.logical_and'),
- TD(noobj, out='?'),
+ TD(nodatetime_or_obj, out='?'),
TD(P, f='logical_and'),
),
'logical_not' :
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.logical_not'),
- TD(noobj, out='?'),
+ TD(nodatetime_or_obj, out='?'),
TD(P, f='logical_not'),
),
'logical_or' :
Ufunc(2, 1, Zero,
docstrings.get('numpy.core.umath.logical_or'),
- TD(noobj, out='?'),
+ TD(nodatetime_or_obj, out='?'),
TD(P, f='logical_or'),
),
'logical_xor' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.logical_xor'),
- TD(noobj, out='?'),
+ TD(nodatetime_or_obj, out='?'),
TD(P, f='logical_xor'),
),
'maximum' :
diff --git a/numpy/core/code_generators/numpy_api.py b/numpy/core/code_generators/numpy_api.py
index db2c368dd..3bff3a31e 100644
--- a/numpy/core/code_generators/numpy_api.py
+++ b/numpy/core/code_generators/numpy_api.py
@@ -361,6 +361,10 @@ ufunc_funcs_api = {
'PyUFunc_ee_e': 36,
'PyUFunc_ee_e_As_ff_f': 37,
'PyUFunc_ee_e_As_dd_d': 38,
+ # End 1.6 API
+ 'PyUFunc_DefaultTypeResolution': 39,
+ 'PyUFunc_BinaryComparisonTypeResolution': 40,
+ 'PyUFunc_ValidateCasting': 41,
}
# List of all the dicts which define the C API
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index 50ea8d711..3db500195 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -724,91 +724,17 @@ PyArray_PromoteTypes(PyArray_Descr *type1, PyArray_Descr *type2)
}
break;
case NPY_DATETIME:
- /* 'M[A],'M[B]' -> M[A] when A==B, error otherwise */
- if (type_num2 == NPY_DATETIME) {
- PyArray_DatetimeMetaData *meta1, *meta2;
- meta1 = get_datetime_metadata_from_dtype(type1);
- if (meta1 == NULL) {
- return NULL;
- }
- meta2 = get_datetime_metadata_from_dtype(type2);
- if (meta2 == NULL) {
- return NULL;
- }
-
- if (meta1->base == meta2->base &&
- meta1->num == meta2->num &&
- meta1->events == meta2->events) {
- Py_INCREF(type1);
- return type1;
- }
- else {
- PyObject *errmsg;
- errmsg = PyUString_FromString("Cannot promote "
- "datetime types ");
- PyUString_ConcatAndDel(&errmsg,
- PyObject_Repr((PyObject *)type1));
- PyUString_ConcatAndDel(&errmsg,
- PyUString_FromString(" and "));
- PyUString_ConcatAndDel(&errmsg,
- PyObject_Repr((PyObject *)type2));
- PyUString_ConcatAndDel(&errmsg,
- PyUString_FromString(" because they have "
- "different units metadata"));
- PyErr_SetObject(PyExc_TypeError, errmsg);
- return NULL;
- }
+ if (type_num2 == NPY_DATETIME || type_num2 == NPY_TIMEDELTA) {
+ return datetime_type_promotion(type1, type2);
}
- /* 'M[A]','m[B]' -> 'M[A]' */
- else if (type_num2 == NPY_TIMEDELTA) {
+ else if (PyTypeNum_ISINTEGER(type_num2)) {
Py_INCREF(type1);
return type1;
}
break;
case NPY_TIMEDELTA:
- /* 'm[A]','M[B]' -> 'M[B]' */
- if (type_num2 == NPY_DATETIME) {
- Py_INCREF(type2);
- return type2;
- }
- /* 'm[A]','m[B]' -> 'm[gcd(A,B)]' */
- else if (type_num2 == NPY_TIMEDELTA) {
- PyObject *gcdmeta;
- PyArray_Descr *dtype;
-
- /* Get the metadata GCD */
- gcdmeta = compute_datetime_metadata_greatest_common_divisor(
- type1, type2);
- if (gcdmeta == NULL) {
- return NULL;
- }
-
- /* Create a TIMEDELTA dtype */
- dtype = PyArray_DescrNewFromType(PyArray_TIMEDELTA);
- if (dtype == NULL) {
- Py_DECREF(gcdmeta);
- return NULL;
- }
-
- /* Replace the metadata dictionary */
- Py_XDECREF(dtype->metadata);
- dtype->metadata = PyDict_New();
- if (dtype->metadata == NULL) {
- Py_DECREF(dtype);
- Py_DECREF(gcdmeta);
- return NULL;
- }
-
- /* Set the metadata object in the dictionary. */
- if (PyDict_SetItemString(dtype->metadata, NPY_METADATA_DTSTR,
- gcdmeta) < 0) {
- Py_DECREF(dtype);
- Py_DECREF(gcdmeta);
- return NULL;
- }
- Py_DECREF(gcdmeta);
-
- return dtype;
+ if (type_num2 == NPY_DATETIME || type_num2 == NPY_TIMEDELTA) {
+ return datetime_type_promotion(type1, type2);
}
else if (PyTypeNum_ISINTEGER(type_num2) ||
PyTypeNum_ISFLOAT(type_num2)) {
@@ -842,6 +768,19 @@ PyArray_PromoteTypes(PyArray_Descr *type1, PyArray_Descr *type2)
return type2;
}
break;
+ case NPY_DATETIME:
+ if (PyTypeNum_ISINTEGER(type_num1)) {
+ Py_INCREF(type2);
+ return type2;
+ }
+ break;
+ case NPY_TIMEDELTA:
+ if (PyTypeNum_ISINTEGER(type_num1) ||
+ PyTypeNum_ISFLOAT(type_num1)) {
+ Py_INCREF(type2);
+ return type2;
+ }
+ break;
}
/* For equivalent types we can return either */
diff --git a/numpy/core/src/multiarray/datetime.c b/numpy/core/src/multiarray/datetime.c
index a73e55e6c..2dca07e10 100644
--- a/numpy/core/src/multiarray/datetime.c
+++ b/numpy/core/src/multiarray/datetime.c
@@ -1581,6 +1581,88 @@ units_overflow: {
}
/*
+ * Uses type1's type_num and the gcd of the metadata to create
+ * the result type.
+ */
+static PyArray_Descr *
+datetime_gcd_type_promotion(PyArray_Descr *type1, PyArray_Descr *type2)
+{
+ PyObject *gcdmeta;
+ PyArray_Descr *dtype;
+
+ /* Get the metadata GCD */
+ gcdmeta = compute_datetime_metadata_greatest_common_divisor(
+ type1, type2);
+ if (gcdmeta == NULL) {
+ return NULL;
+ }
+
+ /* Create a DATETIME or TIMEDELTA dtype */
+ dtype = PyArray_DescrNewFromType(type1->type_num);
+ if (dtype == NULL) {
+ Py_DECREF(gcdmeta);
+ return NULL;
+ }
+
+ /* Replace the metadata dictionary */
+ Py_XDECREF(dtype->metadata);
+ dtype->metadata = PyDict_New();
+ if (dtype->metadata == NULL) {
+ Py_DECREF(dtype);
+ Py_DECREF(gcdmeta);
+ return NULL;
+ }
+
+ /* Set the metadata object in the dictionary. */
+ if (PyDict_SetItemString(dtype->metadata, NPY_METADATA_DTSTR,
+ gcdmeta) < 0) {
+ Py_DECREF(dtype);
+ Py_DECREF(gcdmeta);
+ return NULL;
+ }
+ Py_DECREF(gcdmeta);
+
+ return dtype;
+}
+
+/*
+ * Both type1 and type2 must be either NPY_DATETIME or NPY_TIMEDELTA.
+ * Applies the type promotion rules between the two types, returning
+ * the promoted type.
+ */
+NPY_NO_EXPORT PyArray_Descr *
+datetime_type_promotion(PyArray_Descr *type1, PyArray_Descr *type2)
+{
+ int type_num1, type_num2;
+
+ type_num1 = type1->type_num;
+ type_num2 = type2->type_num;
+
+ if (type_num1 == NPY_DATETIME) {
+ if (type_num2 == NPY_DATETIME) {
+ return datetime_gcd_type_promotion(type1, type2);
+ }
+ else if (type_num2 == NPY_TIMEDELTA) {
+ Py_INCREF(type1);
+ return type1;
+ }
+ }
+ else if (type_num1 == NPY_TIMEDELTA) {
+ if (type_num2 == NPY_DATETIME) {
+ Py_INCREF(type2);
+ return type2;
+ }
+ else if (type_num2 == NPY_TIMEDELTA) {
+ return datetime_gcd_type_promotion(type1, type2);
+ }
+ }
+
+ PyErr_SetString(PyExc_RuntimeError,
+ "Called datetime_type_promotion on non-datetype type");
+ return NULL;
+}
+
+/*
* Converts a substring given by 'str' and 'len' into
* a date time unit enum value. The 'metastr' parameter
* is used for error messages, and may be NULL.
@@ -2164,15 +2246,17 @@ parse_iso_8601_date(char *str, int len, npy_datetimestruct *out)
parse_timezone:
if (sublen == 0) {
- /* Only do this timezone adjustment for recent and future years */
+ /*
+ * ISO 8601 states to treat date-times without a timezone offset
+ * or 'Z' for UTC as local time. The C standard libary functions
+ * mktime and gmtime allow us to do this conversion.
+ *
+ * Only do this timezone adjustment for recent and future years.
+ */
if (out->year > 1900 && out->year < 10000) {
time_t rawtime = 0;
struct tm tm_;
- /*
- * ISO 8601 states to treat date-times without a timezone offset
- * or 'Z' for UTC as local time. The C standard libary functions
- * mktime and gmtime allow us to do this conversion.
- */
+
tm_.tm_sec = out->sec;
tm_.tm_min = out->min;
tm_.tm_hour = out->hour;
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index 54e5ac984..89b130fda 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -941,64 +941,96 @@ U@TYPE@_remainder(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(f
*****************************************************************************
*/
-/**begin repeat
- * #type = datetime, timedelta#
- * #TYPE = DATETIME, TIMEDELTA#
- * #ftype = double, double#
- */
+NPY_NO_EXPORT void
+TIMEDELTA_negative(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
+{
+ UNARY_LOOP {
+ const npy_timedelta in1 = *(npy_timedelta *)ip1;
+ *((npy_timedelta *)op1) = (npy_timedelta)(-(npy_timedelta)in1);
+ }
+}
NPY_NO_EXPORT void
-@TYPE@_ones_like(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(data))
+TIMEDELTA_absolute(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
{
- OUTPUT_LOOP {
- *((@type@ *)op1) = 1;
+ UNARY_LOOP {
+ const npy_timedelta in1 = *(npy_timedelta *)ip1;
+ *((npy_timedelta *)op1) = (in1 >= 0) ? in1 : -in1;
}
}
NPY_NO_EXPORT void
-@TYPE@_negative(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
+TIMEDELTA_sign(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
{
UNARY_LOOP {
- const @type@ in1 = *(@type@ *)ip1;
- *((@type@ *)op1) = (@type@)(-(@type@)in1);
+ const npy_timedelta in1 = *(npy_timedelta *)ip1;
+ *((npy_timedelta *)op1) = in1 > 0 ? 1 : (in1 < 0 ? -1 : 0);
}
}
NPY_NO_EXPORT void
-@TYPE@_logical_not(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
+TIMEDELTA_logical_not(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
{
UNARY_LOOP {
- const @type@ in1 = *(@type@ *)ip1;
+ const npy_timedelta in1 = *(npy_timedelta *)ip1;
*((Bool *)op1) = !in1;
}
}
+NPY_NO_EXPORT void
+TIMEDELTA_logical_xor(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
+{
+ BINARY_LOOP {
+ const npy_timedelta in1 = *(npy_timedelta *)ip1;
+ const npy_timedelta in2 = *(npy_timedelta *)ip2;
+ *((Bool *)op1)= (in1 && !in2) || (!in1 && in2);
+ }
+}
+
-/**begin repeat1
- * #kind = equal, not_equal, greater, greater_equal, less, less_equal,
- * logical_and, logical_or#
- * #OP = ==, !=, >, >=, <, <=, &&, ||#
+/**begin repeat
+ * #kind = logical_and, logical_or#
+ * #OP = &&, ||#
*/
NPY_NO_EXPORT void
-@TYPE@_@kind@(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
+TIMEDELTA_@kind@(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
{
BINARY_LOOP {
- const @type@ in1 = *(@type@ *)ip1;
- const @type@ in2 = *(@type@ *)ip2;
+ const npy_timedelta in1 = *(npy_timedelta *)ip1;
+ const npy_timedelta in2 = *(npy_timedelta *)ip2;
*((Bool *)op1) = in1 @OP@ in2;
}
}
-/**end repeat1**/
+/**end repeat**/
+
+
+/**begin repeat
+ * #type = datetime, timedelta#
+ * #TYPE = DATETIME, TIMEDELTA#
+ */
NPY_NO_EXPORT void
-@TYPE@_logical_xor(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
+@TYPE@_ones_like(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(data))
+{
+ OUTPUT_LOOP {
+ *((@type@ *)op1) = 1;
+ }
+}
+
+/**begin repeat1
+ * #kind = equal, not_equal, greater, greater_equal, less, less_equal#
+ * #OP = ==, !=, >, >=, <, <=#
+ */
+NPY_NO_EXPORT void
+@TYPE@_@kind@(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
{
BINARY_LOOP {
const @type@ in1 = *(@type@ *)ip1;
const @type@ in2 = *(@type@ *)ip2;
- *((Bool *)op1)= (in1 && !in2) || (!in1 && in2);
+ *((Bool *)op1) = in1 @OP@ in2;
}
}
+/**end repeat1**/
/**begin repeat1
* #kind = maximum, minimum#
@@ -1024,24 +1056,6 @@ NPY_NO_EXPORT void
}
/**end repeat1**/
-NPY_NO_EXPORT void
-@TYPE@_absolute(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
-{
- UNARY_LOOP {
- const @type@ in1 = *(@type@ *)ip1;
- *((@type@ *)op1) = (in1 >= 0) ? in1 : -in1;
- }
-}
-
-NPY_NO_EXPORT void
-@TYPE@_sign(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func))
-{
- UNARY_LOOP {
- const @type@ in1 = *(@type@ *)ip1;
- *((@type@ *)op1) = in1 > 0 ? 1 : (in1 < 0 ? -1 : 0);
- }
-}
-
/**end repeat**/
/* FIXME: implement the following correctly using the metadata: data is the
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index d93375b8e..693732562 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -1223,7 +1223,7 @@ find_ufunc_specified_userloop(PyUFuncObject *self,
PyUFuncGenericFunction *out_innerloop,
void **out_innerloopdata)
{
- npy_intp i, j, nin = self->nin, nop = nin + self->nout;
+ int i, j, nin = self->nin, nop = nin + self->nout;
PyUFunc_Loop1d *funcdata;
/* Use this to try to avoid repeating the same userdef loop search */
@@ -1705,7 +1705,16 @@ find_specified_ufunc_inner_loop(PyUFuncObject *self,
return -1;
}
-int generic_ufunc_type_resolution(PyUFuncObject *ufunc,
+/*UFUNC_API
+ *
+ * This function applies the default type resolution rules
+ * for the provided ufunc, filling out_dtypes, out_innerloop,
+ * and out_innerloopdata.
+ *
+ * Returns 0 on success, -1 on error.
+ */
+NPY_NO_EXPORT int
+PyUFunc_DefaultTypeResolution(PyUFuncObject *ufunc,
NPY_CASTING casting,
PyArrayObject **operands,
PyObject *type_tup,
@@ -1747,6 +1756,178 @@ int generic_ufunc_type_resolution(PyUFuncObject *ufunc,
return retval;
}
+/*UFUNC_API
+ *
+ * This function applies special type resolution rules for the case
+ * where all the functions have the pattern XX->bool, using
+ * PyArray_ResultType instead of a linear search to get the best
+ * loop.
+ *
+ * Note that a simpler linear search through the functions loop
+ * is still done, but switching to a simple array lookup for
+ * built-in types would be better at some point.
+ *
+ * Returns 0 on success, -1 on error.
+ */
+NPY_NO_EXPORT int
+PyUFunc_BinaryComparisonTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata)
+{
+ int i, type_num;
+ char *ufunc_name;
+
+ ufunc_name = ufunc->name ? ufunc->name : "<unnamed ufunc>";
+
+ /* Use the default type resolution if there's a custom data type */
+ if (PyArray_DESCR(operands[0])->type_num >= NPY_NTYPES ||
+ PyArray_DESCR(operands[1])->type_num >= NPY_NTYPES) {
+ return PyUFunc_DefaultTypeResolution(ufunc, casting, operands,
+ type_tup, out_dtypes, out_innerloop, out_innerloopdata);
+ }
+
+ if (type_tup == NULL) {
+ /* Input types are the result type */
+ out_dtypes[0] = PyArray_ResultType(2, operands, 0, NULL);
+ if (out_dtypes[0] == NULL) {
+ return -1;
+ }
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+
+ }
+ else {
+ /*
+ * If the type tuple isn't a single-element tuple, let the
+ * default type resolution handle this one.
+ */
+ if (!PyTuple_Check(type_tup) || PyTuple_GET_SIZE(type_tup) != 1) {
+ return PyUFunc_DefaultTypeResolution(ufunc, casting, operands,
+ type_tup, out_dtypes, out_innerloop, out_innerloopdata);
+ }
+
+ if (!PyArray_DescrCheck(PyTuple_GET_ITEM(type_tup, 0))) {
+ PyErr_SetString(PyExc_ValueError,
+ "require data type in the type tuple");
+ return -1;
+ }
+
+ out_dtypes[0] = (PyArray_Descr *)PyTuple_GET_ITEM(type_tup, 0);
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[0]);
+ Py_INCREF(out_dtypes[1]);
+ }
+
+ /* Output type is always boolean */
+ out_dtypes[2] = PyArray_DescrFromType(NPY_BOOL);
+ if (out_dtypes[2] == NULL) {
+ for (i = 0; i < 2; ++i) {
+ Py_DECREF(out_dtypes[i]);
+ out_dtypes[i] = NULL;
+ }
+ return -1;
+ }
+
+ /* Check against the casting rules */
+ if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {
+ for (i = 0; i < 3; ++i) {
+ Py_DECREF(out_dtypes[i]);
+ out_dtypes[i] = NULL;
+ }
+ return -1;
+ }
+
+ type_num = out_dtypes[0]->type_num;
+
+ /* If we have a built-in type, search in the functions list */
+ if (type_num < NPY_NTYPES) {
+ char *types = ufunc->types;
+ int n = ufunc->ntypes;
+
+ for (i = 0; i < n; ++i) {
+ if (types[3*i] == type_num) {
+ *out_innerloop = ufunc->functions[i];
+ *out_innerloopdata = ufunc->data[i];
+ return 0;
+ }
+ }
+
+ PyErr_Format(PyExc_TypeError,
+ "ufunc '%s' not supported for the input types",
+ ufunc_name);
+ return -1;
+ }
+ else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "user type shouldn't have resulted from type promotion");
+ return -1;
+ }
+}
+
+/*UFUNC_API
+ *
+ * Validates that the input operands can be cast to
+ * the input types, and the output types can be cast to
+ * the output operands where provided.
+ *
+ * Returns 0 on success, -1 (with exception raised) on validation failure.
+ */
+NPY_NO_EXPORT int
+PyUFunc_ValidateCasting(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyArray_Descr **dtypes)
+{
+ int i, nin = ufunc->nin, nop = nin + ufunc->nout;
+ char *ufunc_name;
+
+ ufunc_name = ufunc->name ? ufunc->name : "<unnamed ufunc>";
+
+ for (i = 0; i < nop; ++i) {
+ if (i < nin) {
+ if (!PyArray_CanCastArrayTo(operands[i], dtypes[i], casting)) {
+ PyObject *errmsg;
+ errmsg = PyUString_FromFormat("Cannot cast ufunc %s "
+ "input from ", ufunc_name);
+ PyUString_ConcatAndDel(&errmsg,
+ PyObject_Repr((PyObject *)PyArray_DESCR(operands[i])));
+ PyUString_ConcatAndDel(&errmsg,
+ PyUString_FromString(" to "));
+ PyUString_ConcatAndDel(&errmsg,
+ PyObject_Repr((PyObject *)dtypes[i]));
+ PyUString_ConcatAndDel(&errmsg,
+ PyUString_FromFormat(" with casting rule %s",
+ _casting_to_string(casting)));
+ PyErr_SetObject(PyExc_TypeError, errmsg);
+ return -1;
+ }
+ } else if (operands[i] != NULL) {
+ if (!PyArray_CanCastTypeTo(dtypes[i],
+ PyArray_DESCR(operands[i]), casting)) {
+ PyObject *errmsg;
+ errmsg = PyUString_FromFormat("Cannot cast ufunc %s "
+ "output from ", ufunc_name);
+ PyUString_ConcatAndDel(&errmsg,
+ PyObject_Repr((PyObject *)dtypes[i]));
+ PyUString_ConcatAndDel(&errmsg,
+ PyUString_FromString(" to "));
+ PyUString_ConcatAndDel(&errmsg,
+ PyObject_Repr((PyObject *)PyArray_DESCR(operands[i])));
+ PyUString_ConcatAndDel(&errmsg,
+ PyUString_FromFormat(" with casting rule %s",
+ _casting_to_string(casting)));
+ PyErr_SetObject(PyExc_TypeError, errmsg);
+ return -1;
+ }
+ }
+ }
+
+ return 0;
+}
static void
trivial_two_operand_loop(PyArrayObject **op,
@@ -4458,7 +4639,7 @@ PyUFunc_FromFuncAndDataAndSignature(PyUFuncGenericFunction *func, void **data,
self->obj = NULL;
self->userloops=NULL;
- self->type_resolution_function = &generic_ufunc_type_resolution;
+ self->type_resolution_function = &PyUFunc_DefaultTypeResolution;
if (name == NULL) {
self->name = "?";