diff options
author | Tyler Reddy <tyler.je.reddy@gmail.com> | 2018-10-15 10:24:24 -0700 |
---|---|---|
committer | Tyler Reddy <tyler.je.reddy@gmail.com> | 2018-10-15 10:24:24 -0700 |
commit | c9a6b02c347960f016ef28088ca8c63e0f2fe2f5 (patch) | |
tree | 4f0ff588d0b4b759065c953bdeaccb6655d4f3b0 /numpy | |
parent | 86ebcffb482afb67c2f6ec4f396d9017ea610bf1 (diff) | |
download | numpy-c9a6b02c347960f016ef28088ca8c63e0f2fe2f5.tar.gz |
ENH: add timedelta modulus
* added support for modulus operator
with timedelta operands; type signature
is mm->m
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/code_generators/generate_umath.py | 3 | ||||
-rw-r--r-- | numpy/core/src/umath/loops.c.src | 28 | ||||
-rw-r--r-- | numpy/core/src/umath/loops.h.src | 3 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_type_resolution.c | 51 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_type_resolution.h | 7 | ||||
-rw-r--r-- | numpy/core/tests/test_datetime.py | 71 |
6 files changed, 162 insertions, 1 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index 6dc01877b..199ad831b 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -791,8 +791,9 @@ defdict = { 'remainder': Ufunc(2, 1, None, docstrings.get('numpy.core.umath.remainder'), - None, + 'PyUFunc_RemainderTypeResolver', TD(intflt), + [TypeDescription('m', FullTypeDescr, 'mm', 'm')], TD(O, f='PyNumber_Remainder'), ), 'divmod': diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index e62942efd..8599d644a 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -1591,6 +1591,34 @@ TIMEDELTA_mm_d_divide(char **args, npy_intp *dimensions, npy_intp *steps, void * } } +NPY_NO_EXPORT void +TIMEDELTA_mm_m_remainder(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)) +{ + BINARY_LOOP { + const npy_timedelta in1 = *(npy_timedelta *)ip1; + const npy_timedelta in2 = *(npy_timedelta *)ip2; + if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) { + *((npy_timedelta *)op1) = NPY_DATETIME_NAT; + } + else { + if (in2 == 0) { + npy_set_floatstatus_divbyzero(); + *((npy_timedelta *)op1) = 0; + } + else { + /* handle mixed case the way Python does */ + const npy_timedelta rem = in1 % in2; + if ((in1 > 0) == (in2 > 0) || rem == 0) { + *((npy_timedelta *)op1) = rem; + } + else { + *((npy_timedelta *)op1) = rem + in2; + } + } + } + } +} + /* ***************************************************************************** ** FLOAT LOOPS ** diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src index 5c2b2c22c..9b6327308 100644 --- a/numpy/core/src/umath/loops.h.src +++ b/numpy/core/src/umath/loops.h.src @@ -473,6 +473,9 @@ TIMEDELTA_md_m_divide(char **args, npy_intp *dimensions, npy_intp *steps, void * NPY_NO_EXPORT void TIMEDELTA_mm_d_divide(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)); +NPY_NO_EXPORT void +TIMEDELTA_mm_m_remainder(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)); + /* Special case equivalents to above functions */ #define TIMEDELTA_mq_m_true_divide TIMEDELTA_mq_m_divide diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c index 5ddfe29ef..6b042d837 100644 --- a/numpy/core/src/umath/ufunc_type_resolution.c +++ b/numpy/core/src/umath/ufunc_type_resolution.c @@ -1173,6 +1173,57 @@ PyUFunc_DivisionTypeResolver(PyUFuncObject *ufunc, } +NPY_NO_EXPORT int +PyUFunc_RemainderTypeResolver(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes) +{ + int type_num1, type_num2; + int i; + + type_num1 = PyArray_DESCR(operands[0])->type_num; + type_num2 = PyArray_DESCR(operands[1])->type_num; + + /* Use the default when datetime and timedelta are not involved */ + if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) { + return PyUFunc_DefaultTypeResolver(ufunc, casting, operands, + type_tup, out_dtypes); + } + if (type_num1 == NPY_TIMEDELTA) { + if (type_num2 == NPY_TIMEDELTA) { + out_dtypes[0] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]), + PyArray_DESCR(operands[1])); + if (out_dtypes[0] == NULL) { + return -1; + } + out_dtypes[1] = out_dtypes[0]; + Py_INCREF(out_dtypes[1]); + out_dtypes[2] = out_dtypes[0]; + Py_INCREF(out_dtypes[2]); + } + else { + return raise_binary_type_reso_error(ufunc, operands); + } + } + else { + return raise_binary_type_reso_error(ufunc, operands); + } + + /* Check against the casting rules */ + if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) { + for (i = 0; i < 3; ++i) { + Py_DECREF(out_dtypes[i]); + out_dtypes[i] = NULL; + } + return -1; + } + + return 0; +} + + /* * True division should return float64 results when both inputs are integer * types. The PyUFunc_DefaultTypeResolver promotes 8 bit integers to float16 diff --git a/numpy/core/src/umath/ufunc_type_resolution.h b/numpy/core/src/umath/ufunc_type_resolution.h index fa9f1dbfa..bb4823d24 100644 --- a/numpy/core/src/umath/ufunc_type_resolution.h +++ b/numpy/core/src/umath/ufunc_type_resolution.h @@ -92,6 +92,13 @@ PyUFunc_DivisionTypeResolver(PyUFuncObject *ufunc, PyObject *type_tup, PyArray_Descr **out_dtypes); +NPY_NO_EXPORT int +PyUFunc_RemainderTypeResolver(PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes); + /* * Does a linear search for the best inner loop of the ufunc. * diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py index fe0e425fd..4ff5359ce 100644 --- a/numpy/core/tests/test_datetime.py +++ b/numpy/core/tests/test_datetime.py @@ -7,6 +7,7 @@ import datetime import pytest from numpy.testing import ( assert_, assert_equal, assert_raises, assert_warns, suppress_warnings, + assert_raises_regex, ) from numpy.core.numeric import pickle @@ -1611,6 +1612,76 @@ class TestDateTime(object): assert_raises(TypeError, np.arange, np.timedelta64(0, 'Y'), np.timedelta64(5, 'D')) + @pytest.mark.parametrize("val1, val2, expected", [ + # case from gh-12092 + (np.timedelta64(7, 's'), + np.timedelta64(3, 's'), + np.timedelta64(1, 's')), + # negative value cases + (np.timedelta64(3, 's'), + np.timedelta64(-2, 's'), + np.timedelta64(-1, 's')), + (np.timedelta64(-3, 's'), + np.timedelta64(2, 's'), + np.timedelta64(1, 's')), + # larger value cases + (np.timedelta64(17, 's'), + np.timedelta64(22, 's'), + np.timedelta64(17, 's')), + (np.timedelta64(22, 's'), + np.timedelta64(17, 's'), + np.timedelta64(5, 's')), + # different units + (np.timedelta64(1, 'm'), + np.timedelta64(57, 's'), + np.timedelta64(3, 's')), + (np.timedelta64(1, 'us'), + np.timedelta64(727, 'ns'), + np.timedelta64(273, 'ns')), + # NaT is propagated + (np.timedelta64('NaT'), + np.timedelta64(50, 'ns'), + np.timedelta64('NaT')), + # Y % M works + (np.timedelta64(2, 'Y'), + np.timedelta64(22, 'M'), + np.timedelta64(2, 'M')), + ]) + def test_timedelta_modulus(self, val1, val2, expected): + assert_equal(val1 % val2, expected) + + @pytest.mark.parametrize("val1, val2", [ + # years and months sometimes can't be unambiguously + # divided for modulus operation + (np.timedelta64(7, 'Y'), + np.timedelta64(3, 's')), + (np.timedelta64(7, 'M'), + np.timedelta64(1, 'D')), + ]) + def test_timedelta_modulus_error(self, val1, val2): + with assert_raises_regex(TypeError, "common metadata divisor"): + val1 % val2 + + def test_timedelta_modulus_div_by_zero(self): + with assert_warns(RuntimeWarning): + actual = np.timedelta64(10, 's') % np.timedelta64(0, 's') + assert_equal(actual, np.timedelta64(0, 's')) + + @pytest.mark.parametrize("val1, val2", [ + # cases where one operand is not + # timedelta64 + (np.timedelta64(7, 'Y'), + 15,), + (7.5, + np.timedelta64(1, 'D')), + ]) + def test_timedelta_modulus_type_resolution(self, val1, val2): + # NOTE: some of the operations may be supported + # in the future + with assert_raises_regex(TypeError, + "remainder cannot use operands with types"): + val1 % val2 + def test_timedelta_arange_no_dtype(self): d = np.array(5, dtype="m8[D]") assert_equal(np.arange(d, d + 1), d) |