summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMark Wiebe <mwiebe@enthought.com>2011-05-19 17:09:10 -0500
committerMark Wiebe <mwiebe@enthought.com>2011-05-19 17:13:50 -0500
commit7045cbc7e65dc13a1bb0d5ca866d455022e29f24 (patch)
tree8d15072b4bade50f1fd42ceb947eed588386550b /numpy
parent25bcc658a2031d709a052cdf62c8329a59b76b20 (diff)
downloadnumpy-7045cbc7e65dc13a1bb0d5ca866d455022e29f24.tar.gz
ENH: Reimplement datetime dtype string parser (with error checking)
The previous implementation returned the default datetime type in many cases instead of raising an exception. I also changed the hybrid C + Python implementation to be purely in C.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_internal.py30
-rw-r--r--numpy/core/src/multiarray/descriptor.c359
-rw-r--r--numpy/core/tests/test_datetime.py13
3 files changed, 272 insertions, 130 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 5298f412b..a5b6d117a 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -166,36 +166,6 @@ def _split(input):
return newlist
-format_datetime = re.compile(asbytes(r"""
- (?P<typecode>M8|m8|datetime64|timedelta64)
- ([[]
- ((?P<num>\d+)?
- (?P<baseunit>Y|M|W|B|D|h|m|s|ms|us|ns|ps|fs|as)
- (/(?P<den>\d+))?
- []])
- (//(?P<events>\d+))?)?"""), re.X)
-
-# Return (baseunit, num, den, events), datetime
-# from date-time string
-def _datetimestring(astr):
- res = format_datetime.match(astr)
- if res is None:
- raise ValueError("Incorrect date-time string.")
- typecode = res.group('typecode')
- datetime = (typecode == asbytes('M8') or typecode == asbytes('datetime64'))
- defaults = [asbytes('us'), 1, 1, 1]
- names = ['baseunit', 'num', 'den', 'events']
- func = [bytes, int, int, int]
- dt_tuple = []
- for i, name in enumerate(names):
- value = res.group(name)
- if value:
- dt_tuple.append(func[i](value))
- else:
- dt_tuple.append(defaults[i])
-
- return tuple(dt_tuple), datetime
-
format_re = re.compile(asbytes(r'(?P<order1>[<>|=]?)(?P<repeats> *[(]?[ ,0-9]*[)]? *)(?P<order2>[<>|=]?)(?P<dtype>[A-Za-z0-9.]*)'))
# astr is a string (perhaps comma separated)
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index bf4ea27a6..1fe497e67 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -102,9 +102,9 @@ array_set_typeDict(PyObject *NPY_UNUSED(ignored), PyObject *args)
}
static int
-_check_for_commastring(char *type, int len)
+_check_for_commastring(char *type, Py_ssize_t len)
{
- int i;
+ Py_ssize_t i;
/* Check for ints at start of string */
if ((type[0] >= '0'
@@ -135,9 +135,9 @@ _check_for_commastring(char *type, int len)
}
static int
-_check_for_datetime(char *type, int len)
+is_datetime_typestr(char *type, Py_ssize_t len)
{
- if (len < 1) {
+ if (len < 2) {
return 0;
}
if (type[1] == '8' && (type[0] == 'M' || type[0] == 'm')) {
@@ -553,27 +553,66 @@ NPY_NO_EXPORT char *_datetime_strings[] = {
NPY_STR_as
};
+/*
+ * Converts a substring given by 'str' and 'len' into
+ * a date time unit enum value. The 'metastr' parameter
+ * is used for error messages, and may be NULL.
+ *
+ * Returns -1 if there is an error.
+ */
static NPY_DATETIMEUNIT
- _unit_from_str(char *base)
+datetime_unit_from_string(char *str, Py_ssize_t len, char *metastr)
{
- NPY_DATETIMEUNIT unit;
-
- if (base == NULL) {
- return NPY_DATETIME_DEFAULTUNIT;
- }
-
- unit = NPY_FR_Y;
- while (unit < NPY_DATETIME_NUMUNITS) {
- if (strcmp(base, _datetime_strings[unit]) == 0) {
- break;
- }
- unit++;
+ /* Use switch statements so the compiler can make it fast */
+ if (len == 1) {
+ switch (str[0]) {
+ case 'Y':
+ return NPY_FR_Y;
+ case 'M':
+ return NPY_FR_M;
+ case 'W':
+ return NPY_FR_W;
+ case 'B':
+ return NPY_FR_B;
+ case 'D':
+ return NPY_FR_D;
+ case 'h':
+ return NPY_FR_h;
+ case 'm':
+ return NPY_FR_m;
+ case 's':
+ return NPY_FR_s;
+ }
+ }
+ /* All the two-letter units are variants of seconds */
+ else if (len == 2 && str[1] == 's') {
+ switch (str[0]) {
+ case 'm':
+ return NPY_FR_ms;
+ case 'u':
+ return NPY_FR_us;
+ case 'n':
+ return NPY_FR_ns;
+ case 'p':
+ return NPY_FR_ps;
+ case 'f':
+ return NPY_FR_fs;
+ case 'a':
+ return NPY_FR_as;
+ }
+ }
+
+ /* If nothing matched, it's an error */
+ if (metastr == NULL) {
+ PyErr_SetString(PyExc_ValueError,
+ "Invalid datetime unit in metadata");
}
- if (unit == NPY_DATETIME_NUMUNITS) {
- return NPY_DATETIME_DEFAULTUNIT;
+ else {
+ PyErr_Format(PyExc_ValueError,
+ "Invalid datetime unit in metadata string \"%s\"",
+ metastr);
}
-
- return unit;
+ return -1;
}
static NPY_DATETIMEUNIT _multiples_table[16][4] = {
@@ -596,9 +635,14 @@ static NPY_DATETIMEUNIT _multiples_table[16][4] = {
};
-/* Translate divisors into multiples of smaller units */
+/*
+ * Translate divisors into multiples of smaller units.
+ * 'metastr' is used for the error message if the divisor doesn't work,
+ * and can be NULL if the metadata didn't come from a string.
+ */
static int
-_convert_divisor_to_multiple(PyArray_DatetimeMetaData *meta)
+convert_datetime_divisor_to_multiple(PyArray_DatetimeMetaData *meta,
+ char *metastr)
{
int i, num, ind;
NPY_DATETIMEUNIT *totry;
@@ -638,8 +682,16 @@ _convert_divisor_to_multiple(PyArray_DatetimeMetaData *meta)
}
}
if (i == num) {
- PyErr_Format(PyExc_ValueError,
- "divisor (%d) is not a multiple of a lower-unit", meta->den);
+ if (metastr == NULL) {
+ PyErr_Format(PyExc_ValueError,
+ "divisor (%d) is not a multiple of a lower-unit "
+ "in datetime metadata", meta->den);
+ }
+ else {
+ PyErr_Format(PyExc_ValueError,
+ "divisor (%d) is not a multiple of a lower-unit "
+ "in datetime metadata \"%s\"", meta->den, metastr);
+ }
return -1;
}
meta->base = baseunit[i];
@@ -676,10 +728,20 @@ _convert_datetime_tuple_to_cobj(PyObject *tuple)
{
PyArray_DatetimeMetaData *dt_data;
PyObject *ret;
+ char *basestr = NULL;
+ Py_ssize_t len = 0;
+
+ if (PyBytes_AsStringAndSize(PyTuple_GET_ITEM(tuple, 0),
+ &basestr, &len) < 0) {
+ return NULL;
+ }
dt_data = _pya_malloc(sizeof(PyArray_DatetimeMetaData));
- dt_data->base = _unit_from_str(
- PyBytes_AsString(PyTuple_GET_ITEM(tuple, 0)));
+ dt_data->base = datetime_unit_from_string(basestr, len, NULL);
+ if (dt_data->base == -1) {
+ _pya_free(dt_data);
+ return NULL;
+ }
/* Assumes other objects are Python integers */
dt_data->num = PyInt_AS_LONG(PyTuple_GET_ITEM(tuple, 1));
@@ -687,108 +749,205 @@ _convert_datetime_tuple_to_cobj(PyObject *tuple)
dt_data->events = PyInt_AS_LONG(PyTuple_GET_ITEM(tuple, 3));
if (dt_data->den > 1) {
- if (_convert_divisor_to_multiple(dt_data) < 0) {
+ if (convert_datetime_divisor_to_multiple(dt_data, NULL) < 0) {
+ _pya_free(dt_data);
return NULL;
}
}
-/* FIXME
- * There is no error handling here.
- */
- ret = NpyCapsule_FromVoidPtr((void *)dt_data, simple_capsule_dtor);
- return ret;
+ return NpyCapsule_FromVoidPtr((void *)dt_data, simple_capsule_dtor);
}
+static PyObject *
+datetime_metacobj_from_metastr(char *metastr, Py_ssize_t len)
+{
+ PyArray_DatetimeMetaData *dt_data;
+ char *substr = metastr, *substrend = NULL;
+ int sublen = 0;
+
+ dt_data = _pya_malloc(sizeof(PyArray_DatetimeMetaData));
+ if (dt_data == NULL) {
+ return PyErr_NoMemory();
+ }
+
+ /* If there's no metastr, use the default */
+ if (len == 0) {
+ dt_data->num = 1;
+ dt_data->base = NPY_DATETIME_DEFAULTUNIT;
+ dt_data->den = 1;
+ dt_data->events = 1;
+ }
+ else {
+
+ /* The metadata string must start with a '[' */
+ if (len < 3 || *substr++ != '[') {
+ goto bad_input;
+ }
+
+ /* First comes an optional integer multiplier */
+ dt_data->num = (int)strtol(substr, &substrend, 10);
+ if (substr == substrend) {
+ dt_data->num = 1;
+ }
+ substr = substrend;
+
+ /* Next comes the unit itself, followed by either '/' or ']' */
+ substrend = substr;
+ while (*substrend != '\0' && *substrend != '/' && *substrend != ']') {
+ ++substrend;
+ }
+ if (*substrend == '\0') {
+ goto bad_input;
+ }
+ dt_data->base = datetime_unit_from_string(substr,
+ substrend-substr, metastr);
+ if (dt_data->base == -1) {
+ goto error;
+ }
+ substr = substrend;
+
+ /* Next comes an optional integer denominator */
+ if (*substr == '/') {
+ substr++;
+ dt_data->den = (int)strtol(substr, &substrend, 10);
+ /* If the '/' exists, there must be a number followed by ']' */
+ if (substr == substrend || *substrend != ']') {
+ goto bad_input;
+ }
+ substr = substrend + 1;
+ }
+ else if (*substr == ']') {
+ dt_data->den = 1;
+ substr++;
+ }
+ else {
+ goto bad_input;
+ }
+
+ /* Finally comes an optional number of events */
+ if (substr[0] == '/' && substr[1] == '/') {
+ substr += 2;
+
+ dt_data->events = (int)strtol(substr, &substrend, 10);
+ if (substr == substrend || *substrend != '\0') {
+ goto bad_input;
+ }
+ }
+ else if (*substr != '\0') {
+ goto bad_input;
+ }
+ else {
+ dt_data->events = 1;
+ }
+
+ if (dt_data->den > 1) {
+ if (convert_datetime_divisor_to_multiple(dt_data, metastr) < 0) {
+ goto bad_input;
+ }
+ }
+ }
+
+ return NpyCapsule_FromVoidPtr((void *)dt_data, simple_capsule_dtor);
+
+bad_input:
+ PyErr_Format(PyExc_ValueError,
+ "Invalid datetime metadata string \"%s\" at position %d",
+ metastr, (int)(substr-metastr));
+error:
+ _pya_free(dt_data);
+ return NULL;
+}
+
+/*
+ * Converts a datetype dtype string into a dtype descr object.
+ * The "type" string should be NULL-terminated.
+ */
static PyArray_Descr *
-_convert_from_datetime_tuple(PyObject *obj)
+dtype_from_datetime_typestr(char *typestr, Py_ssize_t len)
{
- PyArray_Descr *new;
- PyObject *dt_tuple;
- PyObject *dt_cobj;
- PyObject *datetime_flag;
+ PyArray_Descr *dtype = NULL;
+ char *metastr = NULL;
+ int is_timedelta = 0;
+ Py_ssize_t metalen = 0;
+ PyObject *metacobj = NULL;
- if (!PyTuple_Check(obj) || PyTuple_GET_SIZE(obj)!=2) {
- PyErr_SetString(PyExc_RuntimeError,
- "_datetimestring is not returning a tuple with length 2");
+ if (len < 2) {
+ PyErr_Format(PyExc_ValueError,
+ "Invalid datetime typestr \"%s\"",
+ typestr);
return NULL;
}
- dt_tuple = PyTuple_GET_ITEM(obj, 0);
- datetime_flag = PyTuple_GET_ITEM(obj, 1);
- if (!PyTuple_Check(dt_tuple)
- || PyTuple_GET_SIZE(dt_tuple) != 4
- || !PyInt_Check(datetime_flag)) {
- PyErr_SetString(PyExc_RuntimeError,
- "_datetimestring is not returning a length 4 tuple"\
- " and an integer");
+ /*
+ * First validate that the root is correct,
+ * and get the metadata string address
+ */
+ if (typestr[0] == 'm' && typestr[1] == '8') {
+ is_timedelta = 1;
+ metastr = typestr + 2;
+ metalen = len - 2;
+ }
+ else if (typestr[0] == 'M' && typestr[1] == '8') {
+ is_timedelta = 0;
+ metastr = typestr + 2;
+ metalen = len - 2;
+ }
+ else if (len >= 11 && strncmp(typestr, "timedelta64", 11) == 0) {
+ is_timedelta = 1;
+ metastr = typestr + 11;
+ metalen = len - 11;
+ }
+ else if (len >= 10 && strncmp(typestr, "datetime64", 10) == 0) {
+ is_timedelta = 0;
+ metastr = typestr + 10;
+ metalen = len - 10;
+ }
+ else {
+ PyErr_Format(PyExc_ValueError,
+ "Invalid datetime typestr \"%s\"",
+ typestr);
return NULL;
}
- /* Create new timedelta or datetime dtype */
- if (PyObject_IsTrue(datetime_flag)) {
- new = PyArray_DescrNewFromType(PyArray_DATETIME);
+ /* Create a default datetime or timedelta */
+ if (is_timedelta) {
+ dtype = PyArray_DescrNewFromType(PyArray_TIMEDELTA);
}
else {
- new = PyArray_DescrNewFromType(PyArray_TIMEDELTA);
+ dtype = PyArray_DescrNewFromType(PyArray_DATETIME);
}
-
- if (new == NULL) {
+ if (dtype == NULL) {
return NULL;
}
+
/*
* Remove any reference to old metadata dictionary
* And create a new one for this new dtype
*/
- Py_XDECREF(new->metadata);
- if ((new->metadata = PyDict_New()) == NULL) {
+ Py_XDECREF(dtype->metadata);
+ dtype->metadata = PyDict_New();
+ if (dtype->metadata == NULL) {
+ Py_DECREF(dtype);
return NULL;
}
- dt_cobj = _convert_datetime_tuple_to_cobj(dt_tuple);
- if (dt_cobj == NULL) {
- /* Failure in conversion */
- Py_DECREF(new);
- return NULL;
- }
-
- /* Assume this sets a new reference to dt_cobj */
- PyDict_SetItemString(new->metadata, NPY_METADATA_DTSTR, dt_cobj);
- Py_DECREF(dt_cobj);
- return new;
-}
-
-
-static PyArray_Descr *
-_convert_from_datetime(PyObject *obj)
-{
- PyObject *tupleobj;
- PyArray_Descr *res;
- PyObject *_numpy_internal;
- if (!PyBytes_Check(obj)) {
+ /* Parse the metadata string into a metadata CObject */
+ metacobj = datetime_metacobj_from_metastr(metastr, metalen);
+ if (metacobj == NULL) {
+ Py_DECREF(dtype);
return NULL;
}
- _numpy_internal = PyImport_ImportModule("numpy.core._internal");
- if (_numpy_internal == NULL) {
- return NULL;
- }
- tupleobj = PyObject_CallMethod(_numpy_internal,
- "_datetimestring", "O", obj);
- Py_DECREF(_numpy_internal);
- if (!tupleobj) {
- return NULL;
- }
- /*
- * tuple of a standard tuple (baseunit, num, den, events) and a timedelta
- * boolean
- */
- res = _convert_from_datetime_tuple(tupleobj);
- Py_DECREF(tupleobj);
- if (!res && !PyErr_Occurred()) {
- PyErr_SetString(PyExc_ValueError,
- "invalid data-type");
+
+ /* Set the metadata object in the dictionary. */
+ if (PyDict_SetItemString(dtype->metadata, NPY_METADATA_DTSTR,
+ metacobj) < 0) {
+ Py_DECREF(dtype);
+ Py_DECREF(metacobj);
return NULL;
}
- return res;
+ Py_DECREF(metacobj);
+
+ return dtype;
}
@@ -1288,8 +1447,8 @@ PyArray_DescrConverter(PyObject *obj, PyArray_Descr **at)
goto fail;
}
/* check for datetime format */
- if ((len > 1) && _check_for_datetime(type, len)) {
- *at = _convert_from_datetime(obj);
+ if (is_datetime_typestr(type, len)) {
+ *at = dtype_from_datetime_typestr(type, len);
if (*at) {
return PY_SUCCEED;
}
diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py
index 97fb1151f..12be38358 100644
--- a/numpy/core/tests/test_datetime.py
+++ b/numpy/core/tests/test_datetime.py
@@ -11,6 +11,19 @@ class TestDateTime(TestCase):
assert_(dt1 == np.dtype('datetime64[750%s]' % unit))
dt2 = np.dtype('m8[%s]' % unit)
assert_(dt2 == np.dtype('timedelta64[%s]' % unit))
+
+ # Check that the parser rejects bad datetime types
+ assert_raises(ValueError, np.dtype, 'M8[badunit]')
+ assert_raises(ValueError, np.dtype, 'm8[badunit]')
+ assert_raises(ValueError, np.dtype, 'm8[badunit]')
+ assert_raises(ValueError, np.dtype, 'M8[YY]')
+ assert_raises(ValueError, np.dtype, 'm8[YY]')
+ assert_raises(ValueError, np.dtype, 'M4')
+ assert_raises(ValueError, np.dtype, 'm4')
+ assert_raises(ValueError, np.dtype, 'M7')
+ assert_raises(ValueError, np.dtype, 'm7')
+ assert_raises(ValueError, np.dtype, 'M16')
+ assert_raises(ValueError, np.dtype, 'm16')
def test_hours(self):