diff options
| -rw-r--r-- | numpy/lib/npyio.py | 35 | ||||
| -rw-r--r-- | numpy/lib/tests/test_io.py | 52 |
2 files changed, 57 insertions, 30 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index c8cebaed8..db1d87435 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -1205,7 +1205,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, excludelist=None, deletechars=None, replace_space='_', autostrip=False, case_sensitive=True, defaultfmt="f%i", unpack=None, usemask=False, loose=True, invalid_raise=True, - nrows=None): + max_rows=None): """ Load data from a text file, with missing values handled as specified. @@ -1286,9 +1286,10 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, If True, an exception is raised if an inconsistency is detected in the number of columns. If False, a warning is emitted and the offending lines are skipped. - nrows : int, optional - The number of rows to read. Must not be used with skip_footer at the - same time. + max_rows : int, optional + The maximum number of rows to read. Must not be used with skip_footer + at the same time. If given, the value must be at least 1. Default is + to read the entire file. .. versionadded:: 1.10.0 @@ -1359,11 +1360,13 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, dtype=[('intvar', '<i8'), ('fltvar', '<f8'), ('strvar', '|S5')]) """ - # Check keywords conflict - if skip_footer and (nrows is not None): - raise ValueError( - "keywords 'skip_footer' and 'nrows' can not be specified " - "at the same time") + if max_rows is not None: + if skip_footer: + raise ValueError( + "The keywords 'skip_footer' and 'max_rows' can not be " + "specified at the same time.") + if max_rows < 1: + raise ValueError("'max_rows' must be at least 1.") # Py3 data conversions to bytes, for convenience if comments is not None: @@ -1654,15 +1657,13 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, # Parse each line for (i, line) in enumerate(itertools.chain([first_line, ], fhd)): - if (nrows is not None) and (len(rows) >= nrows): - break values = split_line(line) nbvalues = len(values) # Skip an empty line if nbvalues == 0: continue - # Select only the columns we need if usecols: + # Select only the columns we need try: values = [values[_] for _ in usecols] except IndexError: @@ -1675,16 +1676,14 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, append_to_rows(tuple(values)) if usemask: append_to_masks(tuple([v.strip() in m - for (v, m) in zip(values, missing_values)])) + for (v, m) in zip(values, + missing_values)])) + if len(rows) == max_rows: + break if own_fhd: fhd.close() - if (nrows is not None) and (len(rows) != nrows): - raise AssertionError( - "%d rows required but got %d valid rows instead" - %(nrows, len(rows))) - # Upgrade the converters (if needed) if dtype is None: for (i, converter) in enumerate(converters): diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index df5ab1a2a..2ce78575b 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -1641,29 +1641,57 @@ M 33 21.99 self.assertTrue(isinstance(test, np.recarray)) assert_equal(test, control) - def test_nrows(self): - # + def test_max_rows(self): + # Test the `max_rows` keyword argument. + data = '1 2\n3 4\n5 6\n7 8\n9 10\n' + txt = TextIO(data) + a1 = np.genfromtxt(txt, max_rows=3) + a2 = np.genfromtxt(txt) + assert_equal(a1, [[1, 2], [3, 4], [5, 6]]) + assert_equal(a2, [[7, 8], [9, 10]]) + + # max_rows must be at least 1. + assert_raises(ValueError, np.genfromtxt, TextIO(data), max_rows=0) + + # An input with several invalid rows. data = '1 1\n2 2\n0 \n3 3\n4 4\n5 \n6 \n7 \n' - test = np.genfromtxt(TextIO(data), nrows=2) + + test = np.genfromtxt(TextIO(data), max_rows=2) control = np.array([[1., 1.], [2., 2.]]) - assert_equal(test, control) + assert_equal(test, control) + # Test keywords conflict - assert_raises(ValueError, np.genfromtxt, TextIO(data), skip_footer=1, nrows=4) + assert_raises(ValueError, np.genfromtxt, TextIO(data), skip_footer=1, + max_rows=4) + # Test with invalid value - assert_raises(ValueError, np.genfromtxt, TextIO(data), nrows=4) + assert_raises(ValueError, np.genfromtxt, TextIO(data), max_rows=4) + # Test with invalid not raise with warnings.catch_warnings(): warnings.filterwarnings("ignore") - test = np.genfromtxt(TextIO(data), nrows=4, invalid_raise=False) + + test = np.genfromtxt(TextIO(data), max_rows=4, invalid_raise=False) control = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]]) assert_equal(test, control) - # Test without enough valid rows - assert_raises(AssertionError, np.genfromtxt, TextIO(data), nrows=5) - data = 'a b\n#c d\n1 1\n2 2\n#0 \n3 3\n4 4\n5 \n6 \n7 \n' + test = np.genfromtxt(TextIO(data), max_rows=5, invalid_raise=False) + control = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]]) + assert_equal(test, control) + + # Structured array with field names. + data = 'a b\n#c d\n1 1\n2 2\n#0 \n3 3\n4 4\n5 5\n' + # Test with header, names and comments - test = np.genfromtxt(TextIO(data), skip_header=1, nrows=4, names=True) - control = np.array([(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (4.0, 4.0)], + txt = TextIO(data) + test = np.genfromtxt(txt, skip_header=1, max_rows=3, names=True) + control = np.array([(1.0, 1.0), (2.0, 2.0), (3.0, 3.0)], + dtype=[('c', '<f8'), ('d', '<f8')]) + assert_equal(test, control) + # To continue reading the same "file", don't use skip_header or + # names, and use the previously determined dtype. + test = np.genfromtxt(txt, max_rows=None, dtype=test.dtype) + control = np.array([(4.0, 4.0), (5.0, 5.0)], dtype=[('c', '<f8'), ('d', '<f8')]) assert_equal(test, control) |
