summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/umath/loops.c.src26
-rw-r--r--numpy/core/tests/test_datetime.py34
2 files changed, 53 insertions, 7 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index 36046d9b8..e57dd5bd0 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -1286,11 +1286,17 @@ TIMEDELTA_md_m_multiply(char **args, npy_intp *dimensions, npy_intp *steps, void
BINARY_LOOP {
const npy_timedelta in1 = *(npy_timedelta *)ip1;
const double in2 = *(double *)ip2;
- if (in1 == NPY_DATETIME_NAT || npy_isnan(in2)) {
+ if (in1 == NPY_DATETIME_NAT) {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
else {
- *((npy_timedelta *)op1) = (npy_timedelta)(in1 * in2);
+ double result = in1 * in2;
+ if (npy_isfinite(result)) {
+ *((npy_timedelta *)op1) = (npy_timedelta)result;
+ }
+ else {
+ *((npy_timedelta *)op1) = NPY_DATETIME_NAT;
+ }
}
}
}
@@ -1301,11 +1307,17 @@ TIMEDELTA_dm_m_multiply(char **args, npy_intp *dimensions, npy_intp *steps, void
BINARY_LOOP {
const double in1 = *(double *)ip1;
const npy_timedelta in2 = *(npy_timedelta *)ip2;
- if (npy_isnan(in1) || in2 == NPY_DATETIME_NAT) {
+ if (in2 == NPY_DATETIME_NAT) {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
else {
- *((npy_timedelta *)op1) = (npy_timedelta)(in1 * in2);
+ double result = in1 * in2;
+ if (npy_isfinite(result)) {
+ *((npy_timedelta *)op1) = (npy_timedelta)result;
+ }
+ else {
+ *((npy_timedelta *)op1) = NPY_DATETIME_NAT;
+ }
}
}
}
@@ -1337,11 +1349,11 @@ TIMEDELTA_md_m_divide(char **args, npy_intp *dimensions, npy_intp *steps, void *
}
else {
double result = in1 / in2;
- if (npy_isnan(result)) {
- *((npy_timedelta *)op1) = NPY_DATETIME_NAT;
+ if (npy_isfinite(result)) {
+ *((npy_timedelta *)op1) = (npy_timedelta)result;
}
else {
- *((npy_timedelta *)op1) = (npy_timedelta)(result);
+ *((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
}
}
diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py
index 956993c4e..5fa281867 100644
--- a/numpy/core/tests/test_datetime.py
+++ b/numpy/core/tests/test_datetime.py
@@ -1,6 +1,7 @@
from __future__ import division, absolute_import, print_function
import pickle
+import warnings
import numpy
import numpy as np
@@ -961,6 +962,21 @@ class TestDateTime(TestCase):
# float * M8
assert_raises(TypeError, np.multiply, 1.5, dta)
+ # NaTs
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=RuntimeWarning)
+ nat = np.timedelta64('NaT')
+ def check(a, b, res):
+ assert_equal(a * b, res)
+ assert_equal(b * a, res)
+ for tp in (int, float):
+ check(nat, tp(2), nat)
+ check(nat, tp(0), nat)
+ for f in (float('inf'), float('nan')):
+ check(np.timedelta64(1), f, nat)
+ check(np.timedelta64(0), f, nat)
+ check(nat, f, nat)
+
def test_datetime_divide(self):
for dta, tda, tdb, tdc, tdd in \
[
@@ -1010,6 +1026,24 @@ class TestDateTime(TestCase):
# float / M8
assert_raises(TypeError, np.divide, 1.5, dta)
+ # NaTs
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=RuntimeWarning)
+ nat = np.timedelta64('NaT')
+ for tp in (int, float):
+ assert_equal(np.timedelta64(1) / tp(0), nat)
+ assert_equal(np.timedelta64(0) / tp(0), nat)
+ assert_equal(nat / tp(0), nat)
+ assert_equal(nat / tp(2), nat)
+ # Division by inf
+ assert_equal(np.timedelta64(1) / float('inf'), np.timedelta64(0))
+ assert_equal(np.timedelta64(0) / float('inf'), np.timedelta64(0))
+ assert_equal(nat / float('inf'), nat)
+ # Division by nan
+ assert_equal(np.timedelta64(1) / float('nan'), nat)
+ assert_equal(np.timedelta64(0) / float('nan'), nat)
+ assert_equal(nat / float('nan'), nat)
+
def test_datetime_compare(self):
# Test all the comparison operators
a = np.datetime64('2000-03-12T18:00:00.000000-0600')