diff options
-rw-r--r-- | numpy/core/code_generators/generate_umath.py | 4 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 262 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.h | 17 |
3 files changed, 280 insertions, 3 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index a0ce914c1..f6e5eb109 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -259,14 +259,14 @@ defdict = { 'multiply' : Ufunc(2, 1, One, docstrings.get('numpy.core.umath.multiply'), - None, + 'PyUFunc_MultiplicationTypeResolution', TD(notimes_or_obj), TD(O, f='PyNumber_Multiply'), ), 'divide' : Ufunc(2, 1, One, docstrings.get('numpy.core.umath.divide'), - None, + 'PyUFunc_DivisionTypeResolution', TD(intfltcmplx), TD(O, f='PyNumber_Divide'), ), diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 1df28e2f0..73f465264 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -2160,7 +2160,7 @@ type_reso_error: { } /* - * This function applies the type resolution rules for addition. + * This function applies the type resolution rules for subtraction. * In particular, there are a number of special cases with datetime: * m8[<A>] - m8[<B>] => m8[gcd(<A>,<B>)] - m8[gcd(<A>,<B>)] * m8[<A>] - int => m8[<A>] - m8[<A>] @@ -2340,6 +2340,266 @@ type_reso_error: { } } +/* + * This function applies the type resolution rules for multiplication. + * In particular, there are a number of special cases with datetime: + * int## * m8[<A>] => int64 * m8[<A>] + * m8[<A>] * int## => m8[<A>] * int64 + * float## * m8[<A>] => float64 * m8[<A>] + * m8[<A>] * float## => m8[<A>] * float64 + */ +NPY_NO_EXPORT int +PyUFunc_MultiplicationTypeResolution(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes, + PyUFuncGenericFunction *out_innerloop, + void **out_innerloopdata) +{ + int type_num1, type_num2; + char *types; + int i, n; + char *ufunc_name; + + ufunc_name = ufunc->name ? ufunc->name : "<unnamed ufunc>"; + + type_num1 = PyArray_DESCR(operands[0])->type_num; + type_num2 = PyArray_DESCR(operands[1])->type_num; + + /* Use the default when datetime and timedelta are not involved */ + if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) { + return PyUFunc_DefaultTypeResolution(ufunc, casting, operands, + type_tup, out_dtypes, out_innerloop, out_innerloopdata); + } + + if (type_num1 == NPY_TIMEDELTA) { + /* m8[<A>] * int## => m8[<A>] * int64 */ + if (PyTypeNum_ISINTEGER(type_num2)) { + out_dtypes[0] = PyArray_DESCR(operands[0]); + Py_INCREF(out_dtypes[0]); + out_dtypes[1] = PyArray_DescrNewFromType(NPY_INT64); + if (out_dtypes[1] == NULL) { + Py_DECREF(out_dtypes[0]); + out_dtypes[0] = NULL; + return -1; + } + out_dtypes[2] = out_dtypes[0]; + Py_INCREF(out_dtypes[2]); + + type_num2 = NPY_INT64; + } + /* m8[<A>] * float## => m8[<A>] * float64 */ + else if (PyTypeNum_ISFLOAT(type_num2)) { + out_dtypes[0] = PyArray_DESCR(operands[0]); + Py_INCREF(out_dtypes[0]); + out_dtypes[1] = PyArray_DescrNewFromType(NPY_DOUBLE); + if (out_dtypes[1] == NULL) { + Py_DECREF(out_dtypes[0]); + out_dtypes[0] = NULL; + return -1; + } + out_dtypes[2] = out_dtypes[0]; + Py_INCREF(out_dtypes[2]); + + type_num2 = NPY_DOUBLE; + } + else { + goto type_reso_error; + } + } + else if (PyTypeNum_ISINTEGER(type_num1)) { + /* int## * m8[<A>] => int64 * m8[<A>] */ + if (type_num2 == NPY_TIMEDELTA) { + out_dtypes[0] = PyArray_DescrNewFromType(NPY_INT64); + if (out_dtypes[0] == NULL) { + return -1; + } + out_dtypes[1] = PyArray_DESCR(operands[1]); + Py_INCREF(out_dtypes[1]); + out_dtypes[2] = out_dtypes[1]; + Py_INCREF(out_dtypes[2]); + + type_num1 = NPY_INT64; + } + else { + goto type_reso_error; + } + } + else if (PyTypeNum_ISFLOAT(type_num1)) { + /* float## * m8[<A>] => float64 * m8[<A>] */ + if (type_num2 == NPY_TIMEDELTA) { + out_dtypes[0] = PyArray_DescrNewFromType(NPY_DOUBLE); + if (out_dtypes[0] == NULL) { + return -1; + } + out_dtypes[1] = PyArray_DESCR(operands[1]); + Py_INCREF(out_dtypes[1]); + out_dtypes[2] = out_dtypes[1]; + Py_INCREF(out_dtypes[2]); + + type_num1 = NPY_DOUBLE; + } + else { + goto type_reso_error; + } + } + else { + goto type_reso_error; + } + + /* Check against the casting rules */ + if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) { + for (i = 0; i < 3; ++i) { + Py_DECREF(out_dtypes[i]); + out_dtypes[i] = NULL; + } + return -1; + } + + /* Search in the functions list */ + types = ufunc->types; + n = ufunc->ntypes; + + for (i = 0; i < n; ++i) { + if (types[3*i] == type_num1 && types[3*i+1] == type_num2) { + *out_innerloop = ufunc->functions[i]; + *out_innerloopdata = ufunc->data[i]; + return 0; + } + } + + PyErr_Format(PyExc_TypeError, + "internal error: could not find appropriate datetime " + "inner loop in %s ufunc", ufunc_name); + return -1; + +type_reso_error: { + PyObject *errmsg; + errmsg = PyUString_FromFormat("ufunc %s cannot use operands " + "with types ", ufunc_name); + PyUString_ConcatAndDel(&errmsg, + PyObject_Repr((PyObject *)PyArray_DESCR(operands[0]))); + PyUString_ConcatAndDel(&errmsg, + PyUString_FromString(" and ")); + PyUString_ConcatAndDel(&errmsg, + PyObject_Repr((PyObject *)PyArray_DESCR(operands[1]))); + PyErr_SetObject(PyExc_TypeError, errmsg); + return -1; + } +} + +/* + * This function applies the type resolution rules for division. + * In particular, there are a number of special cases with datetime: + * m8[<A>] / int## => m8[<A>] / int64 + * m8[<A>] / float## => m8[<A>] / float64 + */ +NPY_NO_EXPORT int +PyUFunc_DivisionTypeResolution(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes, + PyUFuncGenericFunction *out_innerloop, + void **out_innerloopdata) +{ + int type_num1, type_num2; + char *types; + int i, n; + char *ufunc_name; + + ufunc_name = ufunc->name ? ufunc->name : "<unnamed ufunc>"; + + type_num1 = PyArray_DESCR(operands[0])->type_num; + type_num2 = PyArray_DESCR(operands[1])->type_num; + + /* Use the default when datetime and timedelta are not involved */ + if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) { + return PyUFunc_DefaultTypeResolution(ufunc, casting, operands, + type_tup, out_dtypes, out_innerloop, out_innerloopdata); + } + + if (type_num1 == NPY_TIMEDELTA) { + /* m8[<A>] / int## => m8[<A>] / int64 */ + if (PyTypeNum_ISINTEGER(type_num2)) { + out_dtypes[0] = PyArray_DESCR(operands[0]); + Py_INCREF(out_dtypes[0]); + out_dtypes[1] = PyArray_DescrNewFromType(NPY_INT64); + if (out_dtypes[1] == NULL) { + Py_DECREF(out_dtypes[0]); + out_dtypes[0] = NULL; + return -1; + } + out_dtypes[2] = out_dtypes[0]; + Py_INCREF(out_dtypes[2]); + + type_num2 = NPY_INT64; + } + /* m8[<A>] / float## => m8[<A>] / float64 */ + else if (PyTypeNum_ISFLOAT(type_num2)) { + out_dtypes[0] = PyArray_DESCR(operands[0]); + Py_INCREF(out_dtypes[0]); + out_dtypes[1] = PyArray_DescrNewFromType(NPY_DOUBLE); + if (out_dtypes[1] == NULL) { + Py_DECREF(out_dtypes[0]); + out_dtypes[0] = NULL; + return -1; + } + out_dtypes[2] = out_dtypes[0]; + Py_INCREF(out_dtypes[2]); + + type_num2 = NPY_DOUBLE; + } + else { + goto type_reso_error; + } + } + else { + goto type_reso_error; + } + + /* Check against the casting rules */ + if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) { + for (i = 0; i < 3; ++i) { + Py_DECREF(out_dtypes[i]); + out_dtypes[i] = NULL; + } + return -1; + } + + /* Search in the functions list */ + types = ufunc->types; + n = ufunc->ntypes; + + for (i = 0; i < n; ++i) { + if (types[3*i] == type_num1 && types[3*i+1] == type_num2) { + *out_innerloop = ufunc->functions[i]; + *out_innerloopdata = ufunc->data[i]; + return 0; + } + } + + PyErr_Format(PyExc_TypeError, + "internal error: could not find appropriate datetime " + "inner loop in %s ufunc", ufunc_name); + return -1; + +type_reso_error: { + PyObject *errmsg; + errmsg = PyUString_FromFormat("ufunc %s cannot use operands " + "with types ", ufunc_name); + PyUString_ConcatAndDel(&errmsg, + PyObject_Repr((PyObject *)PyArray_DESCR(operands[0]))); + PyUString_ConcatAndDel(&errmsg, + PyUString_FromString(" and ")); + PyUString_ConcatAndDel(&errmsg, + PyObject_Repr((PyObject *)PyArray_DESCR(operands[1]))); + PyErr_SetObject(PyExc_TypeError, errmsg); + return -1; + } +} + /*UFUNC_API * * Validates that the input operands can be cast to diff --git a/numpy/core/src/umath/ufunc_object.h b/numpy/core/src/umath/ufunc_object.h index 9f5496e8a..3ddcb842f 100644 --- a/numpy/core/src/umath/ufunc_object.h +++ b/numpy/core/src/umath/ufunc_object.h @@ -25,4 +25,21 @@ PyUFunc_SubtractionTypeResolution(PyUFuncObject *ufunc, PyUFuncGenericFunction *out_innerloop, void **out_innerloopdata); +NPY_NO_EXPORT int +PyUFunc_MultiplicationTypeResolution(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes, + PyUFuncGenericFunction *out_innerloop, + void **out_innerloopdata); +NPY_NO_EXPORT int +PyUFunc_DivisionTypeResolution(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes, + PyUFuncGenericFunction *out_innerloop, + void **out_innerloopdata); + #endif |