summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/descriptor.c15
-rw-r--r--numpy/core/tests/test_multiarray.py13
2 files changed, 22 insertions, 6 deletions
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index b4a0ce37d..8d983ffc9 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -437,7 +437,7 @@ _convert_from_array_descr(PyObject *obj, int align)
goto fail;
}
name = PyTuple_GET_ITEM(item, 0);
- if (PyUString_Check(name)) {
+ if (PyBaseString_Check(name)) {
title = NULL;
}
else if (PyTuple_Check(name)) {
@@ -446,7 +446,7 @@ _convert_from_array_descr(PyObject *obj, int align)
}
title = PyTuple_GET_ITEM(name, 0);
name = PyTuple_GET_ITEM(name, 1);
- if (!PyUString_Check(name)) {
+ if (!PyBaseString_Check(name)) {
goto fail;
}
}
@@ -457,6 +457,17 @@ _convert_from_array_descr(PyObject *obj, int align)
/* Insert name into nameslist */
Py_INCREF(name);
+#if !defined(NPY_PY3K)
+ /* convert unicode name to ascii on Python 2 if possible */
+ if (PyUnicode_Check(name)) {
+ PyObject *tmp = PyUnicode_AsASCIIString(name);
+ Py_DECREF(name);
+ if (tmp == NULL) {
+ goto fail;
+ }
+ name = tmp;
+ }
+#endif
if (PyUString_GET_SIZE(name) == 0) {
Py_DECREF(name);
if (title == NULL) {
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index fba169ebf..43bfb0635 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4693,10 +4693,15 @@ class TestRecord(object):
y['a']
def test_unicode_field_names(self):
- # Unicode field names are not allowed on Py2
- title = u'b'
- assert_raises(TypeError, np.dtype, [(title, int)])
- assert_raises(TypeError, np.dtype, [(('a', title), int)])
+ # Unicode field names are converted to ascii on Python 2:
+ encodable_name = u'b'
+ assert_equal(np.dtype([(encodable_name, int)]).names[0], b'b')
+ assert_equal(np.dtype([(('a', encodable_name), int)]).names[0], b'b')
+
+ # But raises UnicodeEncodeError if it can't be encoded:
+ nonencodable_name = u'\uc3bc'
+ assert_raises(UnicodeEncodeError, np.dtype, [(nonencodable_name, int)])
+ assert_raises(UnicodeEncodeError, np.dtype, [(('a', nonencodable_name), int)])
def test_field_names(self):
# Test unicode and 8-bit / byte strings can be used