summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwiebe@enthought.com>2011-06-01 13:54:13 -0500
committerMark Wiebe <mwiebe@enthought.com>2011-06-01 13:54:13 -0500
commit33babc9304569a7773df31bf5dd265be35ece449 (patch)
tree1ffd5b130d36d7431a01787a2d29a5eab3ba3d3e
parentda6391a0590d51d2901e20b97db309aed62a7e83 (diff)
downloadnumpy-33babc9304569a7773df31bf5dd265be35ece449.tar.gz
ENH: datetime: Add more tests and type resolution for datetime
-rw-r--r--numpy/core/code_generators/generate_umath.py8
-rw-r--r--numpy/core/src/multiarray/_datetime.h6
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c36
-rw-r--r--numpy/core/src/multiarray/datetime.c37
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c30
-rw-r--r--numpy/core/src/umath/ufunc_object.c137
-rw-r--r--numpy/core/src/umath/ufunc_object.h18
-rw-r--r--numpy/core/tests/test_datetime.py34
8 files changed, 264 insertions, 42 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index 945816ad6..b018e1948 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -332,7 +332,7 @@ defdict = {
'ones_like' :
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.ones_like'),
- None,
+ 'PyUFunc_SimpleUnaryOperationTypeResolution',
TD(noobj),
TD(O, f='Py_get_one'),
),
@@ -347,7 +347,7 @@ defdict = {
'absolute' :
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.absolute'),
- None,
+ 'PyUFunc_AbsoluteTypeResolution',
TD(bints+flts+timedeltaonly),
TD(cmplx, out=('f', 'd', 'g')),
TD(O, f='PyNumber_Absolute'),
@@ -361,7 +361,7 @@ defdict = {
'negative' :
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.negative'),
- None,
+ 'PyUFunc_SimpleUnaryOperationTypeResolution',
TD(bints+flts+timedeltaonly),
TD(cmplx, f='neg'),
TD(O, f='PyNumber_Negative'),
@@ -369,7 +369,7 @@ defdict = {
'sign' :
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.sign'),
- None,
+ 'PyUFunc_SimpleUnaryOperationTypeResolution',
TD(nobool_or_datetime),
),
'greater' :
diff --git a/numpy/core/src/multiarray/_datetime.h b/numpy/core/src/multiarray/_datetime.h
index dbff004b9..48903ebcf 100644
--- a/numpy/core/src/multiarray/_datetime.h
+++ b/numpy/core/src/multiarray/_datetime.h
@@ -223,4 +223,10 @@ convert_datetimestruct_to_datetime(PyArray_DatetimeMetaData *meta,
const npy_datetimestruct *dts,
npy_datetime *out);
+/*
+ * Returns true if the datetime metadata matches
+ */
+NPY_NO_EXPORT npy_bool
+has_equivalent_datetime_metadata(PyArray_Descr *type1, PyArray_Descr *type2);
+
#endif
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index a0d1c76d3..73d65c12d 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -373,15 +373,35 @@ PyArray_CanCastTypeTo(PyArray_Descr *from, PyArray_Descr *to,
return ret;
}
- switch (casting) {
- case NPY_NO_CASTING:
- return PyArray_EquivTypes(from, to);
- case NPY_EQUIV_CASTING:
- return (from->elsize == to->elsize);
- case NPY_SAFE_CASTING:
- return (from->elsize <= to->elsize);
+ switch (from->type_num) {
+ case NPY_DATETIME:
+ case NPY_TIMEDELTA:
+ switch (casting) {
+ case NPY_NO_CASTING:
+ return PyArray_ISNBO(from->byteorder) ==
+ PyArray_ISNBO(to->byteorder) &&
+ has_equivalent_datetime_metadata(from, to);
+ case NPY_EQUIV_CASTING:
+ return has_equivalent_datetime_metadata(from, to);
+ case NPY_SAFE_CASTING:
+ return datetime_metadata_divides(from, to,
+ from->type_num == NPY_TIMEDELTA);
+ default:
+ return 1;
+ }
+ break;
default:
- return 1;
+ switch (casting) {
+ case NPY_NO_CASTING:
+ return PyArray_EquivTypes(from, to);
+ case NPY_EQUIV_CASTING:
+ return (from->elsize == to->elsize);
+ case NPY_SAFE_CASTING:
+ return (from->elsize <= to->elsize);
+ default:
+ return 1;
+ }
+ break;
}
}
/* If safe or same-kind casts are allowed */
diff --git a/numpy/core/src/multiarray/datetime.c b/numpy/core/src/multiarray/datetime.c
index b62784707..40c696593 100644
--- a/numpy/core/src/multiarray/datetime.c
+++ b/numpy/core/src/multiarray/datetime.c
@@ -1646,7 +1646,12 @@ datetime_metadata_divides(
}
}
- return (num2 % num1) == 0;
+ /* Crude, incomplete check for overflow */
+ if (num1&0xff00000000000000LL || num2&0xff00000000000000LL ) {
+ return 0;
+ }
+
+ return (num1 % num2) == 0;
}
@@ -3025,4 +3030,34 @@ convert_datetime_to_pyobject(npy_datetime dt, PyArray_DatetimeMetaData *meta)
}
}
+/*
+ * Returns true if the datetime metadata matches
+ */
+NPY_NO_EXPORT npy_bool
+has_equivalent_datetime_metadata(PyArray_Descr *type1, PyArray_Descr *type2)
+{
+ PyArray_DatetimeMetaData *meta1, *meta2;
+
+ if ((type1->type_num != NPY_DATETIME &&
+ type1->type_num != NPY_TIMEDELTA) ||
+ (type2->type_num != NPY_DATETIME &&
+ type2->type_num != NPY_TIMEDELTA)) {
+ return 0;
+ }
+
+ meta1 = get_datetime_metadata_from_dtype(type1);
+ if (meta1 == NULL) {
+ PyErr_Clear();
+ return 0;
+ }
+ meta2 = get_datetime_metadata_from_dtype(type2);
+ if (meta2 == NULL) {
+ PyErr_Clear();
+ return 0;
+ }
+
+ return meta1->base == meta2->base &&
+ meta1->num == meta2->num &&
+ meta1->events == meta2->events;
+}
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 4547ce070..523cb0581 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -1402,34 +1402,6 @@ _equivalent_fields(PyObject *field1, PyObject *field2) {
}
/*
- * Compare the metadata for two date-times.
- * Return 1 if they are the same or 0 if not.
- */
-static int
-_equivalent_datetime_units(PyArray_Descr *dtype1, PyArray_Descr *dtype2)
-{
- PyArray_DatetimeMetaData *data1, *data2;
-
- data1 = get_datetime_metadata_from_dtype(dtype1);
- data2 = get_datetime_metadata_from_dtype(dtype2);
-
- /* If there's a metadata problem, it doesn't match */
- if (data1 == NULL || data2 == NULL) {
- PyErr_Clear();
- return 0;
- }
-
- /* Same meta object */
- if (data1 == data2) {
- return 1;
- }
-
- return ((data1->base == data2->base)
- && (data1->num == data2->num)
- && (data1->events == data2->events));
-}
-
-/*
* Compare the subarray data for two types.
* Return 1 if they are the same, 0 if not.
*/
@@ -1500,7 +1472,7 @@ PyArray_EquivTypes(PyArray_Descr *type1, PyArray_Descr *type2)
|| type_num2 == NPY_TIMEDELTA
|| type_num2 == NPY_TIMEDELTA) {
return ((type_num1 == type_num2)
- && _equivalent_datetime_units(type1, type2));
+ && has_equivalent_datetime_metadata(type1, type2));
}
return type1->kind == type2->kind;
}
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index b9e4c302e..36dea2c83 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -1881,6 +1881,114 @@ PyUFunc_SimpleBinaryComparisonTypeResolution(PyUFuncObject *ufunc,
/*
* This function applies special type resolution rules for the case
+ * where all the functions have the pattern X->X, copying
+ * the input descr directly so that metadata is maintained.
+ *
+ * Note that a simpler linear search through the functions loop
+ * is still done, but switching to a simple array lookup for
+ * built-in types would be better at some point.
+ *
+ * Returns 0 on success, -1 on error.
+ */
+NPY_NO_EXPORT int
+PyUFunc_SimpleUnaryOperationTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata)
+{
+ int i, type_num, type_num1;
+ char *ufunc_name;
+
+ ufunc_name = ufunc->name ? ufunc->name : "<unnamed ufunc>";
+
+ if (ufunc->nin != 1 || ufunc->nout != 1) {
+ PyErr_Format(PyExc_RuntimeError, "ufunc %s is configured "
+ "to use unary operation type resolution but has "
+ "the wrong number of inputs or outputs",
+ ufunc_name);
+ return -1;
+ }
+
+ /*
+ * Use the default type resolution if there's a custom data type
+ * or object arrays.
+ */
+ type_num1 = PyArray_DESCR(operands[0])->type_num;
+ if (type_num1 >= NPY_NTYPES || type_num1 == NPY_OBJECT) {
+ return PyUFunc_DefaultTypeResolution(ufunc, casting, operands,
+ type_tup, out_dtypes, out_innerloop, out_innerloopdata);
+ }
+
+ if (type_tup == NULL) {
+ /* Input types are the result type */
+ out_dtypes[0] = PyArray_DESCR(operands[0]);
+ Py_INCREF(out_dtypes[0]);
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+ }
+ else {
+ /*
+ * If the type tuple isn't a single-element tuple, let the
+ * default type resolution handle this one.
+ */
+ if (!PyTuple_Check(type_tup) || PyTuple_GET_SIZE(type_tup) != 1) {
+ return PyUFunc_DefaultTypeResolution(ufunc, casting, operands,
+ type_tup, out_dtypes, out_innerloop, out_innerloopdata);
+ }
+
+ if (!PyArray_DescrCheck(PyTuple_GET_ITEM(type_tup, 0))) {
+ PyErr_SetString(PyExc_ValueError,
+ "require data type in the type tuple");
+ return -1;
+ }
+
+ out_dtypes[0] = (PyArray_Descr *)PyTuple_GET_ITEM(type_tup, 0);
+ Py_INCREF(out_dtypes[0]);
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+ }
+
+ /* Check against the casting rules */
+ if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {
+ for (i = 0; i < 2; ++i) {
+ Py_DECREF(out_dtypes[i]);
+ out_dtypes[i] = NULL;
+ }
+ return -1;
+ }
+
+ type_num = out_dtypes[0]->type_num;
+
+ /* If we have a built-in type, search in the functions list */
+ if (type_num < NPY_NTYPES) {
+ char *types = ufunc->types;
+ int n = ufunc->ntypes;
+
+ for (i = 0; i < n; ++i) {
+ if (types[2*i] == type_num) {
+ *out_innerloop = ufunc->functions[i];
+ *out_innerloopdata = ufunc->data[i];
+ return 0;
+ }
+ }
+
+ PyErr_Format(PyExc_TypeError,
+ "ufunc '%s' not supported for the input types",
+ ufunc_name);
+ return -1;
+ }
+ else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "user type shouldn't have resulted from type promotion");
+ return -1;
+ }
+}
+
+/*
+ * This function applies special type resolution rules for the case
* where all the functions have the pattern XX->X, using
* PyArray_ResultType instead of a linear search to get the best
* loop.
@@ -1997,6 +2105,35 @@ PyUFunc_SimpleBinaryOperationTypeResolution(PyUFuncObject *ufunc,
}
/*
+ * This function applies special type resolution rules for the absolute
+ * ufunc. This ufunc converts complex -> float, so isn't covered
+ * by the simple unary type resolution.
+ *
+ * Returns 0 on success, -1 on error.
+ */
+NPY_NO_EXPORT int
+PyUFunc_AbsoluteTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata)
+{
+ /* Use the default for complex types, to find the loop producing float */
+ if (PyTypeNum_ISCOMPLEX(PyArray_DESCR(operands[0])->type_num)) {
+ return PyUFunc_DefaultTypeResolution(ufunc, casting, operands,
+ type_tup, out_dtypes, out_innerloop, out_innerloopdata);
+ }
+ else {
+ return PyUFunc_SimpleUnaryOperationTypeResolution(ufunc, casting,
+ operands, type_tup, out_dtypes, out_innerloop,
+ out_innerloopdata);
+ }
+}
+
+
+/*
* This function returns the a new reference to the
* capsule with the datetime metadata.
*
diff --git a/numpy/core/src/umath/ufunc_object.h b/numpy/core/src/umath/ufunc_object.h
index 828cdb226..2a5fd63a1 100644
--- a/numpy/core/src/umath/ufunc_object.h
+++ b/numpy/core/src/umath/ufunc_object.h
@@ -17,6 +17,15 @@ PyUFunc_SimpleBinaryComparisonTypeResolution(PyUFuncObject *ufunc,
void **out_innerloopdata);
NPY_NO_EXPORT int
+PyUFunc_SimpleUnaryOperationTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata);
+
+NPY_NO_EXPORT int
PyUFunc_SimpleBinaryOperationTypeResolution(PyUFuncObject *ufunc,
NPY_CASTING casting,
PyArrayObject **operands,
@@ -26,6 +35,15 @@ PyUFunc_SimpleBinaryOperationTypeResolution(PyUFuncObject *ufunc,
void **out_innerloopdata);
NPY_NO_EXPORT int
+PyUFunc_AbsoluteTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata);
+
+NPY_NO_EXPORT int
PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc,
NPY_CASTING casting,
PyArrayObject **operands,
diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py
index 82681d58f..cd0c4b3df 100644
--- a/numpy/core/tests/test_datetime.py
+++ b/numpy/core/tests/test_datetime.py
@@ -287,6 +287,40 @@ class TestDateTime(TestCase):
#b = np.array(3, dtype='m8[D]')
#assert_raises(TypeError, np.less, a, b, casting='same_kind')
+ def test_datetime_like(self):
+ a = np.array([3], dtype='m8[4D]//6')
+ b = np.array(['2012-12-21'], dtype='M8[D]//3')
+
+ assert_equal(np.ones_like(a).dtype, a.dtype)
+ assert_equal(np.zeros_like(a).dtype, a.dtype)
+ assert_equal(np.empty_like(a).dtype, a.dtype)
+ assert_equal(np.ones_like(b).dtype, b.dtype)
+ assert_equal(np.zeros_like(b).dtype, b.dtype)
+ assert_equal(np.empty_like(b).dtype, b.dtype)
+
+ def test_datetime_unary(self):
+ tda = np.array(3, dtype='m8[D]')
+ tdb = np.array(-3, dtype='m8[D]')
+ tdzero = np.array(0, dtype='m8[D]')
+ tdone = np.array(1, dtype='m8[D]')
+ tdmone = np.array(-1, dtype='m8[D]')
+
+ # negative ufunc
+ assert_equal(-tdb, tda)
+ assert_equal((-tdb).dtype, tda.dtype)
+ assert_equal(np.negative(tdb), tda)
+ assert_equal(np.negative(tdb).dtype, tda.dtype)
+
+ # absolute ufunc
+ assert_equal(np.absolute(tdb), tda)
+ assert_equal(np.absolute(tdb).dtype, tda.dtype)
+
+ # sign ufunc
+ assert_equal(np.sign(tda), tdone)
+ assert_equal(np.sign(tdb), tdmone)
+ assert_equal(np.sign(tdzero), tdzero)
+ assert_equal(np.sign(tda).dtype, tda.dtype)
+
def test_datetime_add(self):
dta = np.array('2012-12-21', dtype='M8[D]')
dtb = np.array('2012-12-24', dtype='M8[D]')