diff options
-rw-r--r-- | numpy/core/src/umath/loops.c.src | 26 | ||||
-rw-r--r-- | numpy/core/tests/test_datetime.py | 34 |
2 files changed, 53 insertions, 7 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index 36046d9b8..e57dd5bd0 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -1286,11 +1286,17 @@ TIMEDELTA_md_m_multiply(char **args, npy_intp *dimensions, npy_intp *steps, void BINARY_LOOP { const npy_timedelta in1 = *(npy_timedelta *)ip1; const double in2 = *(double *)ip2; - if (in1 == NPY_DATETIME_NAT || npy_isnan(in2)) { + if (in1 == NPY_DATETIME_NAT) { *((npy_timedelta *)op1) = NPY_DATETIME_NAT; } else { - *((npy_timedelta *)op1) = (npy_timedelta)(in1 * in2); + double result = in1 * in2; + if (npy_isfinite(result)) { + *((npy_timedelta *)op1) = (npy_timedelta)result; + } + else { + *((npy_timedelta *)op1) = NPY_DATETIME_NAT; + } } } } @@ -1301,11 +1307,17 @@ TIMEDELTA_dm_m_multiply(char **args, npy_intp *dimensions, npy_intp *steps, void BINARY_LOOP { const double in1 = *(double *)ip1; const npy_timedelta in2 = *(npy_timedelta *)ip2; - if (npy_isnan(in1) || in2 == NPY_DATETIME_NAT) { + if (in2 == NPY_DATETIME_NAT) { *((npy_timedelta *)op1) = NPY_DATETIME_NAT; } else { - *((npy_timedelta *)op1) = (npy_timedelta)(in1 * in2); + double result = in1 * in2; + if (npy_isfinite(result)) { + *((npy_timedelta *)op1) = (npy_timedelta)result; + } + else { + *((npy_timedelta *)op1) = NPY_DATETIME_NAT; + } } } } @@ -1337,11 +1349,11 @@ TIMEDELTA_md_m_divide(char **args, npy_intp *dimensions, npy_intp *steps, void * } else { double result = in1 / in2; - if (npy_isnan(result)) { - *((npy_timedelta *)op1) = NPY_DATETIME_NAT; + if (npy_isfinite(result)) { + *((npy_timedelta *)op1) = (npy_timedelta)result; } else { - *((npy_timedelta *)op1) = (npy_timedelta)(result); + *((npy_timedelta *)op1) = NPY_DATETIME_NAT; } } } diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py index 956993c4e..5fa281867 100644 --- a/numpy/core/tests/test_datetime.py +++ b/numpy/core/tests/test_datetime.py @@ -1,6 +1,7 @@ from __future__ import division, absolute_import, print_function import pickle +import warnings import numpy import numpy as np @@ -961,6 +962,21 @@ class TestDateTime(TestCase): # float * M8 assert_raises(TypeError, np.multiply, 1.5, dta) + # NaTs + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=RuntimeWarning) + nat = np.timedelta64('NaT') + def check(a, b, res): + assert_equal(a * b, res) + assert_equal(b * a, res) + for tp in (int, float): + check(nat, tp(2), nat) + check(nat, tp(0), nat) + for f in (float('inf'), float('nan')): + check(np.timedelta64(1), f, nat) + check(np.timedelta64(0), f, nat) + check(nat, f, nat) + def test_datetime_divide(self): for dta, tda, tdb, tdc, tdd in \ [ @@ -1010,6 +1026,24 @@ class TestDateTime(TestCase): # float / M8 assert_raises(TypeError, np.divide, 1.5, dta) + # NaTs + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=RuntimeWarning) + nat = np.timedelta64('NaT') + for tp in (int, float): + assert_equal(np.timedelta64(1) / tp(0), nat) + assert_equal(np.timedelta64(0) / tp(0), nat) + assert_equal(nat / tp(0), nat) + assert_equal(nat / tp(2), nat) + # Division by inf + assert_equal(np.timedelta64(1) / float('inf'), np.timedelta64(0)) + assert_equal(np.timedelta64(0) / float('inf'), np.timedelta64(0)) + assert_equal(nat / float('inf'), nat) + # Division by nan + assert_equal(np.timedelta64(1) / float('nan'), nat) + assert_equal(np.timedelta64(0) / float('nan'), nat) + assert_equal(nat / float('nan'), nat) + def test_datetime_compare(self): # Test all the comparison operators a = np.datetime64('2000-03-12T18:00:00.000000-0600') |