summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/datetime_busdaycal.c123
-rw-r--r--numpy/core/tests/test_datetime.py20
2 files changed, 87 insertions, 56 deletions
diff --git a/numpy/core/src/multiarray/datetime_busdaycal.c b/numpy/core/src/multiarray/datetime_busdaycal.c
index ae35212ea..018912ee8 100644
--- a/numpy/core/src/multiarray/datetime_busdaycal.c
+++ b/numpy/core/src/multiarray/datetime_busdaycal.c
@@ -46,6 +46,7 @@ PyArray_WeekMaskConverter(PyObject *weekmask_in, npy_bool *weekmask)
if (PyBytes_Check(obj)) {
char *str;
Py_ssize_t len;
+ int i;
if (PyBytes_AsStringAndSize(obj, &str, &len) < 0) {
Py_DECREF(obj);
@@ -54,7 +55,6 @@ PyArray_WeekMaskConverter(PyObject *weekmask_in, npy_bool *weekmask)
/* Length 7 is a string like "1111100" */
if (len == 7) {
- int i;
for (i = 0; i < 7; ++i) {
switch(str[i]) {
case '0':
@@ -64,70 +64,81 @@ PyArray_WeekMaskConverter(PyObject *weekmask_in, npy_bool *weekmask)
weekmask[i] = 1;
break;
default:
- goto invalid_weekmask_string;
+ goto general_weekmask_string;
}
}
goto finish;
}
- /* Length divisible by 3 is a string like "Mon" or "MonWedFri" */
- else if (len % 3 == 0) {
- int i;
- memset(weekmask, 0, 7);
- for (i = 0; i < len; i += 3) {
- switch (str[i]) {
- case 'M':
- if (str[i+1] == 'o' && str[i+2] == 'n') {
- weekmask[0] = 1;
- }
- else {
- goto invalid_weekmask_string;
- }
- break;
- case 'T':
- if (str[i+1] == 'u' && str[i+2] == 'e') {
- weekmask[1] = 1;
- }
- else if (str[i+1] == 'h' && str[i+2] == 'u') {
- weekmask[3] = 1;
- }
- else {
- goto invalid_weekmask_string;
- }
- break;
- case 'W':
- if (str[i+1] == 'e' && str[i+2] == 'd') {
- weekmask[2] = 1;
- }
- else {
- goto invalid_weekmask_string;
- }
- break;
- case 'F':
- if (str[i+1] == 'r' && str[i+2] == 'i') {
- weekmask[4] = 1;
- }
- else {
- goto invalid_weekmask_string;
- }
- break;
- case 'S':
- if (str[i+1] == 'a' && str[i+2] == 't') {
- weekmask[5] = 1;
- }
- else if (str[i+1] == 'u' && str[i+2] == 'n') {
- weekmask[6] = 1;
- }
- else {
- goto invalid_weekmask_string;
- }
- break;
- }
+
+general_weekmask_string:
+ /* a string like "SatSun" or "Mon Tue Wed" */
+ memset(weekmask, 0, 7);
+ for (i = 0; i < len; i += 3) {
+ while (isspace(str[i]))
+ ++i;
+
+ if (i == len) {
+ goto finish;
+ }
+ else if (i + 2 >= len) {
+ goto invalid_weekmask_string;
}
- goto finish;
+ switch (str[i]) {
+ case 'M':
+ if (str[i+1] == 'o' && str[i+2] == 'n') {
+ weekmask[0] = 1;
+ }
+ else {
+ goto invalid_weekmask_string;
+ }
+ break;
+ case 'T':
+ if (str[i+1] == 'u' && str[i+2] == 'e') {
+ weekmask[1] = 1;
+ }
+ else if (str[i+1] == 'h' && str[i+2] == 'u') {
+ weekmask[3] = 1;
+ }
+ else {
+ goto invalid_weekmask_string;
+ }
+ break;
+ case 'W':
+ if (str[i+1] == 'e' && str[i+2] == 'd') {
+ weekmask[2] = 1;
+ }
+ else {
+ goto invalid_weekmask_string;
+ }
+ break;
+ case 'F':
+ if (str[i+1] == 'r' && str[i+2] == 'i') {
+ weekmask[4] = 1;
+ }
+ else {
+ goto invalid_weekmask_string;
+ }
+ break;
+ case 'S':
+ if (str[i+1] == 'a' && str[i+2] == 't') {
+ weekmask[5] = 1;
+ }
+ else if (str[i+1] == 'u' && str[i+2] == 'n') {
+ weekmask[6] = 1;
+ }
+ else {
+ goto invalid_weekmask_string;
+ }
+ break;
+ default:
+ goto invalid_weekmask_string;
+ }
}
+ goto finish;
+
invalid_weekmask_string:
PyErr_Format(PyExc_ValueError,
"Invalid business day weekmask string \"%s\"",
diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py
index 2770364f0..563f9e87c 100644
--- a/numpy/core/tests/test_datetime.py
+++ b/numpy/core/tests/test_datetime.py
@@ -1440,8 +1440,28 @@ class TestDateTime(TestCase):
# Default M-F weekmask
assert_equal(bdd.weekmask, np.array([1,1,1,1,1,0,0], dtype='?'))
+ # Check string weekmask with varying whitespace.
+ bdd = np.busdaycalendar(weekmask="Sun TueWed Thu\tFri")
+ assert_equal(bdd.weekmask, np.array([0,1,1,1,1,0,1], dtype='?'))
+
+ # Check length 7 0/1 string
+ bdd = np.busdaycalendar(weekmask="0011001")
+ assert_equal(bdd.weekmask, np.array([0,0,1,1,0,0,1], dtype='?'))
+
+ # Check length 7 string weekmask.
+ bdd = np.busdaycalendar(weekmask="Mon Tue")
+ assert_equal(bdd.weekmask, np.array([1,1,0,0,0,0,0], dtype='?'))
+
# All-zeros weekmask should raise
assert_raises(ValueError, np.busdaycalendar, weekmask=[0,0,0,0,0,0,0])
+ # weekday names must be correct case
+ assert_raises(ValueError, np.busdaycalendar, weekmask="satsun")
+ # All-zeros weekmask should raise
+ assert_raises(ValueError, np.busdaycalendar, weekmask="")
+ # Invalid weekday name codes should raise
+ assert_raises(ValueError, np.busdaycalendar, weekmask="Mon Tue We")
+ assert_raises(ValueError, np.busdaycalendar, weekmask="Max")
+ assert_raises(ValueError, np.busdaycalendar, weekmask="Monday Tue")
def test_datetime_busday_holidays_offset(self):
# With exactly one holiday