summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMark Wiebe <mwiebe@enthought.com>2011-05-20 11:40:40 -0500
committerMark Wiebe <mwiebe@enthought.com>2011-05-20 11:45:33 -0500
commit8949fe24c057973ece3940617d396addbcbf3875 (patch)
treeb6f50fe6f07dbd7a925a541c312f777298c75e7b /numpy
parent7045cbc7e65dc13a1bb0d5ca866d455022e29f24 (diff)
downloadnumpy-8949fe24c057973ece3940617d396addbcbf3875.tar.gz
ENH: Tighten up dtype parsing in general, to catch some more invalid datetime results
There is also a problem with 'O4' and 'O8', which are platform-specific. I changed it to still accept both, but produce a deprecation warning.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/numerictypes.py4
-rw-r--r--numpy/core/src/multiarray/conversion_utils.c213
-rw-r--r--numpy/core/src/multiarray/descriptor.c43
-rw-r--r--numpy/core/tests/test_datetime.py21
-rw-r--r--numpy/core/tests/test_dtype.py8
-rw-r--r--numpy/ma/core.py4
6 files changed, 169 insertions, 124 deletions
diff --git a/numpy/core/numerictypes.py b/numpy/core/numerictypes.py
index 7bdfd98c1..08907541a 100644
--- a/numpy/core/numerictypes.py
+++ b/numpy/core/numerictypes.py
@@ -257,6 +257,10 @@ def bitname(obj):
char = 'O'
base = 'object'
bits = 0
+ elif name=='datetime64':
+ char = 'M'
+ elif name=='timedelta64':
+ char = 'm'
if sys.version_info[0] >= 3:
if name=='bytes_':
diff --git a/numpy/core/src/multiarray/conversion_utils.c b/numpy/core/src/multiarray/conversion_utils.c
index e8ac99273..12f2c731b 100644
--- a/numpy/core/src/multiarray/conversion_utils.c
+++ b/numpy/core/src/multiarray/conversion_utils.c
@@ -628,119 +628,148 @@ PyArray_IntpFromSequence(PyObject *seq, npy_intp *vals, int maxvals)
NPY_NO_EXPORT int
PyArray_TypestrConvert(int itemsize, int gentype)
{
- int newtype = gentype;
+ int newtype = NPY_NOTYPE;
- if (gentype == PyArray_GENBOOLLTR) {
- if (itemsize == 1) {
- newtype = PyArray_BOOL;
- }
- else {
- newtype = PyArray_NOTYPE;
- }
- }
- else if (gentype == PyArray_SIGNEDLTR) {
- switch(itemsize) {
- case 1:
- newtype = PyArray_INT8;
- break;
- case 2:
- newtype = PyArray_INT16;
- break;
- case 4:
- newtype = PyArray_INT32;
- break;
- case 8:
- newtype = PyArray_INT64;
+ switch (gentype) {
+ case NPY_GENBOOLLTR:
+ if (itemsize == 1) {
+ newtype = NPY_BOOL;
+ }
break;
+
+ case NPY_SIGNEDLTR:
+ switch(itemsize) {
+ case 1:
+ newtype = NPY_INT8;
+ break;
+ case 2:
+ newtype = NPY_INT16;
+ break;
+ case 4:
+ newtype = NPY_INT32;
+ break;
+ case 8:
+ newtype = NPY_INT64;
+ break;
#ifdef PyArray_INT128
- case 16:
- newtype = PyArray_INT128;
- break;
+ case 16:
+ newtype = NPY_INT128;
+ break;
#endif
- default:
- newtype = PyArray_NOTYPE;
- }
- }
- else if (gentype == PyArray_UNSIGNEDLTR) {
- switch(itemsize) {
- case 1:
- newtype = PyArray_UINT8;
- break;
- case 2:
- newtype = PyArray_UINT16;
- break;
- case 4:
- newtype = PyArray_UINT32;
- break;
- case 8:
- newtype = PyArray_UINT64;
+ }
break;
+
+ case NPY_UNSIGNEDLTR:
+ switch(itemsize) {
+ case 1:
+ newtype = NPY_UINT8;
+ break;
+ case 2:
+ newtype = NPY_UINT16;
+ break;
+ case 4:
+ newtype = NPY_UINT32;
+ break;
+ case 8:
+ newtype = NPY_UINT64;
+ break;
#ifdef PyArray_INT128
- case 16:
- newtype = PyArray_UINT128;
- break;
+ case 16:
+ newtype = NPY_UINT128;
+ break;
#endif
- default:
- newtype = PyArray_NOTYPE;
- break;
- }
- }
- else if (gentype == PyArray_FLOATINGLTR) {
- switch(itemsize) {
- case 2:
- newtype = PyArray_FLOAT16;
- break;
- case 4:
- newtype = PyArray_FLOAT32;
- break;
- case 8:
- newtype = PyArray_FLOAT64;
+ }
break;
+
+ case NPY_FLOATINGLTR:
+ switch(itemsize) {
+ case 2:
+ newtype = NPY_FLOAT16;
+ break;
+ case 4:
+ newtype = NPY_FLOAT32;
+ break;
+ case 8:
+ newtype = NPY_FLOAT64;
+ break;
#ifdef PyArray_FLOAT80
- case 10:
- newtype = PyArray_FLOAT80;
- break;
+ case 10:
+ newtype = NPY_FLOAT80;
+ break;
#endif
#ifdef PyArray_FLOAT96
- case 12:
- newtype = PyArray_FLOAT96;
- break;
+ case 12:
+ newtype = NPY_FLOAT96;
+ break;
#endif
#ifdef PyArray_FLOAT128
- case 16:
- newtype = PyArray_FLOAT128;
- break;
+ case 16:
+ newtype = NPY_FLOAT128;
+ break;
#endif
- default:
- newtype = PyArray_NOTYPE;
- }
- }
- else if (gentype == PyArray_COMPLEXLTR) {
- switch(itemsize) {
- case 8:
- newtype = PyArray_COMPLEX64;
- break;
- case 16:
- newtype = PyArray_COMPLEX128;
+ }
break;
+
+ case NPY_COMPLEXLTR:
+ switch(itemsize) {
+ case 8:
+ newtype = NPY_COMPLEX64;
+ break;
+ case 16:
+ newtype = NPY_COMPLEX128;
+ break;
#ifdef PyArray_FLOAT80
- case 20:
- newtype = PyArray_COMPLEX160;
- break;
+ case 20:
+ newtype = NPY_COMPLEX160;
+ break;
#endif
#ifdef PyArray_FLOAT96
- case 24:
- newtype = PyArray_COMPLEX192;
- break;
+ case 24:
+ newtype = NPY_COMPLEX192;
+ break;
#endif
#ifdef PyArray_FLOAT128
- case 32:
- newtype = PyArray_COMPLEX256;
- break;
+ case 32:
+ newtype = NPY_COMPLEX256;
+ break;
#endif
- default:
- newtype = PyArray_NOTYPE;
- }
+ }
+ break;
+
+ case NPY_OBJECTLTR:
+ if (PyErr_WarnEx(PyExc_DeprecationWarning,
+ "DType strings 'O4' and 'O8' are deprecated "
+ "because they are platform specific. Use "
+ "'O' instead", 0) == 0 &&
+ (itemsize == 4 || itemsize == 8)) {
+ newtype = NPY_OBJECT;
+ }
+ break;
+
+ case NPY_STRINGLTR:
+ case NPY_STRINGLTR2:
+ newtype = NPY_STRING;
+ break;
+
+ case NPY_UNICODELTR:
+ newtype = NPY_UNICODE;
+ break;
+
+ case NPY_VOIDLTR:
+ newtype = NPY_VOID;
+ break;
+
+ case NPY_DATETIMELTR:
+ if (itemsize == 8) {
+ newtype = NPY_DATETIME;
+ }
+ break;
+
+ case NPY_TIMEDELTALTR:
+ if (itemsize == 8) {
+ newtype = NPY_TIMEDELTA;
+ }
+ break;
}
return newtype;
}
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index 1fe497e67..40572e908 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -604,11 +604,11 @@ datetime_unit_from_string(char *str, Py_ssize_t len, char *metastr)
/* If nothing matched, it's an error */
if (metastr == NULL) {
- PyErr_SetString(PyExc_ValueError,
+ PyErr_SetString(PyExc_TypeError,
"Invalid datetime unit in metadata");
}
else {
- PyErr_Format(PyExc_ValueError,
+ PyErr_Format(PyExc_TypeError,
"Invalid datetime unit in metadata string \"%s\"",
metastr);
}
@@ -727,7 +727,6 @@ static PyObject *
_convert_datetime_tuple_to_cobj(PyObject *tuple)
{
PyArray_DatetimeMetaData *dt_data;
- PyObject *ret;
char *basestr = NULL;
Py_ssize_t len = 0;
@@ -763,7 +762,6 @@ datetime_metacobj_from_metastr(char *metastr, Py_ssize_t len)
{
PyArray_DatetimeMetaData *dt_data;
char *substr = metastr, *substrend = NULL;
- int sublen = 0;
dt_data = _pya_malloc(sizeof(PyArray_DatetimeMetaData));
if (dt_data == NULL) {
@@ -842,7 +840,7 @@ datetime_metacobj_from_metastr(char *metastr, Py_ssize_t len)
if (dt_data->den > 1) {
if (convert_datetime_divisor_to_multiple(dt_data, metastr) < 0) {
- goto bad_input;
+ goto error;
}
}
}
@@ -850,9 +848,16 @@ datetime_metacobj_from_metastr(char *metastr, Py_ssize_t len)
return NpyCapsule_FromVoidPtr((void *)dt_data, simple_capsule_dtor);
bad_input:
- PyErr_Format(PyExc_ValueError,
- "Invalid datetime metadata string \"%s\" at position %d",
- metastr, (int)(substr-metastr));
+ if (substr != metastr) {
+ PyErr_Format(PyExc_TypeError,
+ "Invalid datetime metadata string \"%s\" at position %d",
+ metastr, (int)(substr-metastr));
+ }
+ else {
+ PyErr_Format(PyExc_TypeError,
+ "Invalid datetime metadata string \"%s\"",
+ metastr);
+ }
error:
_pya_free(dt_data);
return NULL;
@@ -872,7 +877,7 @@ dtype_from_datetime_typestr(char *typestr, Py_ssize_t len)
PyObject *metacobj = NULL;
if (len < 2) {
- PyErr_Format(PyExc_ValueError,
+ PyErr_Format(PyExc_TypeError,
"Invalid datetime typestr \"%s\"",
typestr);
return NULL;
@@ -903,7 +908,7 @@ dtype_from_datetime_typestr(char *typestr, Py_ssize_t len)
metalen = len - 10;
}
else {
- PyErr_Format(PyExc_ValueError,
+ PyErr_Format(PyExc_TypeError,
"Invalid datetime typestr \"%s\"",
typestr);
return NULL;
@@ -1449,18 +1454,12 @@ PyArray_DescrConverter(PyObject *obj, PyArray_Descr **at)
/* check for datetime format */
if (is_datetime_typestr(type, len)) {
*at = dtype_from_datetime_typestr(type, len);
- if (*at) {
- return PY_SUCCEED;
- }
- return PY_FAIL;
+ return (*at) ? PY_SUCCEED : PY_FAIL;
}
/* check for commas present or first (or second) element a digit */
if (_check_for_commastring(type, len)) {
*at = _convert_from_commastring(obj, 0);
- if (*at) {
- return PY_SUCCEED;
- }
- return PY_FAIL;
+ return (*at) ? PY_SUCCEED : PY_FAIL;
}
check_num = (int) type[0];
if ((char) check_num == '>'
@@ -1597,7 +1596,13 @@ PyArray_DescrConverter(PyObject *obj, PyArray_Descr **at)
return PY_SUCCEED;
fail:
- PyErr_SetString(PyExc_TypeError, "data type not understood");
+ if (PyBytes_Check(obj)) {
+ PyErr_Format(PyExc_TypeError, "data type \"%s\" not understood",
+ PyBytes_AS_STRING(obj));
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError, "data type not understood");
+ }
*at = NULL;
return PY_FAIL;
}
diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py
index 12be38358..33006b737 100644
--- a/numpy/core/tests/test_datetime.py
+++ b/numpy/core/tests/test_datetime.py
@@ -13,17 +13,16 @@ class TestDateTime(TestCase):
assert_(dt2 == np.dtype('timedelta64[%s]' % unit))
# Check that the parser rejects bad datetime types
- assert_raises(ValueError, np.dtype, 'M8[badunit]')
- assert_raises(ValueError, np.dtype, 'm8[badunit]')
- assert_raises(ValueError, np.dtype, 'm8[badunit]')
- assert_raises(ValueError, np.dtype, 'M8[YY]')
- assert_raises(ValueError, np.dtype, 'm8[YY]')
- assert_raises(ValueError, np.dtype, 'M4')
- assert_raises(ValueError, np.dtype, 'm4')
- assert_raises(ValueError, np.dtype, 'M7')
- assert_raises(ValueError, np.dtype, 'm7')
- assert_raises(ValueError, np.dtype, 'M16')
- assert_raises(ValueError, np.dtype, 'm16')
+ assert_raises(TypeError, np.dtype, 'M8[badunit]')
+ assert_raises(TypeError, np.dtype, 'm8[badunit]')
+ assert_raises(TypeError, np.dtype, 'M8[YY]')
+ assert_raises(TypeError, np.dtype, 'm8[YY]')
+ assert_raises(TypeError, np.dtype, 'M4')
+ assert_raises(TypeError, np.dtype, 'm4')
+ assert_raises(TypeError, np.dtype, 'M7')
+ assert_raises(TypeError, np.dtype, 'm7')
+ assert_raises(TypeError, np.dtype, 'M16')
+ assert_raises(TypeError, np.dtype, 'm16')
def test_hours(self):
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 40c3a3eea..c79b755be 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -33,6 +33,14 @@ class TestBuiltin(TestCase):
self.assertTrue(dt.byteorder != dt3.byteorder, "bogus test")
assert_dtype_equal(dt, dt3)
+ def test_invalid_types(self):
+ # Make sure invalid type strings raise exceptions
+ for typestr in ['O3', 'O5', 'O7', 'b3', 'h4', 'I5', 'l4', 'l8',
+ 'L4', 'L8', 'q8', 'q16', 'Q8', 'Q16', 'e3',
+ 'f5', 'd8', 't8', 'g12', 'g16']:
+ #print typestr
+ assert_raises(TypeError, np.dtype, typestr)
+
class TestRecord(TestCase):
def test_equivalent_record(self):
"""Test whether equivalent record dtypes hash the same."""
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index bbd855cf8..2cb888d55 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -3527,10 +3527,10 @@ class MaskedArray(ndarray):
# convert to object array to make filled work
names = self.dtype.names
if names is None:
- res = self._data.astype("|O8")
+ res = self._data.astype("O")
res[m] = f
else:
- rdtype = _recursive_make_descr(self.dtype, "|O8")
+ rdtype = _recursive_make_descr(self.dtype, "O")
res = self._data.astype(rdtype)
_recursive_printoption(res, m, f)
else: