diff options
author | Mark Wiebe <mwiebe@enthought.com> | 2011-05-19 17:09:10 -0500 |
---|---|---|
committer | Mark Wiebe <mwiebe@enthought.com> | 2011-05-19 17:13:50 -0500 |
commit | 7045cbc7e65dc13a1bb0d5ca866d455022e29f24 (patch) | |
tree | 8d15072b4bade50f1fd42ceb947eed588386550b /numpy | |
parent | 25bcc658a2031d709a052cdf62c8329a59b76b20 (diff) | |
download | numpy-7045cbc7e65dc13a1bb0d5ca866d455022e29f24.tar.gz |
ENH: Reimplement datetime dtype string parser (with error checking)
The previous implementation returned the default datetime type in
many cases instead of raising an exception. I also changed the
hybrid C + Python implementation to be purely in C.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/_internal.py | 30 | ||||
-rw-r--r-- | numpy/core/src/multiarray/descriptor.c | 359 | ||||
-rw-r--r-- | numpy/core/tests/test_datetime.py | 13 |
3 files changed, 272 insertions, 130 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 5298f412b..a5b6d117a 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -166,36 +166,6 @@ def _split(input): return newlist -format_datetime = re.compile(asbytes(r""" - (?P<typecode>M8|m8|datetime64|timedelta64) - ([[] - ((?P<num>\d+)? - (?P<baseunit>Y|M|W|B|D|h|m|s|ms|us|ns|ps|fs|as) - (/(?P<den>\d+))? - []]) - (//(?P<events>\d+))?)?"""), re.X) - -# Return (baseunit, num, den, events), datetime -# from date-time string -def _datetimestring(astr): - res = format_datetime.match(astr) - if res is None: - raise ValueError("Incorrect date-time string.") - typecode = res.group('typecode') - datetime = (typecode == asbytes('M8') or typecode == asbytes('datetime64')) - defaults = [asbytes('us'), 1, 1, 1] - names = ['baseunit', 'num', 'den', 'events'] - func = [bytes, int, int, int] - dt_tuple = [] - for i, name in enumerate(names): - value = res.group(name) - if value: - dt_tuple.append(func[i](value)) - else: - dt_tuple.append(defaults[i]) - - return tuple(dt_tuple), datetime - format_re = re.compile(asbytes(r'(?P<order1>[<>|=]?)(?P<repeats> *[(]?[ ,0-9]*[)]? *)(?P<order2>[<>|=]?)(?P<dtype>[A-Za-z0-9.]*)')) # astr is a string (perhaps comma separated) diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index bf4ea27a6..1fe497e67 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -102,9 +102,9 @@ array_set_typeDict(PyObject *NPY_UNUSED(ignored), PyObject *args) } static int -_check_for_commastring(char *type, int len) +_check_for_commastring(char *type, Py_ssize_t len) { - int i; + Py_ssize_t i; /* Check for ints at start of string */ if ((type[0] >= '0' @@ -135,9 +135,9 @@ _check_for_commastring(char *type, int len) } static int -_check_for_datetime(char *type, int len) +is_datetime_typestr(char *type, Py_ssize_t len) { - if (len < 1) { + if (len < 2) { return 0; } if (type[1] == '8' && (type[0] == 'M' || type[0] == 'm')) { @@ -553,27 +553,66 @@ NPY_NO_EXPORT char *_datetime_strings[] = { NPY_STR_as }; +/* + * Converts a substring given by 'str' and 'len' into + * a date time unit enum value. The 'metastr' parameter + * is used for error messages, and may be NULL. + * + * Returns -1 if there is an error. + */ static NPY_DATETIMEUNIT - _unit_from_str(char *base) +datetime_unit_from_string(char *str, Py_ssize_t len, char *metastr) { - NPY_DATETIMEUNIT unit; - - if (base == NULL) { - return NPY_DATETIME_DEFAULTUNIT; - } - - unit = NPY_FR_Y; - while (unit < NPY_DATETIME_NUMUNITS) { - if (strcmp(base, _datetime_strings[unit]) == 0) { - break; - } - unit++; + /* Use switch statements so the compiler can make it fast */ + if (len == 1) { + switch (str[0]) { + case 'Y': + return NPY_FR_Y; + case 'M': + return NPY_FR_M; + case 'W': + return NPY_FR_W; + case 'B': + return NPY_FR_B; + case 'D': + return NPY_FR_D; + case 'h': + return NPY_FR_h; + case 'm': + return NPY_FR_m; + case 's': + return NPY_FR_s; + } + } + /* All the two-letter units are variants of seconds */ + else if (len == 2 && str[1] == 's') { + switch (str[0]) { + case 'm': + return NPY_FR_ms; + case 'u': + return NPY_FR_us; + case 'n': + return NPY_FR_ns; + case 'p': + return NPY_FR_ps; + case 'f': + return NPY_FR_fs; + case 'a': + return NPY_FR_as; + } + } + + /* If nothing matched, it's an error */ + if (metastr == NULL) { + PyErr_SetString(PyExc_ValueError, + "Invalid datetime unit in metadata"); } - if (unit == NPY_DATETIME_NUMUNITS) { - return NPY_DATETIME_DEFAULTUNIT; + else { + PyErr_Format(PyExc_ValueError, + "Invalid datetime unit in metadata string \"%s\"", + metastr); } - - return unit; + return -1; } static NPY_DATETIMEUNIT _multiples_table[16][4] = { @@ -596,9 +635,14 @@ static NPY_DATETIMEUNIT _multiples_table[16][4] = { }; -/* Translate divisors into multiples of smaller units */ +/* + * Translate divisors into multiples of smaller units. + * 'metastr' is used for the error message if the divisor doesn't work, + * and can be NULL if the metadata didn't come from a string. + */ static int -_convert_divisor_to_multiple(PyArray_DatetimeMetaData *meta) +convert_datetime_divisor_to_multiple(PyArray_DatetimeMetaData *meta, + char *metastr) { int i, num, ind; NPY_DATETIMEUNIT *totry; @@ -638,8 +682,16 @@ _convert_divisor_to_multiple(PyArray_DatetimeMetaData *meta) } } if (i == num) { - PyErr_Format(PyExc_ValueError, - "divisor (%d) is not a multiple of a lower-unit", meta->den); + if (metastr == NULL) { + PyErr_Format(PyExc_ValueError, + "divisor (%d) is not a multiple of a lower-unit " + "in datetime metadata", meta->den); + } + else { + PyErr_Format(PyExc_ValueError, + "divisor (%d) is not a multiple of a lower-unit " + "in datetime metadata \"%s\"", meta->den, metastr); + } return -1; } meta->base = baseunit[i]; @@ -676,10 +728,20 @@ _convert_datetime_tuple_to_cobj(PyObject *tuple) { PyArray_DatetimeMetaData *dt_data; PyObject *ret; + char *basestr = NULL; + Py_ssize_t len = 0; + + if (PyBytes_AsStringAndSize(PyTuple_GET_ITEM(tuple, 0), + &basestr, &len) < 0) { + return NULL; + } dt_data = _pya_malloc(sizeof(PyArray_DatetimeMetaData)); - dt_data->base = _unit_from_str( - PyBytes_AsString(PyTuple_GET_ITEM(tuple, 0))); + dt_data->base = datetime_unit_from_string(basestr, len, NULL); + if (dt_data->base == -1) { + _pya_free(dt_data); + return NULL; + } /* Assumes other objects are Python integers */ dt_data->num = PyInt_AS_LONG(PyTuple_GET_ITEM(tuple, 1)); @@ -687,108 +749,205 @@ _convert_datetime_tuple_to_cobj(PyObject *tuple) dt_data->events = PyInt_AS_LONG(PyTuple_GET_ITEM(tuple, 3)); if (dt_data->den > 1) { - if (_convert_divisor_to_multiple(dt_data) < 0) { + if (convert_datetime_divisor_to_multiple(dt_data, NULL) < 0) { + _pya_free(dt_data); return NULL; } } -/* FIXME - * There is no error handling here. - */ - ret = NpyCapsule_FromVoidPtr((void *)dt_data, simple_capsule_dtor); - return ret; + return NpyCapsule_FromVoidPtr((void *)dt_data, simple_capsule_dtor); } +static PyObject * +datetime_metacobj_from_metastr(char *metastr, Py_ssize_t len) +{ + PyArray_DatetimeMetaData *dt_data; + char *substr = metastr, *substrend = NULL; + int sublen = 0; + + dt_data = _pya_malloc(sizeof(PyArray_DatetimeMetaData)); + if (dt_data == NULL) { + return PyErr_NoMemory(); + } + + /* If there's no metastr, use the default */ + if (len == 0) { + dt_data->num = 1; + dt_data->base = NPY_DATETIME_DEFAULTUNIT; + dt_data->den = 1; + dt_data->events = 1; + } + else { + + /* The metadata string must start with a '[' */ + if (len < 3 || *substr++ != '[') { + goto bad_input; + } + + /* First comes an optional integer multiplier */ + dt_data->num = (int)strtol(substr, &substrend, 10); + if (substr == substrend) { + dt_data->num = 1; + } + substr = substrend; + + /* Next comes the unit itself, followed by either '/' or ']' */ + substrend = substr; + while (*substrend != '\0' && *substrend != '/' && *substrend != ']') { + ++substrend; + } + if (*substrend == '\0') { + goto bad_input; + } + dt_data->base = datetime_unit_from_string(substr, + substrend-substr, metastr); + if (dt_data->base == -1) { + goto error; + } + substr = substrend; + + /* Next comes an optional integer denominator */ + if (*substr == '/') { + substr++; + dt_data->den = (int)strtol(substr, &substrend, 10); + /* If the '/' exists, there must be a number followed by ']' */ + if (substr == substrend || *substrend != ']') { + goto bad_input; + } + substr = substrend + 1; + } + else if (*substr == ']') { + dt_data->den = 1; + substr++; + } + else { + goto bad_input; + } + + /* Finally comes an optional number of events */ + if (substr[0] == '/' && substr[1] == '/') { + substr += 2; + + dt_data->events = (int)strtol(substr, &substrend, 10); + if (substr == substrend || *substrend != '\0') { + goto bad_input; + } + } + else if (*substr != '\0') { + goto bad_input; + } + else { + dt_data->events = 1; + } + + if (dt_data->den > 1) { + if (convert_datetime_divisor_to_multiple(dt_data, metastr) < 0) { + goto bad_input; + } + } + } + + return NpyCapsule_FromVoidPtr((void *)dt_data, simple_capsule_dtor); + +bad_input: + PyErr_Format(PyExc_ValueError, + "Invalid datetime metadata string \"%s\" at position %d", + metastr, (int)(substr-metastr)); +error: + _pya_free(dt_data); + return NULL; +} + +/* + * Converts a datetype dtype string into a dtype descr object. + * The "type" string should be NULL-terminated. + */ static PyArray_Descr * -_convert_from_datetime_tuple(PyObject *obj) +dtype_from_datetime_typestr(char *typestr, Py_ssize_t len) { - PyArray_Descr *new; - PyObject *dt_tuple; - PyObject *dt_cobj; - PyObject *datetime_flag; + PyArray_Descr *dtype = NULL; + char *metastr = NULL; + int is_timedelta = 0; + Py_ssize_t metalen = 0; + PyObject *metacobj = NULL; - if (!PyTuple_Check(obj) || PyTuple_GET_SIZE(obj)!=2) { - PyErr_SetString(PyExc_RuntimeError, - "_datetimestring is not returning a tuple with length 2"); + if (len < 2) { + PyErr_Format(PyExc_ValueError, + "Invalid datetime typestr \"%s\"", + typestr); return NULL; } - dt_tuple = PyTuple_GET_ITEM(obj, 0); - datetime_flag = PyTuple_GET_ITEM(obj, 1); - if (!PyTuple_Check(dt_tuple) - || PyTuple_GET_SIZE(dt_tuple) != 4 - || !PyInt_Check(datetime_flag)) { - PyErr_SetString(PyExc_RuntimeError, - "_datetimestring is not returning a length 4 tuple"\ - " and an integer"); + /* + * First validate that the root is correct, + * and get the metadata string address + */ + if (typestr[0] == 'm' && typestr[1] == '8') { + is_timedelta = 1; + metastr = typestr + 2; + metalen = len - 2; + } + else if (typestr[0] == 'M' && typestr[1] == '8') { + is_timedelta = 0; + metastr = typestr + 2; + metalen = len - 2; + } + else if (len >= 11 && strncmp(typestr, "timedelta64", 11) == 0) { + is_timedelta = 1; + metastr = typestr + 11; + metalen = len - 11; + } + else if (len >= 10 && strncmp(typestr, "datetime64", 10) == 0) { + is_timedelta = 0; + metastr = typestr + 10; + metalen = len - 10; + } + else { + PyErr_Format(PyExc_ValueError, + "Invalid datetime typestr \"%s\"", + typestr); return NULL; } - /* Create new timedelta or datetime dtype */ - if (PyObject_IsTrue(datetime_flag)) { - new = PyArray_DescrNewFromType(PyArray_DATETIME); + /* Create a default datetime or timedelta */ + if (is_timedelta) { + dtype = PyArray_DescrNewFromType(PyArray_TIMEDELTA); } else { - new = PyArray_DescrNewFromType(PyArray_TIMEDELTA); + dtype = PyArray_DescrNewFromType(PyArray_DATETIME); } - - if (new == NULL) { + if (dtype == NULL) { return NULL; } + /* * Remove any reference to old metadata dictionary * And create a new one for this new dtype */ - Py_XDECREF(new->metadata); - if ((new->metadata = PyDict_New()) == NULL) { + Py_XDECREF(dtype->metadata); + dtype->metadata = PyDict_New(); + if (dtype->metadata == NULL) { + Py_DECREF(dtype); return NULL; } - dt_cobj = _convert_datetime_tuple_to_cobj(dt_tuple); - if (dt_cobj == NULL) { - /* Failure in conversion */ - Py_DECREF(new); - return NULL; - } - - /* Assume this sets a new reference to dt_cobj */ - PyDict_SetItemString(new->metadata, NPY_METADATA_DTSTR, dt_cobj); - Py_DECREF(dt_cobj); - return new; -} - - -static PyArray_Descr * -_convert_from_datetime(PyObject *obj) -{ - PyObject *tupleobj; - PyArray_Descr *res; - PyObject *_numpy_internal; - if (!PyBytes_Check(obj)) { + /* Parse the metadata string into a metadata CObject */ + metacobj = datetime_metacobj_from_metastr(metastr, metalen); + if (metacobj == NULL) { + Py_DECREF(dtype); return NULL; } - _numpy_internal = PyImport_ImportModule("numpy.core._internal"); - if (_numpy_internal == NULL) { - return NULL; - } - tupleobj = PyObject_CallMethod(_numpy_internal, - "_datetimestring", "O", obj); - Py_DECREF(_numpy_internal); - if (!tupleobj) { - return NULL; - } - /* - * tuple of a standard tuple (baseunit, num, den, events) and a timedelta - * boolean - */ - res = _convert_from_datetime_tuple(tupleobj); - Py_DECREF(tupleobj); - if (!res && !PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, - "invalid data-type"); + + /* Set the metadata object in the dictionary. */ + if (PyDict_SetItemString(dtype->metadata, NPY_METADATA_DTSTR, + metacobj) < 0) { + Py_DECREF(dtype); + Py_DECREF(metacobj); return NULL; } - return res; + Py_DECREF(metacobj); + + return dtype; } @@ -1288,8 +1447,8 @@ PyArray_DescrConverter(PyObject *obj, PyArray_Descr **at) goto fail; } /* check for datetime format */ - if ((len > 1) && _check_for_datetime(type, len)) { - *at = _convert_from_datetime(obj); + if (is_datetime_typestr(type, len)) { + *at = dtype_from_datetime_typestr(type, len); if (*at) { return PY_SUCCEED; } diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py index 97fb1151f..12be38358 100644 --- a/numpy/core/tests/test_datetime.py +++ b/numpy/core/tests/test_datetime.py @@ -11,6 +11,19 @@ class TestDateTime(TestCase): assert_(dt1 == np.dtype('datetime64[750%s]' % unit)) dt2 = np.dtype('m8[%s]' % unit) assert_(dt2 == np.dtype('timedelta64[%s]' % unit)) + + # Check that the parser rejects bad datetime types + assert_raises(ValueError, np.dtype, 'M8[badunit]') + assert_raises(ValueError, np.dtype, 'm8[badunit]') + assert_raises(ValueError, np.dtype, 'm8[badunit]') + assert_raises(ValueError, np.dtype, 'M8[YY]') + assert_raises(ValueError, np.dtype, 'm8[YY]') + assert_raises(ValueError, np.dtype, 'M4') + assert_raises(ValueError, np.dtype, 'm4') + assert_raises(ValueError, np.dtype, 'M7') + assert_raises(ValueError, np.dtype, 'm7') + assert_raises(ValueError, np.dtype, 'M16') + assert_raises(ValueError, np.dtype, 'm16') def test_hours(self): |