summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/lib/npyio.py35
-rw-r--r--numpy/lib/tests/test_io.py52
2 files changed, 57 insertions, 30 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index c8cebaed8..db1d87435 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -1205,7 +1205,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
excludelist=None, deletechars=None, replace_space='_',
autostrip=False, case_sensitive=True, defaultfmt="f%i",
unpack=None, usemask=False, loose=True, invalid_raise=True,
- nrows=None):
+ max_rows=None):
"""
Load data from a text file, with missing values handled as specified.
@@ -1286,9 +1286,10 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
If True, an exception is raised if an inconsistency is detected in the
number of columns.
If False, a warning is emitted and the offending lines are skipped.
- nrows : int, optional
- The number of rows to read. Must not be used with skip_footer at the
- same time.
+ max_rows : int, optional
+ The maximum number of rows to read. Must not be used with skip_footer
+ at the same time. If given, the value must be at least 1. Default is
+ to read the entire file.
.. versionadded:: 1.10.0
@@ -1359,11 +1360,13 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
dtype=[('intvar', '<i8'), ('fltvar', '<f8'), ('strvar', '|S5')])
"""
- # Check keywords conflict
- if skip_footer and (nrows is not None):
- raise ValueError(
- "keywords 'skip_footer' and 'nrows' can not be specified "
- "at the same time")
+ if max_rows is not None:
+ if skip_footer:
+ raise ValueError(
+ "The keywords 'skip_footer' and 'max_rows' can not be "
+ "specified at the same time.")
+ if max_rows < 1:
+ raise ValueError("'max_rows' must be at least 1.")
# Py3 data conversions to bytes, for convenience
if comments is not None:
@@ -1654,15 +1657,13 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
# Parse each line
for (i, line) in enumerate(itertools.chain([first_line, ], fhd)):
- if (nrows is not None) and (len(rows) >= nrows):
- break
values = split_line(line)
nbvalues = len(values)
# Skip an empty line
if nbvalues == 0:
continue
- # Select only the columns we need
if usecols:
+ # Select only the columns we need
try:
values = [values[_] for _ in usecols]
except IndexError:
@@ -1675,16 +1676,14 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
append_to_rows(tuple(values))
if usemask:
append_to_masks(tuple([v.strip() in m
- for (v, m) in zip(values, missing_values)]))
+ for (v, m) in zip(values,
+ missing_values)]))
+ if len(rows) == max_rows:
+ break
if own_fhd:
fhd.close()
- if (nrows is not None) and (len(rows) != nrows):
- raise AssertionError(
- "%d rows required but got %d valid rows instead"
- %(nrows, len(rows)))
-
# Upgrade the converters (if needed)
if dtype is None:
for (i, converter) in enumerate(converters):
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index df5ab1a2a..2ce78575b 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -1641,29 +1641,57 @@ M 33 21.99
self.assertTrue(isinstance(test, np.recarray))
assert_equal(test, control)
- def test_nrows(self):
- #
+ def test_max_rows(self):
+ # Test the `max_rows` keyword argument.
+ data = '1 2\n3 4\n5 6\n7 8\n9 10\n'
+ txt = TextIO(data)
+ a1 = np.genfromtxt(txt, max_rows=3)
+ a2 = np.genfromtxt(txt)
+ assert_equal(a1, [[1, 2], [3, 4], [5, 6]])
+ assert_equal(a2, [[7, 8], [9, 10]])
+
+ # max_rows must be at least 1.
+ assert_raises(ValueError, np.genfromtxt, TextIO(data), max_rows=0)
+
+ # An input with several invalid rows.
data = '1 1\n2 2\n0 \n3 3\n4 4\n5 \n6 \n7 \n'
- test = np.genfromtxt(TextIO(data), nrows=2)
+
+ test = np.genfromtxt(TextIO(data), max_rows=2)
control = np.array([[1., 1.], [2., 2.]])
- assert_equal(test, control)
+ assert_equal(test, control)
+
# Test keywords conflict
- assert_raises(ValueError, np.genfromtxt, TextIO(data), skip_footer=1, nrows=4)
+ assert_raises(ValueError, np.genfromtxt, TextIO(data), skip_footer=1,
+ max_rows=4)
+
# Test with invalid value
- assert_raises(ValueError, np.genfromtxt, TextIO(data), nrows=4)
+ assert_raises(ValueError, np.genfromtxt, TextIO(data), max_rows=4)
+
# Test with invalid not raise
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
- test = np.genfromtxt(TextIO(data), nrows=4, invalid_raise=False)
+
+ test = np.genfromtxt(TextIO(data), max_rows=4, invalid_raise=False)
control = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]])
assert_equal(test, control)
- # Test without enough valid rows
- assert_raises(AssertionError, np.genfromtxt, TextIO(data), nrows=5)
- data = 'a b\n#c d\n1 1\n2 2\n#0 \n3 3\n4 4\n5 \n6 \n7 \n'
+ test = np.genfromtxt(TextIO(data), max_rows=5, invalid_raise=False)
+ control = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]])
+ assert_equal(test, control)
+
+ # Structured array with field names.
+ data = 'a b\n#c d\n1 1\n2 2\n#0 \n3 3\n4 4\n5 5\n'
+
# Test with header, names and comments
- test = np.genfromtxt(TextIO(data), skip_header=1, nrows=4, names=True)
- control = np.array([(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (4.0, 4.0)],
+ txt = TextIO(data)
+ test = np.genfromtxt(txt, skip_header=1, max_rows=3, names=True)
+ control = np.array([(1.0, 1.0), (2.0, 2.0), (3.0, 3.0)],
+ dtype=[('c', '<f8'), ('d', '<f8')])
+ assert_equal(test, control)
+ # To continue reading the same "file", don't use skip_header or
+ # names, and use the previously determined dtype.
+ test = np.genfromtxt(txt, max_rows=None, dtype=test.dtype)
+ control = np.array([(4.0, 4.0), (5.0, 5.0)],
dtype=[('c', '<f8'), ('d', '<f8')])
assert_equal(test, control)