diff options
-rw-r--r-- | numpy/core/code_generators/generate_umath.py | 8 | ||||
-rw-r--r-- | numpy/core/src/multiarray/_datetime.h | 6 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 36 | ||||
-rw-r--r-- | numpy/core/src/multiarray/datetime.c | 37 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 30 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 137 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.h | 18 | ||||
-rw-r--r-- | numpy/core/tests/test_datetime.py | 34 |
8 files changed, 264 insertions, 42 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index 945816ad6..b018e1948 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -332,7 +332,7 @@ defdict = { 'ones_like' : Ufunc(1, 1, None, docstrings.get('numpy.core.umath.ones_like'), - None, + 'PyUFunc_SimpleUnaryOperationTypeResolution', TD(noobj), TD(O, f='Py_get_one'), ), @@ -347,7 +347,7 @@ defdict = { 'absolute' : Ufunc(1, 1, None, docstrings.get('numpy.core.umath.absolute'), - None, + 'PyUFunc_AbsoluteTypeResolution', TD(bints+flts+timedeltaonly), TD(cmplx, out=('f', 'd', 'g')), TD(O, f='PyNumber_Absolute'), @@ -361,7 +361,7 @@ defdict = { 'negative' : Ufunc(1, 1, None, docstrings.get('numpy.core.umath.negative'), - None, + 'PyUFunc_SimpleUnaryOperationTypeResolution', TD(bints+flts+timedeltaonly), TD(cmplx, f='neg'), TD(O, f='PyNumber_Negative'), @@ -369,7 +369,7 @@ defdict = { 'sign' : Ufunc(1, 1, None, docstrings.get('numpy.core.umath.sign'), - None, + 'PyUFunc_SimpleUnaryOperationTypeResolution', TD(nobool_or_datetime), ), 'greater' : diff --git a/numpy/core/src/multiarray/_datetime.h b/numpy/core/src/multiarray/_datetime.h index dbff004b9..48903ebcf 100644 --- a/numpy/core/src/multiarray/_datetime.h +++ b/numpy/core/src/multiarray/_datetime.h @@ -223,4 +223,10 @@ convert_datetimestruct_to_datetime(PyArray_DatetimeMetaData *meta, const npy_datetimestruct *dts, npy_datetime *out); +/* + * Returns true if the datetime metadata matches + */ +NPY_NO_EXPORT npy_bool +has_equivalent_datetime_metadata(PyArray_Descr *type1, PyArray_Descr *type2); + #endif diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index a0d1c76d3..73d65c12d 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -373,15 +373,35 @@ PyArray_CanCastTypeTo(PyArray_Descr *from, PyArray_Descr *to, return ret; } - switch (casting) { - case NPY_NO_CASTING: - return PyArray_EquivTypes(from, to); - case NPY_EQUIV_CASTING: - return (from->elsize == to->elsize); - case NPY_SAFE_CASTING: - return (from->elsize <= to->elsize); + switch (from->type_num) { + case NPY_DATETIME: + case NPY_TIMEDELTA: + switch (casting) { + case NPY_NO_CASTING: + return PyArray_ISNBO(from->byteorder) == + PyArray_ISNBO(to->byteorder) && + has_equivalent_datetime_metadata(from, to); + case NPY_EQUIV_CASTING: + return has_equivalent_datetime_metadata(from, to); + case NPY_SAFE_CASTING: + return datetime_metadata_divides(from, to, + from->type_num == NPY_TIMEDELTA); + default: + return 1; + } + break; default: - return 1; + switch (casting) { + case NPY_NO_CASTING: + return PyArray_EquivTypes(from, to); + case NPY_EQUIV_CASTING: + return (from->elsize == to->elsize); + case NPY_SAFE_CASTING: + return (from->elsize <= to->elsize); + default: + return 1; + } + break; } } /* If safe or same-kind casts are allowed */ diff --git a/numpy/core/src/multiarray/datetime.c b/numpy/core/src/multiarray/datetime.c index b62784707..40c696593 100644 --- a/numpy/core/src/multiarray/datetime.c +++ b/numpy/core/src/multiarray/datetime.c @@ -1646,7 +1646,12 @@ datetime_metadata_divides( } } - return (num2 % num1) == 0; + /* Crude, incomplete check for overflow */ + if (num1&0xff00000000000000LL || num2&0xff00000000000000LL ) { + return 0; + } + + return (num1 % num2) == 0; } @@ -3025,4 +3030,34 @@ convert_datetime_to_pyobject(npy_datetime dt, PyArray_DatetimeMetaData *meta) } } +/* + * Returns true if the datetime metadata matches + */ +NPY_NO_EXPORT npy_bool +has_equivalent_datetime_metadata(PyArray_Descr *type1, PyArray_Descr *type2) +{ + PyArray_DatetimeMetaData *meta1, *meta2; + + if ((type1->type_num != NPY_DATETIME && + type1->type_num != NPY_TIMEDELTA) || + (type2->type_num != NPY_DATETIME && + type2->type_num != NPY_TIMEDELTA)) { + return 0; + } + + meta1 = get_datetime_metadata_from_dtype(type1); + if (meta1 == NULL) { + PyErr_Clear(); + return 0; + } + meta2 = get_datetime_metadata_from_dtype(type2); + if (meta2 == NULL) { + PyErr_Clear(); + return 0; + } + + return meta1->base == meta2->base && + meta1->num == meta2->num && + meta1->events == meta2->events; +} diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 4547ce070..523cb0581 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -1402,34 +1402,6 @@ _equivalent_fields(PyObject *field1, PyObject *field2) { } /* - * Compare the metadata for two date-times. - * Return 1 if they are the same or 0 if not. - */ -static int -_equivalent_datetime_units(PyArray_Descr *dtype1, PyArray_Descr *dtype2) -{ - PyArray_DatetimeMetaData *data1, *data2; - - data1 = get_datetime_metadata_from_dtype(dtype1); - data2 = get_datetime_metadata_from_dtype(dtype2); - - /* If there's a metadata problem, it doesn't match */ - if (data1 == NULL || data2 == NULL) { - PyErr_Clear(); - return 0; - } - - /* Same meta object */ - if (data1 == data2) { - return 1; - } - - return ((data1->base == data2->base) - && (data1->num == data2->num) - && (data1->events == data2->events)); -} - -/* * Compare the subarray data for two types. * Return 1 if they are the same, 0 if not. */ @@ -1500,7 +1472,7 @@ PyArray_EquivTypes(PyArray_Descr *type1, PyArray_Descr *type2) || type_num2 == NPY_TIMEDELTA || type_num2 == NPY_TIMEDELTA) { return ((type_num1 == type_num2) - && _equivalent_datetime_units(type1, type2)); + && has_equivalent_datetime_metadata(type1, type2)); } return type1->kind == type2->kind; } diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index b9e4c302e..36dea2c83 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -1881,6 +1881,114 @@ PyUFunc_SimpleBinaryComparisonTypeResolution(PyUFuncObject *ufunc, /* * This function applies special type resolution rules for the case + * where all the functions have the pattern X->X, copying + * the input descr directly so that metadata is maintained. + * + * Note that a simpler linear search through the functions loop + * is still done, but switching to a simple array lookup for + * built-in types would be better at some point. + * + * Returns 0 on success, -1 on error. + */ +NPY_NO_EXPORT int +PyUFunc_SimpleUnaryOperationTypeResolution(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes, + PyUFuncGenericFunction *out_innerloop, + void **out_innerloopdata) +{ + int i, type_num, type_num1; + char *ufunc_name; + + ufunc_name = ufunc->name ? ufunc->name : "<unnamed ufunc>"; + + if (ufunc->nin != 1 || ufunc->nout != 1) { + PyErr_Format(PyExc_RuntimeError, "ufunc %s is configured " + "to use unary operation type resolution but has " + "the wrong number of inputs or outputs", + ufunc_name); + return -1; + } + + /* + * Use the default type resolution if there's a custom data type + * or object arrays. + */ + type_num1 = PyArray_DESCR(operands[0])->type_num; + if (type_num1 >= NPY_NTYPES || type_num1 == NPY_OBJECT) { + return PyUFunc_DefaultTypeResolution(ufunc, casting, operands, + type_tup, out_dtypes, out_innerloop, out_innerloopdata); + } + + if (type_tup == NULL) { + /* Input types are the result type */ + out_dtypes[0] = PyArray_DESCR(operands[0]); + Py_INCREF(out_dtypes[0]); + out_dtypes[1] = out_dtypes[0]; + Py_INCREF(out_dtypes[1]); + } + else { + /* + * If the type tuple isn't a single-element tuple, let the + * default type resolution handle this one. + */ + if (!PyTuple_Check(type_tup) || PyTuple_GET_SIZE(type_tup) != 1) { + return PyUFunc_DefaultTypeResolution(ufunc, casting, operands, + type_tup, out_dtypes, out_innerloop, out_innerloopdata); + } + + if (!PyArray_DescrCheck(PyTuple_GET_ITEM(type_tup, 0))) { + PyErr_SetString(PyExc_ValueError, + "require data type in the type tuple"); + return -1; + } + + out_dtypes[0] = (PyArray_Descr *)PyTuple_GET_ITEM(type_tup, 0); + Py_INCREF(out_dtypes[0]); + out_dtypes[1] = out_dtypes[0]; + Py_INCREF(out_dtypes[1]); + } + + /* Check against the casting rules */ + if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) { + for (i = 0; i < 2; ++i) { + Py_DECREF(out_dtypes[i]); + out_dtypes[i] = NULL; + } + return -1; + } + + type_num = out_dtypes[0]->type_num; + + /* If we have a built-in type, search in the functions list */ + if (type_num < NPY_NTYPES) { + char *types = ufunc->types; + int n = ufunc->ntypes; + + for (i = 0; i < n; ++i) { + if (types[2*i] == type_num) { + *out_innerloop = ufunc->functions[i]; + *out_innerloopdata = ufunc->data[i]; + return 0; + } + } + + PyErr_Format(PyExc_TypeError, + "ufunc '%s' not supported for the input types", + ufunc_name); + return -1; + } + else { + PyErr_SetString(PyExc_RuntimeError, + "user type shouldn't have resulted from type promotion"); + return -1; + } +} + +/* + * This function applies special type resolution rules for the case * where all the functions have the pattern XX->X, using * PyArray_ResultType instead of a linear search to get the best * loop. @@ -1997,6 +2105,35 @@ PyUFunc_SimpleBinaryOperationTypeResolution(PyUFuncObject *ufunc, } /* + * This function applies special type resolution rules for the absolute + * ufunc. This ufunc converts complex -> float, so isn't covered + * by the simple unary type resolution. + * + * Returns 0 on success, -1 on error. + */ +NPY_NO_EXPORT int +PyUFunc_AbsoluteTypeResolution(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes, + PyUFuncGenericFunction *out_innerloop, + void **out_innerloopdata) +{ + /* Use the default for complex types, to find the loop producing float */ + if (PyTypeNum_ISCOMPLEX(PyArray_DESCR(operands[0])->type_num)) { + return PyUFunc_DefaultTypeResolution(ufunc, casting, operands, + type_tup, out_dtypes, out_innerloop, out_innerloopdata); + } + else { + return PyUFunc_SimpleUnaryOperationTypeResolution(ufunc, casting, + operands, type_tup, out_dtypes, out_innerloop, + out_innerloopdata); + } +} + + +/* * This function returns the a new reference to the * capsule with the datetime metadata. * diff --git a/numpy/core/src/umath/ufunc_object.h b/numpy/core/src/umath/ufunc_object.h index 828cdb226..2a5fd63a1 100644 --- a/numpy/core/src/umath/ufunc_object.h +++ b/numpy/core/src/umath/ufunc_object.h @@ -17,6 +17,15 @@ PyUFunc_SimpleBinaryComparisonTypeResolution(PyUFuncObject *ufunc, void **out_innerloopdata); NPY_NO_EXPORT int +PyUFunc_SimpleUnaryOperationTypeResolution(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes, + PyUFuncGenericFunction *out_innerloop, + void **out_innerloopdata); + +NPY_NO_EXPORT int PyUFunc_SimpleBinaryOperationTypeResolution(PyUFuncObject *ufunc, NPY_CASTING casting, PyArrayObject **operands, @@ -26,6 +35,15 @@ PyUFunc_SimpleBinaryOperationTypeResolution(PyUFuncObject *ufunc, void **out_innerloopdata); NPY_NO_EXPORT int +PyUFunc_AbsoluteTypeResolution(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes, + PyUFuncGenericFunction *out_innerloop, + void **out_innerloopdata); + +NPY_NO_EXPORT int PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc, NPY_CASTING casting, PyArrayObject **operands, diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py index 82681d58f..cd0c4b3df 100644 --- a/numpy/core/tests/test_datetime.py +++ b/numpy/core/tests/test_datetime.py @@ -287,6 +287,40 @@ class TestDateTime(TestCase): #b = np.array(3, dtype='m8[D]') #assert_raises(TypeError, np.less, a, b, casting='same_kind') + def test_datetime_like(self): + a = np.array([3], dtype='m8[4D]//6') + b = np.array(['2012-12-21'], dtype='M8[D]//3') + + assert_equal(np.ones_like(a).dtype, a.dtype) + assert_equal(np.zeros_like(a).dtype, a.dtype) + assert_equal(np.empty_like(a).dtype, a.dtype) + assert_equal(np.ones_like(b).dtype, b.dtype) + assert_equal(np.zeros_like(b).dtype, b.dtype) + assert_equal(np.empty_like(b).dtype, b.dtype) + + def test_datetime_unary(self): + tda = np.array(3, dtype='m8[D]') + tdb = np.array(-3, dtype='m8[D]') + tdzero = np.array(0, dtype='m8[D]') + tdone = np.array(1, dtype='m8[D]') + tdmone = np.array(-1, dtype='m8[D]') + + # negative ufunc + assert_equal(-tdb, tda) + assert_equal((-tdb).dtype, tda.dtype) + assert_equal(np.negative(tdb), tda) + assert_equal(np.negative(tdb).dtype, tda.dtype) + + # absolute ufunc + assert_equal(np.absolute(tdb), tda) + assert_equal(np.absolute(tdb).dtype, tda.dtype) + + # sign ufunc + assert_equal(np.sign(tda), tdone) + assert_equal(np.sign(tdb), tdmone) + assert_equal(np.sign(tdzero), tdzero) + assert_equal(np.sign(tda).dtype, tda.dtype) + def test_datetime_add(self): dta = np.array('2012-12-21', dtype='M8[D]') dtb = np.array('2012-12-24', dtype='M8[D]') |