diff options
-rw-r--r-- | numpy/core/code_generators/generate_umath.py | 2 | ||||
-rw-r--r-- | numpy/core/src/multiarray/_datetime.h | 14 | ||||
-rw-r--r-- | numpy/core/src/multiarray/ctors.c | 3 | ||||
-rw-r--r-- | numpy/core/src/multiarray/datetime.c | 86 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 246 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.h | 9 |
6 files changed, 340 insertions, 20 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index d151b6619..05472f3c7 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -237,7 +237,7 @@ defdict = { 'add' : Ufunc(2, 1, Zero, docstrings.get('numpy.core.umath.add'), - None, + 'PyUFunc_AdditionTypeResolution', TD(notimes_or_obj), [TypeDescription('M', UsesArraysAsData, 'Mm', 'M'), TypeDescription('m', UsesArraysAsData, 'mm', 'm'), diff --git a/numpy/core/src/multiarray/_datetime.h b/numpy/core/src/multiarray/_datetime.h index 3a7e3426d..0ad7b3404 100644 --- a/numpy/core/src/multiarray/_datetime.h +++ b/numpy/core/src/multiarray/_datetime.h @@ -19,6 +19,13 @@ NPY_NO_EXPORT npy_datetime PyArray_TimedeltaStructToTimedelta(NPY_DATETIMEUNIT fr, npy_timedeltastruct *d); /* + * This function returns the a new reference to the + * capsule with the datetime metadata. + */ +NPY_NO_EXPORT PyObject * +get_datetime_metacobj_from_dtype(PyArray_Descr *dtype); + +/* * This function returns a pointer to the DateTimeMetaData * contained within the provided datetime dtype. */ @@ -60,6 +67,13 @@ NPY_NO_EXPORT PyArray_Descr * parse_dtype_from_datetime_typestr(char *typestr, Py_ssize_t len); /* + * Creates a new NPY_TIMEDELTA dtype, copying the datetime metadata + * from the given dtype. + */ +NPY_NO_EXPORT PyArray_Descr * +timedelta_dtype_with_copied_meta(PyArray_Descr *dtype); + +/* * Converts a substring given by 'str' and 'len' into * a date time unit enum value. The 'metastr' parameter * is used for error messages, and may be NULL. diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c index bbbf91f36..e3891af9f 100644 --- a/numpy/core/src/multiarray/ctors.c +++ b/numpy/core/src/multiarray/ctors.c @@ -950,6 +950,9 @@ PyArray_NewFromDescr(PyTypeObject *subtype, PyArray_Descr *descr, int nd, return NULL; } PyArray_DESCR_REPLACE(descr); + if (descr == NULL) { + return NULL; + } if (descr->type_num == NPY_STRING) { sd = descr->elsize = 1; } diff --git a/numpy/core/src/multiarray/datetime.c b/numpy/core/src/multiarray/datetime.c index f3fa0534c..729cc2c41 100644 --- a/numpy/core/src/multiarray/datetime.c +++ b/numpy/core/src/multiarray/datetime.c @@ -69,14 +69,6 @@ NPY_NO_EXPORT char *_datetime_strings[] = { * license version 1.0.0 */ -#define Py_AssertWithArg(x,errortype,errorstr,a1) \ - { \ - if (!(x)) { \ - PyErr_Format(errortype,errorstr,a1); \ - goto onError; \ - } \ - } - /* Table of number of days in a month (0-based, without and with leap) */ static int days_in_month[2][12] = { { 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31 }, @@ -1013,14 +1005,13 @@ PyArray_TimedeltaToTimedeltaStruct(npy_timedelta val, NPY_DATETIMEUNIT fr, } /* - * This function returns a pointer to the DateTimeMetaData - * contained within the provided datetime dtype. + * This function returns the a new reference to the + * capsule with the datetime metadata. */ -NPY_NO_EXPORT PyArray_DatetimeMetaData * -get_datetime_metadata_from_dtype(PyArray_Descr *dtype) +NPY_NO_EXPORT PyObject * +get_datetime_metacobj_from_dtype(PyArray_Descr *dtype) { - PyObject *tmp; - PyArray_DatetimeMetaData *meta = NULL; + PyObject *metacobj; /* Check that the dtype has metadata */ if (dtype->metadata == NULL) { @@ -1030,14 +1021,34 @@ get_datetime_metadata_from_dtype(PyArray_Descr *dtype) } /* Check that the dtype has unit metadata */ - tmp = PyDict_GetItemString(dtype->metadata, NPY_METADATA_DTSTR); - if (tmp == NULL) { + metacobj = PyDict_GetItemString(dtype->metadata, NPY_METADATA_DTSTR); + if (metacobj == NULL) { PyErr_SetString(PyExc_TypeError, "Datetime type object is invalid, lacks unit metadata"); return NULL; } + + Py_INCREF(metacobj); + return metacobj; +} + +/* + * This function returns a pointer to the DateTimeMetaData + * contained within the provided datetime dtype. + */ +NPY_NO_EXPORT PyArray_DatetimeMetaData * +get_datetime_metadata_from_dtype(PyArray_Descr *dtype) +{ + PyObject *metacobj; + PyArray_DatetimeMetaData *meta = NULL; + + metacobj = get_datetime_metacobj_from_dtype(dtype); + if (metacobj == NULL) { + return NULL; + } + /* Check that the dtype has an NpyCapsule for the metadata */ - meta = (PyArray_DatetimeMetaData *)NpyCapsule_AsVoidPtr(tmp); + meta = (PyArray_DatetimeMetaData *)NpyCapsule_AsVoidPtr(metacobj); if (meta == NULL) { PyErr_SetString(PyExc_TypeError, "Datetime type object is invalid, unit metadata is corrupt"); @@ -1206,10 +1217,10 @@ parse_dtype_from_datetime_typestr(char *typestr, Py_ssize_t len) /* Create a default datetime or timedelta */ if (is_timedelta) { - dtype = PyArray_DescrNewFromType(PyArray_TIMEDELTA); + dtype = PyArray_DescrNewFromType(NPY_TIMEDELTA); } else { - dtype = PyArray_DescrNewFromType(PyArray_DATETIME); + dtype = PyArray_DescrNewFromType(NPY_DATETIME); } if (dtype == NULL) { return NULL; @@ -1245,6 +1256,43 @@ parse_dtype_from_datetime_typestr(char *typestr, Py_ssize_t len) return dtype; } +/* + * Creates a new NPY_TIMEDELTA dtype, copying the datetime metadata + * from the given dtype. + */ +NPY_NO_EXPORT PyArray_Descr * +timedelta_dtype_with_copied_meta(PyArray_Descr *dtype) +{ + PyArray_Descr *ret; + PyObject *metacobj; + + ret = PyArray_DescrNewFromType(NPY_TIMEDELTA); + if (ret == NULL) { + return NULL; + } + Py_XDECREF(ret->metadata); + ret->metadata = PyDict_New(); + if (ret->metadata == NULL) { + Py_DECREF(ret); + return NULL; + } + + metacobj = get_datetime_metacobj_from_dtype(dtype); + if (metacobj == NULL) { + Py_DECREF(ret); + return NULL; + } + + if (PyDict_SetItemString(ret->metadata, NPY_METADATA_DTSTR, + metacobj) < 0) { + Py_DECREF(metacobj); + Py_DECREF(ret); + return NULL; + } + + return ret; +} + static NPY_DATETIMEUNIT _multiples_table[16][4] = { {12, 52, 365}, /* NPY_FR_Y */ {NPY_FR_M, NPY_FR_W, NPY_FR_D}, diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index fe2718620..2e5db4103 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -1880,6 +1880,252 @@ PyUFunc_BinaryComparisonTypeResolution(PyUFuncObject *ufunc, } } +/* + * This function returns the a new reference to the + * capsule with the datetime metadata. + * + * NOTE: This function is copied from datetime.c in multiarray, + * in order + */ +static PyObject * +get_datetime_metacobj_from_dtype(PyArray_Descr *dtype) +{ + PyObject *metacobj; + + /* Check that the dtype has metadata */ + if (dtype->metadata == NULL) { + PyErr_SetString(PyExc_TypeError, + "Datetime type object is invalid, lacks metadata"); + return NULL; + } + + /* Check that the dtype has unit metadata */ + metacobj = PyDict_GetItemString(dtype->metadata, NPY_METADATA_DTSTR); + if (metacobj == NULL) { + PyErr_SetString(PyExc_TypeError, + "Datetime type object is invalid, lacks unit metadata"); + return NULL; + } + + Py_INCREF(metacobj); + return metacobj; +} + +/* + * Creates a new NPY_TIMEDELTA dtype, copying the datetime metadata + * from the given dtype. + * + * NOTE: This function is copied from datetime.c in multiarray, + * in order + */ +static PyArray_Descr * +timedelta_dtype_with_copied_meta(PyArray_Descr *dtype) +{ + PyArray_Descr *ret; + PyObject *metacobj; + + ret = PyArray_DescrNewFromType(NPY_TIMEDELTA); + if (ret == NULL) { + return NULL; + } + Py_XDECREF(ret->metadata); + ret->metadata = PyDict_New(); + if (ret->metadata == NULL) { + Py_DECREF(ret); + return NULL; + } + + metacobj = get_datetime_metacobj_from_dtype(dtype); + if (metacobj == NULL) { + Py_DECREF(ret); + return NULL; + } + + if (PyDict_SetItemString(ret->metadata, NPY_METADATA_DTSTR, + metacobj) < 0) { + Py_DECREF(metacobj); + Py_DECREF(ret); + return NULL; + } + + return ret; +} + + + +/* + * This function applies the type resolution rules for addition. + * In particular, there are a number of special cases with datetime: + * m8[<A>] + m8[<B>] => m8[gcd(<A>,<B>)] + m8[gcd(<A>,<B>)] + * m8[<A>] + int => m8[<A>] + m8[<A>] + * int + m8[<A>] => m8[<A>] + m8[<A>] + * M8[<A>] + int => M8[<A>] + m8[<A>] + * int + M8[<A>] => m8[<A>] + M8[<A>] + * M8[<A>] + m8[<B>] => M8[<A>] + m8[<A>] + * m8[<A>] + M8[<B>] => m8[<B>] + M8[<B>] + * TODO: Non-linear time unit cases require highly special-cased loops + * M8[<A>] + m8[Y|M|B] + * m8[Y|M|B] + M8[<A>] + */ +NPY_NO_EXPORT int +PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes, + PyUFuncGenericFunction *out_innerloop, + void **out_innerloopdata) +{ + int type_num1, type_num2; + char *types; + int i, n; + + type_num1 = PyArray_DESCR(operands[0])->type_num; + type_num2 = PyArray_DESCR(operands[1])->type_num; + + /* Use the default when datetime and timedelta are not involved */ + if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) { + return PyUFunc_DefaultTypeResolution(ufunc, casting, operands, + type_tup, out_dtypes, out_innerloop, out_innerloopdata); + } + + if (type_num1 == NPY_TIMEDELTA) { + /* m8[<A>] + m8[<B>] => m8[gcd(<A>,<B>)] + m8[gcd(<A>,<B>)] */ + if (type_num2 == NPY_TIMEDELTA) { + out_dtypes[0] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]), + PyArray_DESCR(operands[1])); + if (out_dtypes[0] == NULL) { + return -1; + } + out_dtypes[1] = out_dtypes[0]; + Py_INCREF(out_dtypes[1]); + out_dtypes[2] = out_dtypes[0]; + Py_INCREF(out_dtypes[2]); + } + /* m8[<A>] + M8[<B>] => m8[<B>] + M8[<B>] */ + else if (type_num2 == NPY_DATETIME) { + /* Make a new NPY_TIMEDELTA, and copy type2's metadata */ + out_dtypes[0] = timedelta_dtype_with_copied_meta( + PyArray_DESCR(operands[1])); + if (out_dtypes[0] == NULL) { + return -1; + } + out_dtypes[1] = PyArray_DESCR(operands[1]); + Py_INCREF(out_dtypes[1]); + out_dtypes[2] = out_dtypes[1]; + Py_INCREF(out_dtypes[2]); + } + /* m8[<A>] + int => m8[<A>] + m8[<A>] */ + else if (PyTypeNum_ISINTEGER(type_num2)) { + out_dtypes[0] = PyArray_DESCR(operands[0]); + Py_INCREF(out_dtypes[0]); + out_dtypes[1] = out_dtypes[0]; + Py_INCREF(out_dtypes[1]); + out_dtypes[2] = out_dtypes[0]; + Py_INCREF(out_dtypes[2]); + + type_num2 = NPY_TIMEDELTA; + } + else { + goto type_reso_error; + } + } + else if (type_num1 == NPY_DATETIME) { + /* M8[<A>] + m8[<B>] => M8[<A>] + m8[<A>] */ + /* M8[<A>] + int => M8[<A>] + m8[<A>] */ + if (type_num2 == NPY_TIMEDELTA || + PyTypeNum_ISINTEGER(type_num2)) { + /* Make a new NPY_TIMEDELTA, and copy type1's metadata */ + out_dtypes[1] = timedelta_dtype_with_copied_meta( + PyArray_DESCR(operands[0])); + if (out_dtypes[1] == NULL) { + return -1; + } + out_dtypes[0] = PyArray_DESCR(operands[0]); + Py_INCREF(out_dtypes[0]); + out_dtypes[2] = out_dtypes[0]; + Py_INCREF(out_dtypes[2]); + + type_num2 = NPY_TIMEDELTA; + } + else { + goto type_reso_error; + } + } + else if (PyTypeNum_ISINTEGER(type_num1)) { + /* int + m8[<A>] => m8[<A>] + m8[<A>] */ + if (type_num2 == NPY_TIMEDELTA) { + out_dtypes[0] = PyArray_DESCR(operands[0]); + Py_INCREF(out_dtypes[0]); + out_dtypes[1] = out_dtypes[0]; + Py_INCREF(out_dtypes[1]); + out_dtypes[2] = out_dtypes[0]; + Py_INCREF(out_dtypes[2]); + + type_num1 = NPY_TIMEDELTA; + } + else if (type_num2 == NPY_DATETIME) { + /* Make a new NPY_TIMEDELTA, and copy type2's metadata */ + out_dtypes[0] = timedelta_dtype_with_copied_meta( + PyArray_DESCR(operands[1])); + if (out_dtypes[0] == NULL) { + return -1; + } + out_dtypes[1] = PyArray_DESCR(operands[1]); + Py_INCREF(out_dtypes[1]); + out_dtypes[2] = out_dtypes[1]; + Py_INCREF(out_dtypes[2]); + + type_num1 = NPY_TIMEDELTA; + } + else { + goto type_reso_error; + } + } + else { + goto type_reso_error; + } + + /* Check against the casting rules */ + if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) { + for (i = 0; i < 3; ++i) { + Py_DECREF(out_dtypes[i]); + out_dtypes[i] = NULL; + } + return -1; + } + + /* Search in the functions list */ + types = ufunc->types; + n = ufunc->ntypes; + + for (i = 0; i < n; ++i) { + if (types[3*i] == type_num1 && types[3*i+1] == type_num2) { + *out_innerloop = ufunc->functions[i]; + *out_innerloopdata = ufunc->data[i]; + return 0; + } + } + + PyErr_SetString(PyExc_TypeError, + "internal error: could not find appropriate datetime " + "inner loop in add ufunc"); + return -1; + +type_reso_error: { + PyObject *errmsg; + errmsg = PyUString_FromString("Cannot add operands with types "); + PyUString_ConcatAndDel(&errmsg, + PyObject_Repr((PyObject *)PyArray_DESCR(operands[0]))); + PyUString_ConcatAndDel(&errmsg, + PyUString_FromString(" and ")); + PyUString_ConcatAndDel(&errmsg, + PyObject_Repr((PyObject *)PyArray_DESCR(operands[1]))); + PyErr_SetObject(PyExc_TypeError, errmsg); + return -1; + } +} + /*UFUNC_API * * Validates that the input operands can be cast to diff --git a/numpy/core/src/umath/ufunc_object.h b/numpy/core/src/umath/ufunc_object.h index a8886be05..e078aab27 100644 --- a/numpy/core/src/umath/ufunc_object.h +++ b/numpy/core/src/umath/ufunc_object.h @@ -7,4 +7,13 @@ ufunc_geterr(PyObject *NPY_UNUSED(dummy), PyObject *args); NPY_NO_EXPORT PyObject * ufunc_seterr(PyObject *NPY_UNUSED(dummy), PyObject *args); +NPY_NO_EXPORT int +PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes, + PyUFuncGenericFunction *out_innerloop, + void **out_innerloopdata); + #endif |