diff options
-rw-r--r-- | numpy/core/src/multiarray/datetime_busdaycal.c | 123 | ||||
-rw-r--r-- | numpy/core/tests/test_datetime.py | 20 |
2 files changed, 87 insertions, 56 deletions
diff --git a/numpy/core/src/multiarray/datetime_busdaycal.c b/numpy/core/src/multiarray/datetime_busdaycal.c index ae35212ea..018912ee8 100644 --- a/numpy/core/src/multiarray/datetime_busdaycal.c +++ b/numpy/core/src/multiarray/datetime_busdaycal.c @@ -46,6 +46,7 @@ PyArray_WeekMaskConverter(PyObject *weekmask_in, npy_bool *weekmask) if (PyBytes_Check(obj)) { char *str; Py_ssize_t len; + int i; if (PyBytes_AsStringAndSize(obj, &str, &len) < 0) { Py_DECREF(obj); @@ -54,7 +55,6 @@ PyArray_WeekMaskConverter(PyObject *weekmask_in, npy_bool *weekmask) /* Length 7 is a string like "1111100" */ if (len == 7) { - int i; for (i = 0; i < 7; ++i) { switch(str[i]) { case '0': @@ -64,70 +64,81 @@ PyArray_WeekMaskConverter(PyObject *weekmask_in, npy_bool *weekmask) weekmask[i] = 1; break; default: - goto invalid_weekmask_string; + goto general_weekmask_string; } } goto finish; } - /* Length divisible by 3 is a string like "Mon" or "MonWedFri" */ - else if (len % 3 == 0) { - int i; - memset(weekmask, 0, 7); - for (i = 0; i < len; i += 3) { - switch (str[i]) { - case 'M': - if (str[i+1] == 'o' && str[i+2] == 'n') { - weekmask[0] = 1; - } - else { - goto invalid_weekmask_string; - } - break; - case 'T': - if (str[i+1] == 'u' && str[i+2] == 'e') { - weekmask[1] = 1; - } - else if (str[i+1] == 'h' && str[i+2] == 'u') { - weekmask[3] = 1; - } - else { - goto invalid_weekmask_string; - } - break; - case 'W': - if (str[i+1] == 'e' && str[i+2] == 'd') { - weekmask[2] = 1; - } - else { - goto invalid_weekmask_string; - } - break; - case 'F': - if (str[i+1] == 'r' && str[i+2] == 'i') { - weekmask[4] = 1; - } - else { - goto invalid_weekmask_string; - } - break; - case 'S': - if (str[i+1] == 'a' && str[i+2] == 't') { - weekmask[5] = 1; - } - else if (str[i+1] == 'u' && str[i+2] == 'n') { - weekmask[6] = 1; - } - else { - goto invalid_weekmask_string; - } - break; - } + +general_weekmask_string: + /* a string like "SatSun" or "Mon Tue Wed" */ + memset(weekmask, 0, 7); + for (i = 0; i < len; i += 3) { + while (isspace(str[i])) + ++i; + + if (i == len) { + goto finish; + } + else if (i + 2 >= len) { + goto invalid_weekmask_string; } - goto finish; + switch (str[i]) { + case 'M': + if (str[i+1] == 'o' && str[i+2] == 'n') { + weekmask[0] = 1; + } + else { + goto invalid_weekmask_string; + } + break; + case 'T': + if (str[i+1] == 'u' && str[i+2] == 'e') { + weekmask[1] = 1; + } + else if (str[i+1] == 'h' && str[i+2] == 'u') { + weekmask[3] = 1; + } + else { + goto invalid_weekmask_string; + } + break; + case 'W': + if (str[i+1] == 'e' && str[i+2] == 'd') { + weekmask[2] = 1; + } + else { + goto invalid_weekmask_string; + } + break; + case 'F': + if (str[i+1] == 'r' && str[i+2] == 'i') { + weekmask[4] = 1; + } + else { + goto invalid_weekmask_string; + } + break; + case 'S': + if (str[i+1] == 'a' && str[i+2] == 't') { + weekmask[5] = 1; + } + else if (str[i+1] == 'u' && str[i+2] == 'n') { + weekmask[6] = 1; + } + else { + goto invalid_weekmask_string; + } + break; + default: + goto invalid_weekmask_string; + } } + goto finish; + invalid_weekmask_string: PyErr_Format(PyExc_ValueError, "Invalid business day weekmask string \"%s\"", diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py index 2770364f0..563f9e87c 100644 --- a/numpy/core/tests/test_datetime.py +++ b/numpy/core/tests/test_datetime.py @@ -1440,8 +1440,28 @@ class TestDateTime(TestCase): # Default M-F weekmask assert_equal(bdd.weekmask, np.array([1,1,1,1,1,0,0], dtype='?')) + # Check string weekmask with varying whitespace. + bdd = np.busdaycalendar(weekmask="Sun TueWed Thu\tFri") + assert_equal(bdd.weekmask, np.array([0,1,1,1,1,0,1], dtype='?')) + + # Check length 7 0/1 string + bdd = np.busdaycalendar(weekmask="0011001") + assert_equal(bdd.weekmask, np.array([0,0,1,1,0,0,1], dtype='?')) + + # Check length 7 string weekmask. + bdd = np.busdaycalendar(weekmask="Mon Tue") + assert_equal(bdd.weekmask, np.array([1,1,0,0,0,0,0], dtype='?')) + # All-zeros weekmask should raise assert_raises(ValueError, np.busdaycalendar, weekmask=[0,0,0,0,0,0,0]) + # weekday names must be correct case + assert_raises(ValueError, np.busdaycalendar, weekmask="satsun") + # All-zeros weekmask should raise + assert_raises(ValueError, np.busdaycalendar, weekmask="") + # Invalid weekday name codes should raise + assert_raises(ValueError, np.busdaycalendar, weekmask="Mon Tue We") + assert_raises(ValueError, np.busdaycalendar, weekmask="Max") + assert_raises(ValueError, np.busdaycalendar, weekmask="Monday Tue") def test_datetime_busday_holidays_offset(self): # With exactly one holiday |