summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/npyio.py37
-rw-r--r--numpy/lib/tests/test_io.py7
2 files changed, 27 insertions, 17 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 588e26b4c..096f1a3a4 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -2050,7 +2050,6 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
strcolidx = [i for (i, v) in enumerate(column_types)
if v == np.unicode_]
- type_str = np.unicode_
if byte_converters and strcolidx:
# convert strings back to bytes for backward compatibility
warnings.warn(
@@ -2066,33 +2065,37 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
try:
data = [encode_unicode_cols(r) for r in data]
- type_str = np.bytes_
except UnicodeEncodeError:
pass
+ else:
+ for i in strcolidx:
+ column_types[i] = np.bytes_
+ # Update string types to be the right length
+ sized_column_types = column_types[:]
+ for i, col_type in enumerate(column_types):
+ if np.issubdtype(col_type, np.character):
+ n_chars = max(len(row[i]) for row in data)
+ sized_column_types[i] = (col_type, n_chars)
- # ... and take the largest number of chars.
- for i in strcolidx:
- max_line_length = max(len(row[i]) for row in data)
- column_types[i] = np.dtype((type_str, max_line_length))
- #
if names is None:
- # If the dtype is uniform, don't define names, else use ''
- base = set([c.type for c in converters if c._checked])
+ # If the dtype is uniform (before sizing strings)
+ base = set([
+ c_type
+ for c, c_type in zip(converters, column_types)
+ if c._checked])
if len(base) == 1:
- if strcolidx:
- (ddtype, mdtype) = (type_str, bool)
- else:
- (ddtype, mdtype) = (list(base)[0], bool)
+ uniform_type, = base
+ (ddtype, mdtype) = (uniform_type, bool)
else:
ddtype = [(defaultfmt % i, dt)
- for (i, dt) in enumerate(column_types)]
+ for (i, dt) in enumerate(sized_column_types)]
if usemask:
mdtype = [(defaultfmt % i, bool)
- for (i, dt) in enumerate(column_types)]
+ for (i, dt) in enumerate(sized_column_types)]
else:
- ddtype = list(zip(names, column_types))
- mdtype = list(zip(names, [bool] * len(column_types)))
+ ddtype = list(zip(names, sized_column_types))
+ mdtype = list(zip(names, [bool] * len(sized_column_types)))
output = np.array(data, dtype=ddtype)
if usemask:
outputmask = np.array(masks, dtype=mdtype)
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index a274636da..d05fcd543 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -2056,6 +2056,13 @@ M 33 21.99
assert_(isinstance(test, np.recarray))
assert_equal(test, control)
+ #gh-10394
+ data = TextIO('color\n"red"\n"blue"')
+ test = np.recfromcsv(data, converters={0: lambda x: x.strip(b'\"')})
+ control = np.array([('red',), ('blue',)], dtype=[('color', (bytes, 4))])
+ assert_equal(test.dtype, control.dtype)
+ assert_equal(test, control)
+
def test_max_rows(self):
# Test the `max_rows` keyword argument.
data = '1 2\n3 4\n5 6\n7 8\n9 10\n'