summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/code_generators/generate_umath.py24
-rw-r--r--numpy/core/code_generators/numpy_api.py3
-rw-r--r--numpy/core/src/umath/loops.c.src3
-rw-r--r--numpy/core/src/umath/ufunc_object.c126
-rw-r--r--numpy/core/src/umath/ufunc_object.h18
-rw-r--r--numpy/core/tests/test_datetime.py183
6 files changed, 334 insertions, 23 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index eae75f7c4..db29b467e 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -375,7 +375,7 @@ defdict = {
'greater' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.greater'),
- 'PyUFunc_BinaryComparisonTypeResolution',
+ 'PyUFunc_SimpleBinaryComparisonTypeResolution',
TD(all, out='?'),
),
'greater_equal' :
@@ -387,31 +387,31 @@ defdict = {
'less' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.less'),
- 'PyUFunc_BinaryComparisonTypeResolution',
+ 'PyUFunc_SimpleBinaryComparisonTypeResolution',
TD(all, out='?'),
),
'less_equal' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.less_equal'),
- 'PyUFunc_BinaryComparisonTypeResolution',
+ 'PyUFunc_SimpleBinaryComparisonTypeResolution',
TD(all, out='?'),
),
'equal' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.equal'),
- 'PyUFunc_BinaryComparisonTypeResolution',
+ 'PyUFunc_SimpleBinaryComparisonTypeResolution',
TD(all, out='?'),
),
'not_equal' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.not_equal'),
- 'PyUFunc_BinaryComparisonTypeResolution',
+ 'PyUFunc_SimpleBinaryComparisonTypeResolution',
TD(all, out='?'),
),
'logical_and' :
Ufunc(2, 1, One,
docstrings.get('numpy.core.umath.logical_and'),
- 'PyUFunc_BinaryComparisonTypeResolution',
+ 'PyUFunc_SimpleBinaryComparisonTypeResolution',
TD(nodatetime_or_obj, out='?'),
TD(P, f='logical_and'),
),
@@ -425,42 +425,42 @@ defdict = {
'logical_or' :
Ufunc(2, 1, Zero,
docstrings.get('numpy.core.umath.logical_or'),
- 'PyUFunc_BinaryComparisonTypeResolution',
+ 'PyUFunc_SimpleBinaryComparisonTypeResolution',
TD(nodatetime_or_obj, out='?'),
TD(P, f='logical_or'),
),
'logical_xor' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.logical_xor'),
- 'PyUFunc_BinaryComparisonTypeResolution',
+ 'PyUFunc_SimpleBinaryComparisonTypeResolution',
TD(nodatetime_or_obj, out='?'),
TD(P, f='logical_xor'),
),
'maximum' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.maximum'),
- None,
+ 'PyUFunc_SimpleBinaryOperationTypeResolution',
TD(noobj),
TD(O, f='npy_ObjectMax')
),
'minimum' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.minimum'),
- None,
+ 'PyUFunc_SimpleBinaryOperationTypeResolution',
TD(noobj),
TD(O, f='npy_ObjectMin')
),
'fmax' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.fmax'),
- None,
+ 'PyUFunc_SimpleBinaryOperationTypeResolution',
TD(noobj),
TD(O, f='npy_ObjectMax')
),
'fmin' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.fmin'),
- None,
+ 'PyUFunc_SimpleBinaryOperationTypeResolution',
TD(noobj),
TD(O, f='npy_ObjectMin')
),
diff --git a/numpy/core/code_generators/numpy_api.py b/numpy/core/code_generators/numpy_api.py
index 3bff3a31e..a789ae683 100644
--- a/numpy/core/code_generators/numpy_api.py
+++ b/numpy/core/code_generators/numpy_api.py
@@ -363,8 +363,7 @@ ufunc_funcs_api = {
'PyUFunc_ee_e_As_dd_d': 38,
# End 1.6 API
'PyUFunc_DefaultTypeResolution': 39,
- 'PyUFunc_BinaryComparisonTypeResolution': 40,
- 'PyUFunc_ValidateCasting': 41,
+ 'PyUFunc_ValidateCasting': 40,
}
# List of all the dicts which define the C API
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index aaad5b1ba..60383bc45 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -1058,9 +1058,6 @@ NPY_NO_EXPORT void
/**end repeat**/
-/* FIXME: implement the following correctly using the metadata: data is the
- sequence of ndarrays in the same order as args.
- */
NPY_NO_EXPORT void
DATETIME_Mm_M_add(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(data))
{
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 73f465264..95c8463ad 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -1756,8 +1756,7 @@ PyUFunc_DefaultTypeResolution(PyUFuncObject *ufunc,
return retval;
}
-/*UFUNC_API
- *
+/*
* This function applies special type resolution rules for the case
* where all the functions have the pattern XX->bool, using
* PyArray_ResultType instead of a linear search to get the best
@@ -1770,7 +1769,7 @@ PyUFunc_DefaultTypeResolution(PyUFuncObject *ufunc,
* Returns 0 on success, -1 on error.
*/
NPY_NO_EXPORT int
-PyUFunc_BinaryComparisonTypeResolution(PyUFuncObject *ufunc,
+PyUFunc_SimpleBinaryComparisonTypeResolution(PyUFuncObject *ufunc,
NPY_CASTING casting,
PyArrayObject **operands,
PyObject *type_tup,
@@ -1881,6 +1880,123 @@ PyUFunc_BinaryComparisonTypeResolution(PyUFuncObject *ufunc,
}
/*
+ * This function applies special type resolution rules for the case
+ * where all the functions have the pattern XX->X, using
+ * PyArray_ResultType instead of a linear search to get the best
+ * loop.
+ *
+ * Note that a simpler linear search through the functions loop
+ * is still done, but switching to a simple array lookup for
+ * built-in types would be better at some point.
+ *
+ * Returns 0 on success, -1 on error.
+ */
+NPY_NO_EXPORT int
+PyUFunc_SimpleBinaryOperationTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata)
+{
+ int i, type_num, type_num1, type_num2;
+ char *ufunc_name;
+
+ ufunc_name = ufunc->name ? ufunc->name : "<unnamed ufunc>";
+
+ if (ufunc->nin != 2 || ufunc->nout != 1) {
+ PyErr_Format(PyExc_RuntimeError, "ufunc %s is configured "
+ "to use binary operation type resolution but has "
+ "the wrong number of inputs or outputs",
+ ufunc_name);
+ return -1;
+ }
+
+ /*
+ * Use the default type resolution if there's a custom data type
+ * or object arrays.
+ */
+ type_num1 = PyArray_DESCR(operands[0])->type_num;
+ type_num2 = PyArray_DESCR(operands[1])->type_num;
+ if (type_num1 >= NPY_NTYPES || type_num2 >= NPY_NTYPES ||
+ type_num1 == NPY_OBJECT || type_num2 == NPY_OBJECT) {
+ return PyUFunc_DefaultTypeResolution(ufunc, casting, operands,
+ type_tup, out_dtypes, out_innerloop, out_innerloopdata);
+ }
+
+ if (type_tup == NULL) {
+ /* Input types are the result type */
+ out_dtypes[0] = PyArray_ResultType(2, operands, 0, NULL);
+ if (out_dtypes[0] == NULL) {
+ return -1;
+ }
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+ out_dtypes[2] = out_dtypes[0];
+ Py_INCREF(out_dtypes[2]);
+ }
+ else {
+ /*
+ * If the type tuple isn't a single-element tuple, let the
+ * default type resolution handle this one.
+ */
+ if (!PyTuple_Check(type_tup) || PyTuple_GET_SIZE(type_tup) != 1) {
+ return PyUFunc_DefaultTypeResolution(ufunc, casting, operands,
+ type_tup, out_dtypes, out_innerloop, out_innerloopdata);
+ }
+
+ if (!PyArray_DescrCheck(PyTuple_GET_ITEM(type_tup, 0))) {
+ PyErr_SetString(PyExc_ValueError,
+ "require data type in the type tuple");
+ return -1;
+ }
+
+ out_dtypes[0] = (PyArray_Descr *)PyTuple_GET_ITEM(type_tup, 0);
+ Py_INCREF(out_dtypes[0]);
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+ out_dtypes[2] = out_dtypes[0];
+ Py_INCREF(out_dtypes[2]);
+ }
+
+ /* Check against the casting rules */
+ if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {
+ for (i = 0; i < 3; ++i) {
+ Py_DECREF(out_dtypes[i]);
+ out_dtypes[i] = NULL;
+ }
+ return -1;
+ }
+
+ type_num = out_dtypes[0]->type_num;
+
+ /* If we have a built-in type, search in the functions list */
+ if (type_num < NPY_NTYPES) {
+ char *types = ufunc->types;
+ int n = ufunc->ntypes;
+
+ for (i = 0; i < n; ++i) {
+ if (types[3*i] == type_num) {
+ *out_innerloop = ufunc->functions[i];
+ *out_innerloopdata = ufunc->data[i];
+ return 0;
+ }
+ }
+
+ PyErr_Format(PyExc_TypeError,
+ "ufunc '%s' not supported for the input types",
+ ufunc_name);
+ return -1;
+ }
+ else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "user type shouldn't have resulted from type promotion");
+ return -1;
+ }
+}
+
+/*
* This function returns the a new reference to the
* capsule with the datetime metadata.
*
@@ -2087,7 +2203,7 @@ PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc,
else if (PyTypeNum_ISINTEGER(type_num1)) {
/* int + m8[<A>] => m8[<A>] + m8[<A>] */
if (type_num2 == NPY_TIMEDELTA) {
- out_dtypes[0] = PyArray_DESCR(operands[0]);
+ out_dtypes[0] = PyArray_DESCR(operands[1]);
Py_INCREF(out_dtypes[0]);
out_dtypes[1] = out_dtypes[0];
Py_INCREF(out_dtypes[1]);
@@ -2282,7 +2398,7 @@ PyUFunc_SubtractionTypeResolution(PyUFuncObject *ufunc,
else if (PyTypeNum_ISINTEGER(type_num1)) {
/* int - m8[<A>] => m8[<A>] - m8[<A>] */
if (type_num2 == NPY_TIMEDELTA) {
- out_dtypes[0] = PyArray_DESCR(operands[0]);
+ out_dtypes[0] = PyArray_DESCR(operands[1]);
Py_INCREF(out_dtypes[0]);
out_dtypes[1] = out_dtypes[0];
Py_INCREF(out_dtypes[1]);
diff --git a/numpy/core/src/umath/ufunc_object.h b/numpy/core/src/umath/ufunc_object.h
index 3ddcb842f..828cdb226 100644
--- a/numpy/core/src/umath/ufunc_object.h
+++ b/numpy/core/src/umath/ufunc_object.h
@@ -8,6 +8,24 @@ NPY_NO_EXPORT PyObject *
ufunc_seterr(PyObject *NPY_UNUSED(dummy), PyObject *args);
NPY_NO_EXPORT int
+PyUFunc_SimpleBinaryComparisonTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata);
+
+NPY_NO_EXPORT int
+PyUFunc_SimpleBinaryOperationTypeResolution(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes,
+ PyUFuncGenericFunction *out_innerloop,
+ void **out_innerloopdata);
+
+NPY_NO_EXPORT int
PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc,
NPY_CASTING casting,
PyArrayObject **operands,
diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py
index fde277d4f..5b8aff14f 100644
--- a/numpy/core/tests/test_datetime.py
+++ b/numpy/core/tests/test_datetime.py
@@ -65,7 +65,7 @@ class TestDateTime(TestCase):
assert_equal(np.array('2001-03-22', dtype='M8[D]').astype('i8'),
(2000 - 1970)*365 + (2000 - 1972)/4 + 366 + 31 + 28 + 21)
- def test_days_to_pydatetime(self):
+ def test_days_to_pydate(self):
assert_equal(np.array('1599', dtype='M8[D]').astype('O'),
datetime.date(1599, 1, 1))
assert_equal(np.array('1600', dtype='M8[D]').astype('O'),
@@ -222,6 +222,187 @@ class TestDateTime(TestCase):
assert_equal(np.array('today', dtype=dt1),
np.array('today', dtype=dt2))
+ def test_datetime_add(self):
+ dta = np.array('2012-12-21', dtype='M8[D]')
+ dtb = np.array('2012-12-24', dtype='M8[D]')
+ dtc = np.array('1940-12-24', dtype='M8[D]')
+ tda = np.array(3, dtype='m8[D]')
+ tdb = np.array(11, dtype='m8[h]')
+ tdc = np.array(3*24 + 11, dtype='m8[h]')
+
+ # m8 + m8
+ assert_equal(tda + tdb, tdc)
+ assert_equal((tda + tdb).dtype, np.dtype('m8[h]'))
+ # m8 + int
+ assert_equal(tdb + 3*24, tdc)
+ assert_equal((tdb + 3*24).dtype, np.dtype('m8[h]'))
+ # int + m8
+ assert_equal(3*24 + tdb, tdc)
+ assert_equal((3*24 + tdb).dtype, np.dtype('m8[h]'))
+ # M8 + int
+ assert_equal(dta + 3, dtb)
+ assert_equal((dta + 3).dtype, np.dtype('M8[D]'))
+ # int + M8
+ assert_equal(3 + dta, dtb)
+ assert_equal((3 + dta).dtype, np.dtype('M8[D]'))
+ # M8 + m8
+ assert_equal(dta + tda, dtb)
+ assert_equal((dta + tda).dtype, np.dtype('M8[D]'))
+ # m8 + M8
+ assert_equal(tda + dta, dtb)
+ assert_equal((tda + dta).dtype, np.dtype('M8[D]'))
+
+ # In M8 + m8, the M8 controls the result type
+ assert_equal(dta + tdb, dta)
+ assert_equal((dta + tdb).dtype, np.dtype('M8[D]'))
+ assert_equal(dtc + tdb, dtc)
+ assert_equal((dtc + tdb).dtype, np.dtype('M8[D]'))
+ assert_equal(tdb + dta, dta)
+ assert_equal((tdb + dta).dtype, np.dtype('M8[D]'))
+ assert_equal(tdb + dtc, dtc)
+ assert_equal((tdb + dtc).dtype, np.dtype('M8[D]'))
+
+ # M8 + M8
+ assert_raises(TypeError, np.add, dta, dtb)
+
+ def test_datetime_subtract(self):
+ dta = np.array('2012-12-21', dtype='M8[D]')
+ dtb = np.array('2012-12-24', dtype='M8[D]')
+ dtc = np.array('1940-12-24', dtype='M8[D]')
+ dtd = np.array('1940-12-24', dtype='M8[h]')
+ tda = np.array(3, dtype='m8[D]')
+ tdb = np.array(11, dtype='m8[h]')
+ tdc = np.array(3*24 - 11, dtype='m8[h]')
+
+ # m8 - m8
+ assert_equal(tda - tdb, tdc)
+ assert_equal((tda - tdb).dtype, np.dtype('m8[h]'))
+ assert_equal(tdb - tda, -tdc)
+ assert_equal((tdb - tda).dtype, np.dtype('m8[h]'))
+ # m8 - int
+ assert_equal(tdc - 3*24, -tdb)
+ assert_equal((tdc - 3*24).dtype, np.dtype('m8[h]'))
+ # int - m8
+ assert_equal(3*24 - tdb, tdc)
+ assert_equal((3*24 - tdb).dtype, np.dtype('m8[h]'))
+ # M8 - int
+ assert_equal(dtb - 3, dta)
+ assert_equal((dtb - 3).dtype, np.dtype('M8[D]'))
+ # M8 - m8
+ assert_equal(dtb - tda, dta)
+ assert_equal((dtb - tda).dtype, np.dtype('M8[D]'))
+
+ # In M8 - m8, the M8 controls the result type
+ assert_equal(dta - tdb, dta)
+ assert_equal((dta - tdb).dtype, np.dtype('M8[D]'))
+ assert_equal(dtc - tdb, dtc)
+ assert_equal((dtc - tdb).dtype, np.dtype('M8[D]'))
+
+ # M8 - M8 with different metadata
+ assert_raises(TypeError, np.subtract, dtc, dtd)
+ # m8 - M8
+ assert_raises(TypeError, np.subtract, tda, dta)
+ # int - M8
+ assert_raises(TypeError, np.subtract, 3, dta)
+
+ def test_datetime_multiply(self):
+ dta = np.array('2012-12-21', dtype='M8[D]')
+ tda = np.array(6, dtype='m8[h]')
+ tdb = np.array(9, dtype='m8[h]')
+ tdc = np.array(12, dtype='m8[h]')
+
+ # m8 * int
+ assert_equal(tda * 2, tdc)
+ assert_equal((tda * 2).dtype, np.dtype('m8[h]'))
+ # int * m8
+ assert_equal(2 * tda, tdc)
+ assert_equal((2 * tda).dtype, np.dtype('m8[h]'))
+ # m8 * float
+ assert_equal(tda * 1.5, tdb)
+ assert_equal((tda * 1.5).dtype, np.dtype('m8[h]'))
+ # float * m8
+ assert_equal(1.5 * tda, tdb)
+ assert_equal((1.5 * tda).dtype, np.dtype('m8[h]'))
+
+ # m8 * m8
+ assert_raises(TypeError, np.multiply, tda, tdb)
+ # m8 * M8
+ assert_raises(TypeError, np.multiply, dta, tda)
+ # M8 * m8
+ assert_raises(TypeError, np.multiply, tda, dta)
+ # M8 * int
+ assert_raises(TypeError, np.multiply, dta, 2)
+ # int * M8
+ assert_raises(TypeError, np.multiply, 2, dta)
+ # M8 * float
+ assert_raises(TypeError, np.multiply, dta, 1.5)
+ # float * M8
+ assert_raises(TypeError, np.multiply, 1.5, dta)
+
+ def test_datetime_divide(self):
+ dta = np.array('2012-12-21', dtype='M8[D]')
+ tda = np.array(6, dtype='m8[h]')
+ tdb = np.array(9, dtype='m8[h]')
+ tdc = np.array(12, dtype='m8[h]')
+
+ # m8 / int
+ assert_equal(tdc / 2, tda)
+ assert_equal((tdc / 2).dtype, np.dtype('m8[h]'))
+ # m8 / float
+ assert_equal(tda / 0.5, tdc)
+ assert_equal((tda / 0.5).dtype, np.dtype('m8[h]'))
+
+ # int / m8
+ assert_raises(TypeError, np.divide, 2, tdb)
+ # float / m8
+ assert_raises(TypeError, np.divide, 0.5, tdb)
+ # m8 / m8
+ assert_raises(TypeError, np.divide, tda, tdb)
+ # m8 / M8
+ assert_raises(TypeError, np.divide, dta, tda)
+ # M8 / m8
+ assert_raises(TypeError, np.divide, tda, dta)
+ # M8 / int
+ assert_raises(TypeError, np.divide, dta, 2)
+ # int / M8
+ assert_raises(TypeError, np.divide, 2, dta)
+ # M8 / float
+ assert_raises(TypeError, np.divide, dta, 1.5)
+ # float / M8
+ assert_raises(TypeError, np.divide, 1.5, dta)
+
+ def test_datetime_minmax(self):
+ # The metadata of the result should become the GCD
+ # of the operand metadata
+ a = np.array('1999-03-12T13Z', dtype='M8[2m]')
+ b = np.array('1999-03-12T12Z', dtype='M8[s]')
+ assert_equal(np.minimum(a,b), b)
+ assert_equal(np.minimum(a,b).dtype, np.dtype('M8[s]'))
+ assert_equal(np.fmin(a,b), b)
+ assert_equal(np.fmin(a,b).dtype, np.dtype('M8[s]'))
+ assert_equal(np.maximum(a,b), a)
+ assert_equal(np.maximum(a,b).dtype, np.dtype('M8[s]'))
+ assert_equal(np.fmax(a,b), a)
+ assert_equal(np.fmax(a,b).dtype, np.dtype('M8[s]'))
+ # Viewed as integers, the comparison is opposite because
+ # of the units chosen
+ assert_equal(np.minimum(a.view('i8'),b.view('i8')), a.view('i8'))
+
+ # Also do timedelta
+ a = np.array(3, dtype='m8[h]')
+ b = np.array(3*3600 - 3, dtype='m8[s]')
+ assert_equal(np.minimum(a,b), b)
+ assert_equal(np.minimum(a,b).dtype, np.dtype('m8[s]'))
+ assert_equal(np.fmin(a,b), b)
+ assert_equal(np.fmin(a,b).dtype, np.dtype('m8[s]'))
+ assert_equal(np.maximum(a,b), a)
+ assert_equal(np.maximum(a,b).dtype, np.dtype('m8[s]'))
+ assert_equal(np.fmax(a,b), a)
+ assert_equal(np.fmax(a,b).dtype, np.dtype('m8[s]'))
+ # Viewed as integers, the comparison is opposite because
+ # of the units chosen
+ assert_equal(np.minimum(a.view('i8'),b.view('i8')), a.view('i8'))
+
def test_hours(self):
t = np.ones(3, dtype='M8[s]')
t[0] = 60*60*24 + 60*60*10