diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2011-01-16 23:02:06 -0800 |
---|---|---|
committer | Mark Wiebe <mwwiebe@gmail.com> | 2011-01-16 23:02:06 -0800 |
commit | e6c3a0c9411fa0fb701261a206a2a4d002de68cd (patch) | |
tree | c948986c19549cebf7f180d98e1be9a83aa2df34 /numpy/lib | |
parent | 03dd59a0fa368d30a05b5a3d7da8d41b253e087a (diff) | |
download | numpy-e6c3a0c9411fa0fb701261a206a2a4d002de68cd.tar.gz |
ENH: core: Implement PyArray_CopyInto using the new iterator
This change also uses the dtype conversion code implemented for new
iterator buffering, which differs slightly from the previous casting
behavior. In particular, fields are matched up by name instead of
position, so code depending on that behavior breaks. The loadtxt
function has been fixed to not depend on this casting behavior.
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/npyio.py | 60 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 10 |
2 files changed, 47 insertions, 23 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 34bbd1469..3f4db4593 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -683,19 +683,44 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, X = [] def flatten_dtype(dt): - """Unpack a structured data-type.""" + """Unpack a structured data-type, and produce re-packing info.""" if dt.names is None: # If the dtype is flattened, return. # If the dtype has a shape, the dtype occurs # in the list more than once. - return [dt.base] * int(np.prod(dt.shape)) + shape = dt.shape + if len(shape) == 0: + return ([dt.base], None) + else: + packing = [(shape[-1], tuple)] + if len(shape) > 1: + for dim in dt.shape[-2:0:-1]: + packing = [(dim*packing[0][0],packing*dim)] + packing = packing*shape[0] + return ([dt.base] * int(np.prod(dt.shape)), packing) else: types = [] + packing = [] for field in dt.names: tp, bytes = dt.fields[field] - flat_dt = flatten_dtype(tp) + flat_dt, flat_packing = flatten_dtype(tp) types.extend(flat_dt) - return types + packing.append((len(flat_dt),flat_packing)) + return (types, packing) + + def pack_items(items, packing): + """Pack items into nested lists based on re-packing info.""" + if packing == None: + return items[0] + elif packing is tuple: + return tuple(items) + else: + start = 0 + ret = [] + for length, subpacking in packing: + ret.append(pack_items(items[start:start+length], subpacking)) + start += length + return tuple(ret) def split_line(line): """Chop off comments, strip, and split at delimiter.""" @@ -724,7 +749,7 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, first_vals = split_line(first_line) N = len(usecols or first_vals) - dtype_types = flatten_dtype(dtype) + dtype_types, packing = flatten_dtype(dtype) if len(dtype_types) > 1: # We're dealing with a structured array, each field of # the dtype matches a column @@ -732,6 +757,8 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, else: # All fields have the same dtype converters = [defconv for i in xrange(N)] + if N > 1: + packing = [(N, tuple)] # By preference, use the converters specified by the user for i, conv in (user_converters or {}).iteritems(): @@ -753,27 +780,16 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, vals = [vals[i] for i in usecols] # Convert each value according to its column and store - X.append(tuple([conv(val) for (conv, val) in zip(converters, vals)])) + items = [conv(val) for (conv, val) in zip(converters, vals)] + # Then pack it according to the dtype's nesting + items = pack_items(items, packing) + + X.append(items) finally: if own_fh: fh.close() - if len(dtype_types) > 1: - # We're dealing with a structured array, with a dtype such as - # [('x', int), ('y', [('s', int), ('t', float)])] - # - # First, create the array using a flattened dtype: - # [('x', int), ('s', int), ('t', float)] - # - # Then, view the array using the specified dtype. - try: - X = np.array(X, dtype=np.dtype([('', t) for t in dtype_types])) - X = X.view(dtype) - except TypeError: - # In the case we have an object dtype - X = np.array(X, dtype=dtype) - else: - X = np.array(X, dtype) + X = np.array(X, dtype) X = np.squeeze(X) if unpack: diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index a85b01909..04497dee8 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -381,6 +381,15 @@ class TestLoadTxt(TestCase): dtype=dt) assert_array_equal(x, a) + def test_3d_shaped_dtype(self): + c = StringIO("aaaa 1.0 8.0 1 2 3 4 5 6 7 8 9 10 11 12") + dt = np.dtype([('name', 'S4'), ('x', float), ('y', float), + ('block', int, (2, 2, 3))]) + x = np.loadtxt(c, dtype=dt) + a = np.array([('aaaa', 1.0, 8.0, [[[1, 2, 3], [4, 5, 6]],[[7, 8, 9], [10, 11, 12]]])], + dtype=dt) + assert_array_equal(x, a) + def test_empty_file(self): c = StringIO() assert_raises(IOError, np.loadtxt, c) @@ -884,7 +893,6 @@ M 33 21.99 dtype=dt) assert_array_equal(x, a) - def test_withmissing(self): data = StringIO('A,B\n0,1\n2,N/A') kwargs = dict(delimiter=",", missing_values="N/A", names=True) |