diff options
author | Ralf Gommers <ralf.gommers@googlemail.com> | 2011-04-03 12:33:07 +0200 |
---|---|---|
committer | Ralf Gommers <ralf.gommers@googlemail.com> | 2011-04-03 13:02:14 +0200 |
commit | a311969ea2f47b486da14da99a26e72c12a0c20f (patch) | |
tree | f42c8786574e6f1929724cb6d541d41f9a5fafb3 /numpy/lib | |
parent | e340e665adba35a5aba7fac09e28ac1f2e4d949b (diff) | |
download | numpy-a311969ea2f47b486da14da99a26e72c12a0c20f.tar.gz |
ENH: add ndmin keyword to loadtxt. Closes #1562.
Thanks to Paul Anton Letnes and Derek Homeier.
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/npyio.py | 21 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 17 |
2 files changed, 36 insertions, 2 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 25737cbbe..d891c6be1 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -583,7 +583,8 @@ def _getconv(dtype): def loadtxt(fname, dtype=float, comments='#', delimiter=None, - converters=None, skiprows=0, usecols=None, unpack=False): + converters=None, skiprows=0, usecols=None, unpack=False, + ndmin=0): """ Load data from a text file. @@ -621,6 +622,9 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, If True, the returned array is transposed, so that arguments may be unpacked using ``x, y, z = loadtxt(...)``. When used with a record data-type, arrays are returned for each field. Default is False. + ndmin : int, optional + The returned array must have at least `ndmin` dimensions. + Legal values: 0 (default), 1 or 2. Returns ------- @@ -799,7 +803,20 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, X = np.array(X, dtype) - X = np.squeeze(X) + # Verify that the array has at least dimensions `ndmin`. + # Check correctness of the values of `ndmin` + if not ndmin in [0, 1, 2]: + raise ValueError('Illegal value of ndmin keyword: %s' % ndmin) + # Tweak the size and shape of the arrays + if X.ndim > ndmin: + X = np.squeeze(X) + # Has to be in this order for the odd case ndmin=1, X.squeeze().ndim=0 + if X.ndim < ndmin: + if ndmin == 1: + X.shape = (X.size, ) + elif ndmin == 2: + X.shape = (X.size, 1) + if unpack: if len(dtype_types) > 1: # For structured arrays, return an array for each field. diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index 9afc5dd0b..30c0998dc 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -468,6 +468,23 @@ class TestLoadTxt(TestCase): assert_array_equal(b, np.array([21, 35])) assert_array_equal(c, np.array([ 72., 58.])) + def test_ndmin_keyword(self): + c = StringIO() + c.write(asbytes('1,2,3\n4,5,6')) + c.seek(0) + x = np.loadtxt(c, dtype=int, delimiter=',', ndmin=1) + a = np.array([[1, 2, 3], [4, 5, 6]]) + assert_array_equal(x, a) + d = StringIO() + d.write(asbytes('0,1,2')) + d.seek(0) + x = np.loadtxt(d, dtype=int, delimiter=',', ndmin=2) + assert_(x.shape == (3, 1)) + assert_raises(ValueError, np.loadtxt, d, ndmin=3) + assert_raises(ValueError, np.loadtxt, d, ndmin=1.5) + e = StringIO() + assert_(np.loadtxt(e, ndmin=2).shape == (0, 1,)) + class Testfromregex(TestCase): def test_record(self): |