summaryrefslogtreecommitdiff
path: root/numpy/lib/io.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2008-07-22 06:37:48 +0000
committerCharles Harris <charlesr.harris@gmail.com>2008-07-22 06:37:48 +0000
commita2dcde587eb07398d0a30189a898c614ea1ba1aa (patch)
tree8c135ebad9b5e532df6f74f9a620dcde96035e81 /numpy/lib/io.py
parente41b0e3b2222095c2eb75952602fc8b779798cbc (diff)
downloadnumpy-a2dcde587eb07398d0a30189a898c614ea1ba1aa.tar.gz
Apply Stefan's patch for Ryan's loadtext fix.
Diffstat (limited to 'numpy/lib/io.py')
-rw-r--r--numpy/lib/io.py112
1 files changed, 78 insertions, 34 deletions
diff --git a/numpy/lib/io.py b/numpy/lib/io.py
index b1ae192ec..7ac154df5 100644
--- a/numpy/lib/io.py
+++ b/numpy/lib/io.py
@@ -10,6 +10,7 @@ import format
import cStringIO
import tempfile
import os
+import itertools
from cPickle import load as _cload, loads
from _datasource import DataSource
@@ -286,44 +287,87 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, converters=None,
raise ValueError('fname must be a string or file handle')
X = []
- dtype = np.dtype(dtype)
- defconv = _getconv(dtype)
- converterseq = None
- if converters is None:
- converters = {}
- if dtype.names is not None:
- if usecols is None:
- converterseq = [_getconv(dtype.fields[name][0]) \
- for name in dtype.names]
- else:
- converters.update([(col,_getconv(dtype.fields[name][0])) \
- for col,name in zip(usecols, dtype.names)])
-
- for i,line in enumerate(fh):
- if i<skiprows: continue
- comment_start = line.find(comments)
- if comment_start != -1:
- line = line[:comment_start].strip()
+ def flatten_dtype(dt):
+ """Unpack a structured data-type."""
+ if dt.names is None:
+ return [dt]
else:
- line = line.strip()
- if not len(line): continue
- vals = line.split(delimiter)
- if converterseq is None:
- converterseq = [converters.get(j,defconv) \
- for j in xrange(len(vals))]
- if usecols is not None:
- row = [converterseq[j](vals[j]) for j in usecols]
+ types = []
+ for field in dt.names:
+ tp, bytes = dt.fields[field]
+ flat_dt = flatten_dtype(tp)
+ types.extend(flat_dt)
+ return types
+
+ def split_line(line):
+ """Chop off comments, strip, and split at delimiter."""
+ line = line.split(comments)[0].strip()
+ if line:
+ return line.split(delimiter)
else:
- row = [converterseq[j](val) for j,val in enumerate(vals)]
- if dtype.names is not None:
- row = tuple(row)
- X.append(row)
+ return []
- X = np.array(X, dtype)
- X = np.squeeze(X)
- if unpack: return X.T
- else: return X
+ # Make sure we're dealing with a proper dtype
+ dtype = np.dtype(dtype)
+ defconv = _getconv(dtype)
+ # Skip the first `skiprows` lines
+ for i in xrange(skiprows):
+ fh.readline()
+
+ # Read until we find a line with some values, and use
+ # it to estimate the number of columns, N.
+ read_line = None
+ while not read_line:
+ first_line = fh.readline()
+ read_line = split_line(first_line)
+ N = len(usecols or read_line)
+
+ dtype_types = flatten_dtype(dtype)
+ if len(dtype_types) > 1:
+ # We're dealing with a structured array, each field of
+ # the dtype matches a column
+ converterseq = [_getconv(dt) for dt in dtype_types]
+ else:
+ # All fields have the same dtype
+ converterseq = [defconv for i in xrange(N)]
+
+ # By preference, use the converters specified by the user
+ for i, conv in (converters or {}).iteritems():
+ if usecols:
+ i = usecols.find(i)
+ converterseq[i] = conv
+
+ # Parse each line, including the first
+ for i, line in enumerate(itertools.chain([first_line], fh)):
+ vals = split_line(line)
+ if len(vals) == 0:
+ continue
+
+ if usecols:
+ vals = [vals[i] for i in usecols]
+
+ # Convert each value according to its column and store
+ X.append(tuple(conv(val) for (conv, val) in zip(converterseq, vals)))
+
+ if len(dtype_types) > 1:
+ # We're dealing with a structured array, with a dtype such as
+ # [('x', int), ('y', [('s', int), ('t', float)])]
+ #
+ # First, create the array using a flattened dtype:
+ # [('x', int), ('s', int), ('t', float)]
+ #
+ # Then, view the array using the specified dtype.
+ X = np.array(X, dtype=np.dtype([('', t) for t in dtype_types]))
+ X = X.view(dtype)
+ else:
+ X = np.array(X, dtype)
+
+ X = np.squeeze(X)
+ if unpack:
+ return X.T
+ else:
+ return X
def savetxt(fname, X, fmt='%.18e',delimiter=' '):