diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/_internal.py | 4 | ||||
-rw-r--r-- | numpy/core/records.py | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 115 |
3 files changed, 70 insertions, 53 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 1cf89aab0..ce7ef7060 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -9,7 +9,7 @@ from __future__ import division, absolute_import, print_function import re import sys -from numpy.compat import basestring +from numpy.compat import basestring, unicode from .multiarray import dtype, array, ndarray try: import ctypes @@ -294,7 +294,7 @@ def _newnames(datatype, order): """ oldnames = datatype.names nameslist = list(oldnames) - if isinstance(order, str): + if isinstance(order, (str, unicode)): order = [order] seen = set() if isinstance(order, (list, tuple)): diff --git a/numpy/core/records.py b/numpy/core/records.py index 612d39322..a483871ba 100644 --- a/numpy/core/records.py +++ b/numpy/core/records.py @@ -42,7 +42,7 @@ import warnings from . import numeric as sb from . import numerictypes as nt -from numpy.compat import isfileobj, bytes, long +from numpy.compat import isfileobj, bytes, long, unicode from .arrayprint import get_printoptions # All of the functions allow formats to be a dtype @@ -174,7 +174,7 @@ class format_parser(object): if (names): if (type(names) in [list, tuple]): pass - elif isinstance(names, str): + elif isinstance(names, (str, unicode)): names = names.split(',') else: raise NameError("illegal input names %s" % repr(names)) diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 1511f5b6b..d751657db 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -4748,55 +4748,72 @@ class TestRecord(object): # Error raised when multiple fields have the same name assert_raises(ValueError, test_assign) - if sys.version_info[0] >= 3: - def test_bytes_fields(self): - # Bytes are not allowed in field names and not recognized in titles - # on Py3 - assert_raises(TypeError, np.dtype, [(b'a', int)]) - assert_raises(TypeError, np.dtype, [(('b', b'a'), int)]) - - dt = np.dtype([((b'a', 'b'), int)]) - assert_raises(TypeError, dt.__getitem__, b'a') - - x = np.array([(1,), (2,), (3,)], dtype=dt) - assert_raises(IndexError, x.__getitem__, b'a') - - y = x[0] - assert_raises(IndexError, y.__getitem__, b'a') - - def test_multiple_field_name_unicode(self): - def test_assign_unicode(): - dt = np.dtype([("\u20B9", "f8"), - ("B", "f8"), - ("\u20B9", "f8")]) - - # Error raised when multiple fields have the same name(unicode included) - assert_raises(ValueError, test_assign_unicode) - - else: - def test_unicode_field_titles(self): - # Unicode field titles are added to field dict on Py2 - title = u'b' - dt = np.dtype([((title, 'a'), int)]) - dt[title] - dt['a'] - x = np.array([(1,), (2,), (3,)], dtype=dt) - x[title] - x['a'] - y = x[0] - y[title] - y['a'] - - def test_unicode_field_names(self): - # Unicode field names are converted to ascii on Python 2: - encodable_name = u'b' - assert_equal(np.dtype([(encodable_name, int)]).names[0], b'b') - assert_equal(np.dtype([(('a', encodable_name), int)]).names[0], b'b') - - # But raises UnicodeEncodeError if it can't be encoded: - nonencodable_name = u'\uc3bc' - assert_raises(UnicodeEncodeError, np.dtype, [(nonencodable_name, int)]) - assert_raises(UnicodeEncodeError, np.dtype, [(('a', nonencodable_name), int)]) + @pytest.mark.skipif(sys.version_info[0] < 3, reason="Not Python 3") + def test_bytes_fields(self): + # Bytes are not allowed in field names and not recognized in titles + # on Py3 + assert_raises(TypeError, np.dtype, [(b'a', int)]) + assert_raises(TypeError, np.dtype, [(('b', b'a'), int)]) + + dt = np.dtype([((b'a', 'b'), int)]) + assert_raises(TypeError, dt.__getitem__, b'a') + + x = np.array([(1,), (2,), (3,)], dtype=dt) + assert_raises(IndexError, x.__getitem__, b'a') + + y = x[0] + assert_raises(IndexError, y.__getitem__, b'a') + + @pytest.mark.skipif(sys.version_info[0] < 3, reason="Not Python 3") + def test_multiple_field_name_unicode(self): + def test_assign_unicode(): + dt = np.dtype([("\u20B9", "f8"), + ("B", "f8"), + ("\u20B9", "f8")]) + + # Error raised when multiple fields have the same name(unicode included) + assert_raises(ValueError, test_assign_unicode) + + @pytest.mark.skipif(sys.version_info[0] >= 3, reason="Not Python 2") + def test_unicode_field_titles(self): + # Unicode field titles are added to field dict on Py2 + title = u'b' + dt = np.dtype([((title, 'a'), int)]) + dt[title] + dt['a'] + x = np.array([(1,), (2,), (3,)], dtype=dt) + x[title] + x['a'] + y = x[0] + y[title] + y['a'] + + @pytest.mark.skipif(sys.version_info[0] >= 3, reason="Not Python 2") + def test_unicode_field_names(self): + # Unicode field names are converted to ascii on Python 2: + encodable_name = u'b' + assert_equal(np.dtype([(encodable_name, int)]).names[0], b'b') + assert_equal(np.dtype([(('a', encodable_name), int)]).names[0], b'b') + + # But raises UnicodeEncodeError if it can't be encoded: + nonencodable_name = u'\uc3bc' + assert_raises(UnicodeEncodeError, np.dtype, [(nonencodable_name, int)]) + assert_raises(UnicodeEncodeError, np.dtype, [(('a', nonencodable_name), int)]) + + def test_fromarrays_unicode(self): + # A single name string provided to fromarrays() is allowed to be unicode + # on both Python 2 and 3: + x = np.core.records.fromarrays([[0], [1]], names=u'a,b', formats=u'i4,i4') + assert_equal(x['a'][0], 0) + assert_equal(x['b'][0], 1) + + def test_unicode_order(self): + # Test that we can sort with order as a unicode field name in both Python 2 and + # 3: + name = u'b' + x = np.array([1, 3, 2], dtype=[(name, int)]) + x.sort(order=name) + assert_equal(x[u'b'], np.array([1, 2, 3])) def test_field_names(self): # Test unicode and 8-bit / byte strings can be used |