summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDerek Homeier <derek@astro.physik.uni-goettingen.de>2011-05-06 12:07:32 +0200
committerRalf Gommers <ralf.gommers@googlemail.com>2011-05-07 22:12:29 +0200
commitb233716a031cb523f9bc65dda2c22f69f6f0736a (patch)
tree13b7a7fa4d38e75d4c36c5b51c465cdddba7f98a
parent607d2b3bbe984892fbf345788a54eafebdf967ed (diff)
downloadnumpy-b233716a031cb523f9bc65dda2c22f69f6f0736a.tar.gz
use np.atleast_Nd() to boost dimensions to ndmin
-rw-r--r--numpy/lib/npyio.py13
-rw-r--r--numpy/lib/tests/test_io.py18
2 files changed, 24 insertions, 7 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 4d5e96b93..13f659d70 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -627,7 +627,7 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None,
data-type, arrays are returned for each field. Default is False.
ndmin : int, optional
The returned array will have at least `ndmin` dimensions.
- Otherwise single-dimensional axes will be squeezed.
+ Otherwise mono-dimensional axes will be squeezed.
Legal values: 0 (default), 1 or 2.
.. versionadded:: 1.6.0
@@ -803,6 +803,8 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None,
fh.close()
X = np.array(X, dtype)
+ # Multicolumn data are returned with shape (1, N, M), i.e.
+ # (1, 1, M) for a single row - remove the singleton dimension there
if X.ndim == 3 and X.shape[:2] == (1, 1):
X.shape = (1, -1)
@@ -810,15 +812,16 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None,
# Check correctness of the values of `ndmin`
if not ndmin in [0, 1, 2]:
raise ValueError('Illegal value of ndmin keyword: %s' % ndmin)
- # Tweak the size and shape of the arrays
+ # Tweak the size and shape of the arrays - remove extraneous dimensions
if X.ndim > ndmin:
X = np.squeeze(X)
- # Has to be in this order for the odd case ndmin=1, X.squeeze().ndim=0
+ # and ensure we have the minimum number of dimensions asked for
+ # - has to be in this order for the odd case ndmin=1, X.squeeze().ndim=0
if X.ndim < ndmin:
if ndmin == 1:
- X.shape = (X.size, )
+ X = np.atleast_1d(X)
elif ndmin == 2:
- X.shape = (X.size, 1)
+ X = np.atleast_2d(X).T
if unpack:
if len(dtype_types) > 1:
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index 97633d525..e83c82ecd 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -479,6 +479,10 @@ class TestLoadTxt(TestCase):
c = StringIO()
c.write(asbytes('1,2,3\n4,5,6'))
c.seek(0)
+ assert_raises(ValueError, np.loadtxt, c, ndmin=3)
+ c.seek(0)
+ assert_raises(ValueError, np.loadtxt, c, ndmin=1.5)
+ c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',', ndmin=1)
a = np.array([[1, 2, 3], [4, 5, 6]])
assert_array_equal(x, a)
@@ -487,13 +491,23 @@ class TestLoadTxt(TestCase):
d.seek(0)
x = np.loadtxt(d, dtype=int, delimiter=',', ndmin=2)
assert_(x.shape == (1, 3))
- assert_raises(ValueError, np.loadtxt, d, ndmin=3)
- assert_raises(ValueError, np.loadtxt, d, ndmin=1.5)
+ d.seek(0)
+ x = np.loadtxt(d, dtype=int, delimiter=',', ndmin=1)
+ assert_(x.shape == (3,))
+ d.seek(0)
+ x = np.loadtxt(d, dtype=int, delimiter=',', ndmin=0)
+ assert_(x.shape == (3,))
e = StringIO()
e.write(asbytes('0\n1\n2'))
e.seek(0)
x = np.loadtxt(e, dtype=int, delimiter=',', ndmin=2)
assert_(x.shape == (3, 1))
+ e.seek(0)
+ x = np.loadtxt(e, dtype=int, delimiter=',', ndmin=1)
+ assert_(x.shape == (3,))
+ e.seek(0)
+ x = np.loadtxt(e, dtype=int, delimiter=',', ndmin=0)
+ assert_(x.shape == (3,))
f = StringIO()
assert_(np.loadtxt(f, ndmin=2).shape == (0, 1,))
assert_(np.loadtxt(f, ndmin=1).shape == (0,))