diff options
Diffstat (limited to 'numpy/lib/npyio.py')
-rw-r--r-- | numpy/lib/npyio.py | 60 |
1 files changed, 38 insertions, 22 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 34bbd1469..3f4db4593 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -683,19 +683,44 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, X = [] def flatten_dtype(dt): - """Unpack a structured data-type.""" + """Unpack a structured data-type, and produce re-packing info.""" if dt.names is None: # If the dtype is flattened, return. # If the dtype has a shape, the dtype occurs # in the list more than once. - return [dt.base] * int(np.prod(dt.shape)) + shape = dt.shape + if len(shape) == 0: + return ([dt.base], None) + else: + packing = [(shape[-1], tuple)] + if len(shape) > 1: + for dim in dt.shape[-2:0:-1]: + packing = [(dim*packing[0][0],packing*dim)] + packing = packing*shape[0] + return ([dt.base] * int(np.prod(dt.shape)), packing) else: types = [] + packing = [] for field in dt.names: tp, bytes = dt.fields[field] - flat_dt = flatten_dtype(tp) + flat_dt, flat_packing = flatten_dtype(tp) types.extend(flat_dt) - return types + packing.append((len(flat_dt),flat_packing)) + return (types, packing) + + def pack_items(items, packing): + """Pack items into nested lists based on re-packing info.""" + if packing == None: + return items[0] + elif packing is tuple: + return tuple(items) + else: + start = 0 + ret = [] + for length, subpacking in packing: + ret.append(pack_items(items[start:start+length], subpacking)) + start += length + return tuple(ret) def split_line(line): """Chop off comments, strip, and split at delimiter.""" @@ -724,7 +749,7 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, first_vals = split_line(first_line) N = len(usecols or first_vals) - dtype_types = flatten_dtype(dtype) + dtype_types, packing = flatten_dtype(dtype) if len(dtype_types) > 1: # We're dealing with a structured array, each field of # the dtype matches a column @@ -732,6 +757,8 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, else: # All fields have the same dtype converters = [defconv for i in xrange(N)] + if N > 1: + packing = [(N, tuple)] # By preference, use the converters specified by the user for i, conv in (user_converters or {}).iteritems(): @@ -753,27 +780,16 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, vals = [vals[i] for i in usecols] # Convert each value according to its column and store - X.append(tuple([conv(val) for (conv, val) in zip(converters, vals)])) + items = [conv(val) for (conv, val) in zip(converters, vals)] + # Then pack it according to the dtype's nesting + items = pack_items(items, packing) + + X.append(items) finally: if own_fh: fh.close() - if len(dtype_types) > 1: - # We're dealing with a structured array, with a dtype such as - # [('x', int), ('y', [('s', int), ('t', float)])] - # - # First, create the array using a flattened dtype: - # [('x', int), ('s', int), ('t', float)] - # - # Then, view the array using the specified dtype. - try: - X = np.array(X, dtype=np.dtype([('', t) for t in dtype_types])) - X = X.view(dtype) - except TypeError: - # In the case we have an object dtype - X = np.array(X, dtype=dtype) - else: - X = np.array(X, dtype) + X = np.array(X, dtype) X = np.squeeze(X) if unpack: |