summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwiebe@enthought.com>2011-05-27 15:40:22 -0500
committerMark Wiebe <mwiebe@enthought.com>2011-05-27 15:40:22 -0500
commitcb21fd9d3f4cbfc326d105ccea9bd603651cf12f (patch)
tree51260153ab090d9f405c884db19be8eb964f0430
parent4f2a2d9a1c9cae22c1e48ae3e91e958789fe7bfb (diff)
downloadnumpy-cb21fd9d3f4cbfc326d105ccea9bd603651cf12f.tar.gz
ENH: Create type resolution function for 'add' with special datetime rules
-rw-r--r--numpy/core/code_generators/generate_umath.py2
-rw-r--r--numpy/core/src/multiarray/_datetime.h14
-rw-r--r--numpy/core/src/multiarray/ctors.c3
-rw-r--r--numpy/core/src/multiarray/datetime.c86
-rw-r--r--numpy/core/src/umath/ufunc_object.c246
-rw-r--r--numpy/core/src/umath/ufunc_object.h9
6 files changed, 340 insertions, 20 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index d151b6619..05472f3c7 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -237,7 +237,7 @@ defdict = {
'add' :
Ufunc(2, 1, Zero,
docstrings.get('numpy.core.umath.add'),
- None,
+ 'PyUFunc_AdditionTypeResolution',
TD(notimes_or_obj),
[TypeDescription('M', UsesArraysAsData, 'Mm', 'M'),
TypeDescription('m', UsesArraysAsData, 'mm', 'm'),
diff --git a/numpy/core/src/multiarray/_datetime.h b/numpy/core/src/multiarray/_datetime.h
index 3a7e3426d..0ad7b3404 100644
--- a/numpy/core/src/multiarray/_datetime.h
+++ b/numpy/core/src/multiarray/_datetime.h
@@ -19,6 +19,13 @@ NPY_NO_EXPORT npy_datetime
PyArray_TimedeltaStructToTimedelta(NPY_DATETIMEUNIT fr, npy_timedeltastruct *d);
/*
+ * This function returns the a new reference to the
+ * capsule with the datetime metadata.
+ */
+NPY_NO_EXPORT PyObject *
+get_datetime_metacobj_from_dtype(PyArray_Descr *dtype);
+
+/*
* This function returns a pointer to the DateTimeMetaData
* contained within the provided datetime dtype.
*/
@@ -60,6 +67,13 @@ NPY_NO_EXPORT PyArray_Descr *
parse_dtype_from_datetime_typestr(char *typestr, Py_ssize_t len);
/*
+ * Creates a new NPY_TIMEDELTA dtype, copying the datetime metadata
+ * from the given dtype.
+ */
+NPY_NO_EXPORT PyArray_Descr *
+timedelta_dtype_with_copied_meta(PyArray_Descr *dtype);
+
+/*
* Converts a substring given by 'str' and 'len' into
* a date time unit enum value. The 'metastr' parameter
* is used for error messages, and may be NULL.
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c
index bbbf91f36..e3891af9f 100644
--- a/numpy/core/src/multiarray/ctors.c
+++ b/numpy/core/src/multiarray/ctors.c
@@ -950,6 +950,9 @@ PyArray_NewFromDescr(PyTypeObject *subtype, PyArray_Descr *descr, int nd,
return NULL;
}
PyArray_DESCR_REPLACE(descr);
+ if (descr == NULL) {
+ return NULL;
+ }
if (descr->type_num == NPY_STRING) {
sd = descr->elsize = 1;
}
diff --git a/numpy/core/src/multiarray/datetime.c b/numpy/core/src/multiarray/datetime.c
index f3fa0534c..729cc2c41 100644
--- a/numpy/core/src/multiarray/datetime.c
+++ b/numpy/core/src/multiarray/datetime.c
@@ -69,14 +69,6 @@ NPY_NO_EXPORT char *_datetime_strings[] = {
* license version 1.0.0
*/
-#define Py_AssertWithArg(x,errortype,errorstr,a1) \
- { \
- if (!(x)) { \
- PyErr_Format(errortype,errorstr,a1); \
- goto onError; \
- } \
- }
-
/* Table of number of days in a month (0-based, without and with leap) */
static int days_in_month[2][12] = {
{ 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31 },
@@ -1013,14 +1005,13 @@ PyArray_TimedeltaToTimedeltaStruct(npy_timedelta val, NPY_DATETIMEUNIT fr,
}
/*
- * This function returns a pointer to the DateTimeMetaData
- * contained within the provided datetime dtype.
+ * This function returns the a new reference to the
+ * capsule with the datetime metadata.
*/
-NPY_NO_EXPORT PyArray_DatetimeMetaData *
-get_datetime_metadata_from_dtype(PyArray_Descr *dtype)
+NPY_NO_EXPORT PyObject *
+get_datetime_metacobj_from_dtype(PyArray_Descr *dtype)
{
- PyObject *tmp;
- PyArray_DatetimeMetaData *meta = NULL;
+ PyObject *metacobj;
/* Check that the dtype has metadata */
if (dtype->metadata == NULL) {
@@ -1030,14 +1021,34 @@ get_datetime_metadata_from_dtype(PyArray_Descr *dtype)
}
/* Check that the dtype has unit metadata */
- tmp = PyDict_GetItemString(dtype->metadata, NPY_METADATA_DTSTR);
- if (tmp == NULL) {
+ metacobj = PyDict_GetItemString(dtype->metadata, NPY_METADATA_DTSTR);
+ if (metacobj == NULL) {
PyErr_SetString(PyExc_TypeError,
"Datetime type object is invalid, lacks unit metadata");
return NULL;
}
+
+ Py_INCREF(metacobj);
+ return metacobj;
+}
+
+/*
+ * This function returns a pointer to the DateTimeMetaData
+ * contained within the provided datetime dtype.
+ */
+NPY_NO_EXPORT PyArray_DatetimeMetaData *
+get_datetime_metadata_from_dtype(PyArray_Descr *dtype)
+{
+ PyObject *metacobj;
+ PyArray_DatetimeMetaData *meta = NULL;
+
+ metacobj = get_datetime_metacobj_from_dtype(dtype);
+ if (metacobj == NULL) {
+ return NULL;
+ }
+
/* Check that the dtype has an NpyCapsule for the metadata */
- meta = (PyArray_DatetimeMetaData *)NpyCapsule_AsVoidPtr(tmp);
+ meta = (PyArray_DatetimeMetaData *)NpyCapsule_AsVoidPtr(metacobj);
if (meta == NULL) {
PyErr_SetString(PyExc_TypeError,
"Datetime type object is invalid, unit metadata is corrupt");
@@ -1206,10 +1217,10 @@ parse_dtype_from_datetime_typestr(char *typestr, Py_ssize_t len)
/* Create a default datetime or timedelta */
if (is_timedelta) {
- dtype = PyArray_DescrNewFromType(PyArray_TIMEDELTA);
+ dtype = PyArray_DescrNewFromType(NPY_TIMEDELTA);
}
else {
- dtype = PyArray_DescrNewFromType(PyArray_DATETIME);
+ dtype = PyArray_DescrNewFromType(NPY_DATETIME);
}
if (dtype == NULL) {
return NULL;
@@ -1245,6 +1256,43 @@ parse_dtype_from_datetime_typestr(char *typestr, Py_ssize_t len)
return dtype;
}
+/*
+ * Creates a new NPY_TIMEDELTA dtype, copying the datetime metadata
+ * from the given dtype.
+ */
+NPY_NO_EXPORT PyArray_Descr *
+timedelta_dtype_with_copied_meta(PyArray_Descr *dtype)
+{
+ PyArray_Descr *ret;
+ PyObject *metacobj;
+
+ ret = PyArray_DescrNewFromType(NPY_TIMEDELTA);
+ if (ret == NULL) {
+ return NULL;
+ }
+ Py_XDECREF(ret->metadata);
+ ret->metadata = PyDict_New();
+ if (ret->metadata == NULL) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+
+ metacobj = get_datetime_metacobj_from_dtype(dtype);
+ if (metacobj == NULL) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+
+ if (PyDict_SetItemString(ret->metadata, NPY_METADATA_DTSTR,
+ metacobj) < 0) {
+ Py_DECREF(metacobj);
+ Py_DECREF(ret);
+ return NULL;
+ }
+
+ return ret;
+}
+
static NPY_DATETIMEUNIT _multiples_table[16][4] = {
{12, 52, 365}, /* NPY_FR_Y */
{NPY_FR_M, NPY_FR_W, NPY_FR_D},
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index fe2718620..2e5db4103 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -1880,6 +1880,252 @@ PyUFunc_BinaryComparisonTypeResolution(PyUFuncObject *ufunc,
}
}
+/*
+ * This function returns the a new reference to the
+ * capsule with the datetime metadata.
+ *
+ * NOTE: This function is copied from datetime.c in multiarray,
+ * in order
+ */
+static PyObject *
+get_datetime_metacobj_from_dtype(PyArray_Descr *dtype)
+{
+ PyObject *metacobj;
+
+ /* Check that the dtype has metadata */
+ if (dtype->metadata == NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "Datetime type object is invalid, lacks metadata");
+ return NULL;
+ }
+
+ /* Check that the dtype has unit metadata */
+ metacobj = PyDict_GetItemString(dtype->metadata, NPY_METADATA_DTSTR);
+ if (metacobj == NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "Datetime type object is invalid, lacks unit metadata");
+ return NULL;
+ }
+
+ Py_INCREF(metacobj);
+ return metacobj;
+}
+
+/*
+ * Creates a new NPY_TIMEDELTA dtype, copying the datetime metadata
+ * from the given dtype.
+ *
+ * NOTE: This function is copied from datetime.c in multiarray,
+ * in order
+ */
+static PyArray_Descr *
+timedelta_dtype_with_copied_meta(PyArray_Descr *dtype)
+{
+ PyArray_Descr *ret;
+ PyObject *metacobj;
+
+ ret = PyArray_DescrNewFromType(NPY_TIMEDELTA);
+ if (ret == NULL) {
+ return NULL;
+ }
+ Py_XDECREF(ret->metadata);
+ ret->metadata = PyDict_New();
+ if (ret->metadata == NULL) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+
+ metacobj = get_datetime_metacobj_from_dtype(dtype);
+ if (metacobj == NULL) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+
+ if (PyDict_SetItemString(ret->metadata, NPY_METADATA_DTSTR,
+ metacobj) < 0) {
+ Py_DECREF(metacobj);
+ Py_DECREF(ret);
+ return NULL;
+ }
+
+ return ret;
+}
+
+
+
+/*
+ * This function applies the type resolution rules for addition.
+ * In particular, there are a number of special cases with datetime:
+ * m8[<A>] + m8[<B>] => m8[gcd(<A>,<B>)] + m8[gcd(<A>,<B>)]
+ * m8[<A>] + int => m8[<A>] + m8[<A>]
+ * int + m8[<A>] => m8[<A>] + m8[<A>]
+ * M8[<A>] + int => M8[<A>] + m8[<A>]
+ * int + M8[<A>] => m8[<A>] + M8[<A>]
+ * M8[<A>] + m8[<B>] => M8[<A>] + m8[<A>]
+ * m8[<A>] + M8[<B>] => m8[<B>] + M8[<B>]
+ * TODO: Non-linear time unit cases require highly special-cased loops
+ * M8[<A>] + m8[Y|M|B]
+ * m8[Y|M|B] + M8[<A>]
+ */
+NPY_NO_EXPORT int
+PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata)
+{
+ int type_num1, type_num2;
+ char *types;
+ int i, n;
+
+ type_num1 = PyArray_DESCR(operands[0])->type_num;
+ type_num2 = PyArray_DESCR(operands[1])->type_num;
+
+ /* Use the default when datetime and timedelta are not involved */
+ if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) {
+ return PyUFunc_DefaultTypeResolution(ufunc, casting, operands,
+ type_tup, out_dtypes, out_innerloop, out_innerloopdata);
+ }
+
+ if (type_num1 == NPY_TIMEDELTA) {
+ /* m8[<A>] + m8[<B>] => m8[gcd(<A>,<B>)] + m8[gcd(<A>,<B>)] */
+ if (type_num2 == NPY_TIMEDELTA) {
+ out_dtypes[0] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]),
+ PyArray_DESCR(operands[1]));
+ if (out_dtypes[0] == NULL) {
+ return -1;
+ }
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+ out_dtypes[2] = out_dtypes[0];
+ Py_INCREF(out_dtypes[2]);
+ }
+ /* m8[<A>] + M8[<B>] => m8[<B>] + M8[<B>] */
+ else if (type_num2 == NPY_DATETIME) {
+ /* Make a new NPY_TIMEDELTA, and copy type2's metadata */
+ out_dtypes[0] = timedelta_dtype_with_copied_meta(
+ PyArray_DESCR(operands[1]));
+ if (out_dtypes[0] == NULL) {
+ return -1;
+ }
+ out_dtypes[1] = PyArray_DESCR(operands[1]);
+ Py_INCREF(out_dtypes[1]);
+ out_dtypes[2] = out_dtypes[1];
+ Py_INCREF(out_dtypes[2]);
+ }
+ /* m8[<A>] + int => m8[<A>] + m8[<A>] */
+ else if (PyTypeNum_ISINTEGER(type_num2)) {
+ out_dtypes[0] = PyArray_DESCR(operands[0]);
+ Py_INCREF(out_dtypes[0]);
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+ out_dtypes[2] = out_dtypes[0];
+ Py_INCREF(out_dtypes[2]);
+
+ type_num2 = NPY_TIMEDELTA;
+ }
+ else {
+ goto type_reso_error;
+ }
+ }
+ else if (type_num1 == NPY_DATETIME) {
+ /* M8[<A>] + m8[<B>] => M8[<A>] + m8[<A>] */
+ /* M8[<A>] + int => M8[<A>] + m8[<A>] */
+ if (type_num2 == NPY_TIMEDELTA ||
+ PyTypeNum_ISINTEGER(type_num2)) {
+ /* Make a new NPY_TIMEDELTA, and copy type1's metadata */
+ out_dtypes[1] = timedelta_dtype_with_copied_meta(
+ PyArray_DESCR(operands[0]));
+ if (out_dtypes[1] == NULL) {
+ return -1;
+ }
+ out_dtypes[0] = PyArray_DESCR(operands[0]);
+ Py_INCREF(out_dtypes[0]);
+ out_dtypes[2] = out_dtypes[0];
+ Py_INCREF(out_dtypes[2]);
+
+ type_num2 = NPY_TIMEDELTA;
+ }
+ else {
+ goto type_reso_error;
+ }
+ }
+ else if (PyTypeNum_ISINTEGER(type_num1)) {
+ /* int + m8[<A>] => m8[<A>] + m8[<A>] */
+ if (type_num2 == NPY_TIMEDELTA) {
+ out_dtypes[0] = PyArray_DESCR(operands[0]);
+ Py_INCREF(out_dtypes[0]);
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+ out_dtypes[2] = out_dtypes[0];
+ Py_INCREF(out_dtypes[2]);
+
+ type_num1 = NPY_TIMEDELTA;
+ }
+ else if (type_num2 == NPY_DATETIME) {
+ /* Make a new NPY_TIMEDELTA, and copy type2's metadata */
+ out_dtypes[0] = timedelta_dtype_with_copied_meta(
+ PyArray_DESCR(operands[1]));
+ if (out_dtypes[0] == NULL) {
+ return -1;
+ }
+ out_dtypes[1] = PyArray_DESCR(operands[1]);
+ Py_INCREF(out_dtypes[1]);
+ out_dtypes[2] = out_dtypes[1];
+ Py_INCREF(out_dtypes[2]);
+
+ type_num1 = NPY_TIMEDELTA;
+ }
+ else {
+ goto type_reso_error;
+ }
+ }
+ else {
+ goto type_reso_error;
+ }
+
+ /* Check against the casting rules */
+ if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {
+ for (i = 0; i < 3; ++i) {
+ Py_DECREF(out_dtypes[i]);
+ out_dtypes[i] = NULL;
+ }
+ return -1;
+ }
+
+ /* Search in the functions list */
+ types = ufunc->types;
+ n = ufunc->ntypes;
+
+ for (i = 0; i < n; ++i) {
+ if (types[3*i] == type_num1 && types[3*i+1] == type_num2) {
+ *out_innerloop = ufunc->functions[i];
+ *out_innerloopdata = ufunc->data[i];
+ return 0;
+ }
+ }
+
+ PyErr_SetString(PyExc_TypeError,
+ "internal error: could not find appropriate datetime "
+ "inner loop in add ufunc");
+ return -1;
+
+type_reso_error: {
+ PyObject *errmsg;
+ errmsg = PyUString_FromString("Cannot add operands with types ");
+ PyUString_ConcatAndDel(&errmsg,
+ PyObject_Repr((PyObject *)PyArray_DESCR(operands[0])));
+ PyUString_ConcatAndDel(&errmsg,
+ PyUString_FromString(" and "));
+ PyUString_ConcatAndDel(&errmsg,
+ PyObject_Repr((PyObject *)PyArray_DESCR(operands[1])));
+ PyErr_SetObject(PyExc_TypeError, errmsg);
+ return -1;
+ }
+}
+
/*UFUNC_API
*
* Validates that the input operands can be cast to
diff --git a/numpy/core/src/umath/ufunc_object.h b/numpy/core/src/umath/ufunc_object.h
index a8886be05..e078aab27 100644
--- a/numpy/core/src/umath/ufunc_object.h
+++ b/numpy/core/src/umath/ufunc_object.h
@@ -7,4 +7,13 @@ ufunc_geterr(PyObject *NPY_UNUSED(dummy), PyObject *args);
NPY_NO_EXPORT PyObject *
ufunc_seterr(PyObject *NPY_UNUSED(dummy), PyObject *args);
+NPY_NO_EXPORT int
+PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata);
+
#endif