diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2008-07-22 06:37:48 +0000 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2008-07-22 06:37:48 +0000 |
commit | a2dcde587eb07398d0a30189a898c614ea1ba1aa (patch) | |
tree | 8c135ebad9b5e532df6f74f9a620dcde96035e81 /numpy/lib/io.py | |
parent | e41b0e3b2222095c2eb75952602fc8b779798cbc (diff) | |
download | numpy-a2dcde587eb07398d0a30189a898c614ea1ba1aa.tar.gz |
Apply Stefan's patch for Ryan's loadtext fix.
Diffstat (limited to 'numpy/lib/io.py')
-rw-r--r-- | numpy/lib/io.py | 112 |
1 files changed, 78 insertions, 34 deletions
diff --git a/numpy/lib/io.py b/numpy/lib/io.py index b1ae192ec..7ac154df5 100644 --- a/numpy/lib/io.py +++ b/numpy/lib/io.py @@ -10,6 +10,7 @@ import format import cStringIO import tempfile import os +import itertools from cPickle import load as _cload, loads from _datasource import DataSource @@ -286,44 +287,87 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, converters=None, raise ValueError('fname must be a string or file handle') X = [] - dtype = np.dtype(dtype) - defconv = _getconv(dtype) - converterseq = None - if converters is None: - converters = {} - if dtype.names is not None: - if usecols is None: - converterseq = [_getconv(dtype.fields[name][0]) \ - for name in dtype.names] - else: - converters.update([(col,_getconv(dtype.fields[name][0])) \ - for col,name in zip(usecols, dtype.names)]) - - for i,line in enumerate(fh): - if i<skiprows: continue - comment_start = line.find(comments) - if comment_start != -1: - line = line[:comment_start].strip() + def flatten_dtype(dt): + """Unpack a structured data-type.""" + if dt.names is None: + return [dt] else: - line = line.strip() - if not len(line): continue - vals = line.split(delimiter) - if converterseq is None: - converterseq = [converters.get(j,defconv) \ - for j in xrange(len(vals))] - if usecols is not None: - row = [converterseq[j](vals[j]) for j in usecols] + types = [] + for field in dt.names: + tp, bytes = dt.fields[field] + flat_dt = flatten_dtype(tp) + types.extend(flat_dt) + return types + + def split_line(line): + """Chop off comments, strip, and split at delimiter.""" + line = line.split(comments)[0].strip() + if line: + return line.split(delimiter) else: - row = [converterseq[j](val) for j,val in enumerate(vals)] - if dtype.names is not None: - row = tuple(row) - X.append(row) + return [] - X = np.array(X, dtype) - X = np.squeeze(X) - if unpack: return X.T - else: return X + # Make sure we're dealing with a proper dtype + dtype = np.dtype(dtype) + defconv = _getconv(dtype) + # Skip the first `skiprows` lines + for i in xrange(skiprows): + fh.readline() + + # Read until we find a line with some values, and use + # it to estimate the number of columns, N. + read_line = None + while not read_line: + first_line = fh.readline() + read_line = split_line(first_line) + N = len(usecols or read_line) + + dtype_types = flatten_dtype(dtype) + if len(dtype_types) > 1: + # We're dealing with a structured array, each field of + # the dtype matches a column + converterseq = [_getconv(dt) for dt in dtype_types] + else: + # All fields have the same dtype + converterseq = [defconv for i in xrange(N)] + + # By preference, use the converters specified by the user + for i, conv in (converters or {}).iteritems(): + if usecols: + i = usecols.find(i) + converterseq[i] = conv + + # Parse each line, including the first + for i, line in enumerate(itertools.chain([first_line], fh)): + vals = split_line(line) + if len(vals) == 0: + continue + + if usecols: + vals = [vals[i] for i in usecols] + + # Convert each value according to its column and store + X.append(tuple(conv(val) for (conv, val) in zip(converterseq, vals))) + + if len(dtype_types) > 1: + # We're dealing with a structured array, with a dtype such as + # [('x', int), ('y', [('s', int), ('t', float)])] + # + # First, create the array using a flattened dtype: + # [('x', int), ('s', int), ('t', float)] + # + # Then, view the array using the specified dtype. + X = np.array(X, dtype=np.dtype([('', t) for t in dtype_types])) + X = X.view(dtype) + else: + X = np.array(X, dtype) + + X = np.squeeze(X) + if unpack: + return X.T + else: + return X def savetxt(fname, X, fmt='%.18e',delimiter=' '): |