diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2011-08-18 20:37:36 -0700 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2011-08-20 10:09:33 -0600 |
commit | a54a0bdc89dd82ca3cfee89ae70a7fe3ee00e2a2 (patch) | |
tree | 563dec10e782285ee72b79ba82a330e75b7e1063 | |
parent | 78f7542add429689465c79d2d2f2042c38239672 (diff) | |
download | numpy-a54a0bdc89dd82ca3cfee89ae70a7fe3ee00e2a2.tar.gz |
BUG: loadtxt: There was some extra nesting for subarray dtypes (Ticket #1936)
-rw-r--r-- | numpy/lib/npyio.py | 15 | ||||
-rw-r--r-- | numpy/lib/tests/test_regression.py | 20 |
2 files changed, 30 insertions, 5 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 30c8cc5be..cb71ffad5 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -705,11 +705,10 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, if len(shape) == 0: return ([dt.base], None) else: - packing = [(shape[-1], tuple)] + packing = [(shape[-1], list)] if len(shape) > 1: - for dim in dt.shape[-2:0:-1]: - packing = [(dim*packing[0][0],packing*dim)] - packing = packing*shape[0] + for dim in dt.shape[-2::-1]: + packing = [(dim*packing[0][0], packing*dim)] return ([dt.base] * int(np.prod(dt.shape)), packing) else: types = [] @@ -718,7 +717,11 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, tp, bytes = dt.fields[field] flat_dt, flat_packing = flatten_dtype(tp) types.extend(flat_dt) - packing.append((len(flat_dt),flat_packing)) + # Avoid extra nesting for subarrays + if len(tp.shape) > 0: + packing.extend(flat_packing) + else: + packing.append((len(flat_dt), flat_packing)) return (types, packing) def pack_items(items, packing): @@ -727,6 +730,8 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, return items[0] elif packing is tuple: return tuple(items) + elif packing is list: + return list(items) else: start = 0 ret = [] diff --git a/numpy/lib/tests/test_regression.py b/numpy/lib/tests/test_regression.py index c0cfff9a5..c244aea87 100644 --- a/numpy/lib/tests/test_regression.py +++ b/numpy/lib/tests/test_regression.py @@ -202,5 +202,25 @@ class TestRegression(TestCase): except: raise AssertionError() + def test_loadtxt_fields_subarrays(self): + # For ticket #1936 + from StringIO import StringIO + dt = [("a", 'u1', 2), ("b", 'u1', 2)] + x = np.loadtxt(StringIO("0 1 2 3"), dtype=dt) + assert_equal(x, np.array([((0, 1), (2, 3))], dtype=dt)) + + dt = [("a", [("a", 'u1', (1,3)), ("b", 'u1')])] + x = np.loadtxt(StringIO("0 1 2 3"), dtype=dt) + assert_equal(x, np.array([(((0,1,2), 3),)], dtype=dt)) + + dt = [("a", 'u1', (2,2))] + x = np.loadtxt(StringIO("0 1 2 3"), dtype=dt) + assert_equal(x, np.array([(((0, 1), (2, 3)),)], dtype=dt)) + + dt = [("a", 'u1', (2,3,2))] + x = np.loadtxt(StringIO("0 1 2 3 4 5 6 7 8 9 10 11"), dtype=dt) + data = [((((0,1), (2,3), (4,5)), ((6,7), (8,9), (10,11))),)] + assert_equal(x, np.array(data, dtype=dt)) + if __name__ == "__main__": run_module_suite() |