diff options
Diffstat (limited to 'numpy/lib')
| -rw-r--r-- | numpy/lib/npyio.py | 37 | ||||
| -rw-r--r-- | numpy/lib/tests/test_io.py | 7 |
2 files changed, 27 insertions, 17 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 588e26b4c..096f1a3a4 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -2050,7 +2050,6 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, strcolidx = [i for (i, v) in enumerate(column_types) if v == np.unicode_] - type_str = np.unicode_ if byte_converters and strcolidx: # convert strings back to bytes for backward compatibility warnings.warn( @@ -2066,33 +2065,37 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, try: data = [encode_unicode_cols(r) for r in data] - type_str = np.bytes_ except UnicodeEncodeError: pass + else: + for i in strcolidx: + column_types[i] = np.bytes_ + # Update string types to be the right length + sized_column_types = column_types[:] + for i, col_type in enumerate(column_types): + if np.issubdtype(col_type, np.character): + n_chars = max(len(row[i]) for row in data) + sized_column_types[i] = (col_type, n_chars) - # ... and take the largest number of chars. - for i in strcolidx: - max_line_length = max(len(row[i]) for row in data) - column_types[i] = np.dtype((type_str, max_line_length)) - # if names is None: - # If the dtype is uniform, don't define names, else use '' - base = set([c.type for c in converters if c._checked]) + # If the dtype is uniform (before sizing strings) + base = set([ + c_type + for c, c_type in zip(converters, column_types) + if c._checked]) if len(base) == 1: - if strcolidx: - (ddtype, mdtype) = (type_str, bool) - else: - (ddtype, mdtype) = (list(base)[0], bool) + uniform_type, = base + (ddtype, mdtype) = (uniform_type, bool) else: ddtype = [(defaultfmt % i, dt) - for (i, dt) in enumerate(column_types)] + for (i, dt) in enumerate(sized_column_types)] if usemask: mdtype = [(defaultfmt % i, bool) - for (i, dt) in enumerate(column_types)] + for (i, dt) in enumerate(sized_column_types)] else: - ddtype = list(zip(names, column_types)) - mdtype = list(zip(names, [bool] * len(column_types))) + ddtype = list(zip(names, sized_column_types)) + mdtype = list(zip(names, [bool] * len(sized_column_types))) output = np.array(data, dtype=ddtype) if usemask: outputmask = np.array(masks, dtype=mdtype) diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index a274636da..d05fcd543 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -2056,6 +2056,13 @@ M 33 21.99 assert_(isinstance(test, np.recarray)) assert_equal(test, control) + #gh-10394 + data = TextIO('color\n"red"\n"blue"') + test = np.recfromcsv(data, converters={0: lambda x: x.strip(b'\"')}) + control = np.array([('red',), ('blue',)], dtype=[('color', (bytes, 4))]) + assert_equal(test.dtype, control.dtype) + assert_equal(test, control) + def test_max_rows(self): # Test the `max_rows` keyword argument. data = '1 2\n3 4\n5 6\n7 8\n9 10\n' |
