summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwiebe@enthought.com>2011-05-27 18:41:40 -0500
committerMark Wiebe <mwiebe@enthought.com>2011-05-27 18:45:53 -0500
commit7fa5c2a1718a15cf5527a54a9c7fe0d71cf3eabe (patch)
tree3c874dd4dee15bc9c8eacfe8bae22ebb0e238c02
parent9f2b299f6b1dbc91c254f88ed46075387d9d6181 (diff)
downloadnumpy-7fa5c2a1718a15cf5527a54a9c7fe0d71cf3eabe.tar.gz
ENH: Add ufunc datetime-aware type resolution functions for multiply and divide
-rw-r--r--numpy/core/code_generators/generate_umath.py4
-rw-r--r--numpy/core/src/umath/ufunc_object.c262
-rw-r--r--numpy/core/src/umath/ufunc_object.h17
3 files changed, 280 insertions, 3 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index a0ce914c1..f6e5eb109 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -259,14 +259,14 @@ defdict = {
'multiply' :
Ufunc(2, 1, One,
docstrings.get('numpy.core.umath.multiply'),
- None,
+ 'PyUFunc_MultiplicationTypeResolution',
TD(notimes_or_obj),
TD(O, f='PyNumber_Multiply'),
),
'divide' :
Ufunc(2, 1, One,
docstrings.get('numpy.core.umath.divide'),
- None,
+ 'PyUFunc_DivisionTypeResolution',
TD(intfltcmplx),
TD(O, f='PyNumber_Divide'),
),
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 1df28e2f0..73f465264 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -2160,7 +2160,7 @@ type_reso_error: {
}
/*
- * This function applies the type resolution rules for addition.
+ * This function applies the type resolution rules for subtraction.
* In particular, there are a number of special cases with datetime:
* m8[<A>] - m8[<B>] => m8[gcd(<A>,<B>)] - m8[gcd(<A>,<B>)]
* m8[<A>] - int => m8[<A>] - m8[<A>]
@@ -2340,6 +2340,266 @@ type_reso_error: {
}
}
+/*
+ * This function applies the type resolution rules for multiplication.
+ * In particular, there are a number of special cases with datetime:
+ * int## * m8[<A>] => int64 * m8[<A>]
+ * m8[<A>] * int## => m8[<A>] * int64
+ * float## * m8[<A>] => float64 * m8[<A>]
+ * m8[<A>] * float## => m8[<A>] * float64
+ */
+NPY_NO_EXPORT int
+PyUFunc_MultiplicationTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata)
+{
+ int type_num1, type_num2;
+ char *types;
+ int i, n;
+ char *ufunc_name;
+
+ ufunc_name = ufunc->name ? ufunc->name : "<unnamed ufunc>";
+
+ type_num1 = PyArray_DESCR(operands[0])->type_num;
+ type_num2 = PyArray_DESCR(operands[1])->type_num;
+
+ /* Use the default when datetime and timedelta are not involved */
+ if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) {
+ return PyUFunc_DefaultTypeResolution(ufunc, casting, operands,
+ type_tup, out_dtypes, out_innerloop, out_innerloopdata);
+ }
+
+ if (type_num1 == NPY_TIMEDELTA) {
+ /* m8[<A>] * int## => m8[<A>] * int64 */
+ if (PyTypeNum_ISINTEGER(type_num2)) {
+ out_dtypes[0] = PyArray_DESCR(operands[0]);
+ Py_INCREF(out_dtypes[0]);
+ out_dtypes[1] = PyArray_DescrNewFromType(NPY_INT64);
+ if (out_dtypes[1] == NULL) {
+ Py_DECREF(out_dtypes[0]);
+ out_dtypes[0] = NULL;
+ return -1;
+ }
+ out_dtypes[2] = out_dtypes[0];
+ Py_INCREF(out_dtypes[2]);
+
+ type_num2 = NPY_INT64;
+ }
+ /* m8[<A>] * float## => m8[<A>] * float64 */
+ else if (PyTypeNum_ISFLOAT(type_num2)) {
+ out_dtypes[0] = PyArray_DESCR(operands[0]);
+ Py_INCREF(out_dtypes[0]);
+ out_dtypes[1] = PyArray_DescrNewFromType(NPY_DOUBLE);
+ if (out_dtypes[1] == NULL) {
+ Py_DECREF(out_dtypes[0]);
+ out_dtypes[0] = NULL;
+ return -1;
+ }
+ out_dtypes[2] = out_dtypes[0];
+ Py_INCREF(out_dtypes[2]);
+
+ type_num2 = NPY_DOUBLE;
+ }
+ else {
+ goto type_reso_error;
+ }
+ }
+ else if (PyTypeNum_ISINTEGER(type_num1)) {
+ /* int## * m8[<A>] => int64 * m8[<A>] */
+ if (type_num2 == NPY_TIMEDELTA) {
+ out_dtypes[0] = PyArray_DescrNewFromType(NPY_INT64);
+ if (out_dtypes[0] == NULL) {
+ return -1;
+ }
+ out_dtypes[1] = PyArray_DESCR(operands[1]);
+ Py_INCREF(out_dtypes[1]);
+ out_dtypes[2] = out_dtypes[1];
+ Py_INCREF(out_dtypes[2]);
+
+ type_num1 = NPY_INT64;
+ }
+ else {
+ goto type_reso_error;
+ }
+ }
+ else if (PyTypeNum_ISFLOAT(type_num1)) {
+ /* float## * m8[<A>] => float64 * m8[<A>] */
+ if (type_num2 == NPY_TIMEDELTA) {
+ out_dtypes[0] = PyArray_DescrNewFromType(NPY_DOUBLE);
+ if (out_dtypes[0] == NULL) {
+ return -1;
+ }
+ out_dtypes[1] = PyArray_DESCR(operands[1]);
+ Py_INCREF(out_dtypes[1]);
+ out_dtypes[2] = out_dtypes[1];
+ Py_INCREF(out_dtypes[2]);
+
+ type_num1 = NPY_DOUBLE;
+ }
+ else {
+ goto type_reso_error;
+ }
+ }
+ else {
+ goto type_reso_error;
+ }
+
+ /* Check against the casting rules */
+ if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {
+ for (i = 0; i < 3; ++i) {
+ Py_DECREF(out_dtypes[i]);
+ out_dtypes[i] = NULL;
+ }
+ return -1;
+ }
+
+ /* Search in the functions list */
+ types = ufunc->types;
+ n = ufunc->ntypes;
+
+ for (i = 0; i < n; ++i) {
+ if (types[3*i] == type_num1 && types[3*i+1] == type_num2) {
+ *out_innerloop = ufunc->functions[i];
+ *out_innerloopdata = ufunc->data[i];
+ return 0;
+ }
+ }
+
+ PyErr_Format(PyExc_TypeError,
+ "internal error: could not find appropriate datetime "
+ "inner loop in %s ufunc", ufunc_name);
+ return -1;
+
+type_reso_error: {
+ PyObject *errmsg;
+ errmsg = PyUString_FromFormat("ufunc %s cannot use operands "
+ "with types ", ufunc_name);
+ PyUString_ConcatAndDel(&errmsg,
+ PyObject_Repr((PyObject *)PyArray_DESCR(operands[0])));
+ PyUString_ConcatAndDel(&errmsg,
+ PyUString_FromString(" and "));
+ PyUString_ConcatAndDel(&errmsg,
+ PyObject_Repr((PyObject *)PyArray_DESCR(operands[1])));
+ PyErr_SetObject(PyExc_TypeError, errmsg);
+ return -1;
+ }
+}
+
+/*
+ * This function applies the type resolution rules for division.
+ * In particular, there are a number of special cases with datetime:
+ * m8[<A>] / int## => m8[<A>] / int64
+ * m8[<A>] / float## => m8[<A>] / float64
+ */
+NPY_NO_EXPORT int
+PyUFunc_DivisionTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata)
+{
+ int type_num1, type_num2;
+ char *types;
+ int i, n;
+ char *ufunc_name;
+
+ ufunc_name = ufunc->name ? ufunc->name : "<unnamed ufunc>";
+
+ type_num1 = PyArray_DESCR(operands[0])->type_num;
+ type_num2 = PyArray_DESCR(operands[1])->type_num;
+
+ /* Use the default when datetime and timedelta are not involved */
+ if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) {
+ return PyUFunc_DefaultTypeResolution(ufunc, casting, operands,
+ type_tup, out_dtypes, out_innerloop, out_innerloopdata);
+ }
+
+ if (type_num1 == NPY_TIMEDELTA) {
+ /* m8[<A>] / int## => m8[<A>] / int64 */
+ if (PyTypeNum_ISINTEGER(type_num2)) {
+ out_dtypes[0] = PyArray_DESCR(operands[0]);
+ Py_INCREF(out_dtypes[0]);
+ out_dtypes[1] = PyArray_DescrNewFromType(NPY_INT64);
+ if (out_dtypes[1] == NULL) {
+ Py_DECREF(out_dtypes[0]);
+ out_dtypes[0] = NULL;
+ return -1;
+ }
+ out_dtypes[2] = out_dtypes[0];
+ Py_INCREF(out_dtypes[2]);
+
+ type_num2 = NPY_INT64;
+ }
+ /* m8[<A>] / float## => m8[<A>] / float64 */
+ else if (PyTypeNum_ISFLOAT(type_num2)) {
+ out_dtypes[0] = PyArray_DESCR(operands[0]);
+ Py_INCREF(out_dtypes[0]);
+ out_dtypes[1] = PyArray_DescrNewFromType(NPY_DOUBLE);
+ if (out_dtypes[1] == NULL) {
+ Py_DECREF(out_dtypes[0]);
+ out_dtypes[0] = NULL;
+ return -1;
+ }
+ out_dtypes[2] = out_dtypes[0];
+ Py_INCREF(out_dtypes[2]);
+
+ type_num2 = NPY_DOUBLE;
+ }
+ else {
+ goto type_reso_error;
+ }
+ }
+ else {
+ goto type_reso_error;
+ }
+
+ /* Check against the casting rules */
+ if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {
+ for (i = 0; i < 3; ++i) {
+ Py_DECREF(out_dtypes[i]);
+ out_dtypes[i] = NULL;
+ }
+ return -1;
+ }
+
+ /* Search in the functions list */
+ types = ufunc->types;
+ n = ufunc->ntypes;
+
+ for (i = 0; i < n; ++i) {
+ if (types[3*i] == type_num1 && types[3*i+1] == type_num2) {
+ *out_innerloop = ufunc->functions[i];
+ *out_innerloopdata = ufunc->data[i];
+ return 0;
+ }
+ }
+
+ PyErr_Format(PyExc_TypeError,
+ "internal error: could not find appropriate datetime "
+ "inner loop in %s ufunc", ufunc_name);
+ return -1;
+
+type_reso_error: {
+ PyObject *errmsg;
+ errmsg = PyUString_FromFormat("ufunc %s cannot use operands "
+ "with types ", ufunc_name);
+ PyUString_ConcatAndDel(&errmsg,
+ PyObject_Repr((PyObject *)PyArray_DESCR(operands[0])));
+ PyUString_ConcatAndDel(&errmsg,
+ PyUString_FromString(" and "));
+ PyUString_ConcatAndDel(&errmsg,
+ PyObject_Repr((PyObject *)PyArray_DESCR(operands[1])));
+ PyErr_SetObject(PyExc_TypeError, errmsg);
+ return -1;
+ }
+}
+
/*UFUNC_API
*
* Validates that the input operands can be cast to
diff --git a/numpy/core/src/umath/ufunc_object.h b/numpy/core/src/umath/ufunc_object.h
index 9f5496e8a..3ddcb842f 100644
--- a/numpy/core/src/umath/ufunc_object.h
+++ b/numpy/core/src/umath/ufunc_object.h
@@ -25,4 +25,21 @@ PyUFunc_SubtractionTypeResolution(PyUFuncObject *ufunc,
PyUFuncGenericFunction *out_innerloop,
void **out_innerloopdata);
+NPY_NO_EXPORT int
+PyUFunc_MultiplicationTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata);
+NPY_NO_EXPORT int
+PyUFunc_DivisionTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata);
+
#endif