summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2018-01-13 20:15:32 -0800
committerEric Wieser <wieser.eric@gmail.com>2018-01-15 19:09:15 -0800
commit0a8786163133c4227bfa7dbc3c9a6800172b65f7 (patch)
treeb8c598fbdcc95b2eaab2a0b28d82ab6d33133e84 /numpy/lib
parent5f01e54b20e38d483e8bab31bf5f953a860fe8d3 (diff)
downloadnumpy-0a8786163133c4227bfa7dbc3c9a6800172b65f7.tar.gz
BUG: Resize bytes_ columns in genfromtxt
Fixes gh-10394, due to regression in gh-10054
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/npyio.py37
-rw-r--r--numpy/lib/tests/test_io.py7
2 files changed, 27 insertions, 17 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index e4d827334..9e979bbe6 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -2042,7 +2042,6 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
strcolidx = [i for (i, v) in enumerate(column_types)
if v == np.unicode_]
- type_str = np.unicode_
if byte_converters and strcolidx:
# convert strings back to bytes for backward compatibility
warnings.warn(
@@ -2058,33 +2057,37 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
try:
data = [encode_unicode_cols(r) for r in data]
- type_str = np.bytes_
except UnicodeEncodeError:
pass
+ else:
+ for i in strcolidx:
+ column_types[i] = np.bytes_
+ # Update string types to be the right length
+ sized_column_types = column_types[:]
+ for i, col_type in enumerate(column_types):
+ if np.issubdtype(col_type, np.character):
+ n_chars = max(len(row[i]) for row in data)
+ sized_column_types[i] = (col_type, n_chars)
- # ... and take the largest number of chars.
- for i in strcolidx:
- max_line_length = max(len(row[i]) for row in data)
- column_types[i] = np.dtype((type_str, max_line_length))
- #
if names is None:
- # If the dtype is uniform, don't define names, else use ''
- base = set([c.type for c in converters if c._checked])
+ # If the dtype is uniform (before sizing strings)
+ base = set([
+ c_type
+ for c, c_type in zip(converters, column_types)
+ if c._checked])
if len(base) == 1:
- if strcolidx:
- (ddtype, mdtype) = (type_str, bool)
- else:
- (ddtype, mdtype) = (list(base)[0], bool)
+ uniform_type, = base
+ (ddtype, mdtype) = (uniform_type, bool)
else:
ddtype = [(defaultfmt % i, dt)
- for (i, dt) in enumerate(column_types)]
+ for (i, dt) in enumerate(sized_column_types)]
if usemask:
mdtype = [(defaultfmt % i, bool)
- for (i, dt) in enumerate(column_types)]
+ for (i, dt) in enumerate(sized_column_types)]
else:
- ddtype = list(zip(names, column_types))
- mdtype = list(zip(names, [bool] * len(column_types)))
+ ddtype = list(zip(names, sized_column_types))
+ mdtype = list(zip(names, [bool] * len(sized_column_types)))
output = np.array(data, dtype=ddtype)
if usemask:
outputmask = np.array(masks, dtype=mdtype)
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index 75a8e4968..2daa0153b 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -2056,6 +2056,13 @@ M 33 21.99
assert_(isinstance(test, np.recarray))
assert_equal(test, control)
+ #gh-10394
+ data = TextIO('color\n"red"\n"blue"')
+ test = np.recfromcsv(data, converters={0: lambda x: x.strip(b'\"')})
+ control = np.array([('red',), ('blue',)], dtype=[('color', (bytes, 4))])
+ assert_equal(test.dtype, control.dtype)
+ assert_equal(test, control)
+
def test_max_rows(self):
# Test the `max_rows` keyword argument.
data = '1 2\n3 4\n5 6\n7 8\n9 10\n'