diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/_iotools.py | 31 | ||||
-rw-r--r-- | numpy/lib/io.py | 30 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 29 |
3 files changed, 73 insertions, 17 deletions
diff --git a/numpy/lib/_iotools.py b/numpy/lib/_iotools.py index 23053bf4d..250a14795 100644 --- a/numpy/lib/_iotools.py +++ b/numpy/lib/_iotools.py @@ -64,23 +64,46 @@ def has_nested_fields(ndtype): return False -def flatten_dtype(ndtype): +def flatten_dtype(ndtype, flatten_base=False): """ - Unpack a structured data-type. + Unpack a structured data-type by collapsing nested fields and/or fields with + a shape. + Note that the field names are lost. + + Parameters + ---------- + ndtype : dtype + The datatype to collapse + flatten_base : {False, True}, optional + Whether to transform a field with a shape into several fields or not. + + Examples + -------- + >>> dt = np.dtype([('name', 'S4'), ('x', float), ('y', float), + ... ('block', int, (2, 3))]) + >>> flatten_dtype(dt) + [dtype('|S4'), dtype('float64'), dtype('float64'), dtype(('int32',(2, 3)))] + >>> flatten_dtype(dt, flatten_base=True) + [dtype('|S4'), dtype('float64'), dtype('float64'), dtype('int32'), + dtype('int32'), dtype('int32'), dtype('int32'), dtype('int32'), + dtype('int32')] """ names = ndtype.names if names is None: - return [ndtype] + if flatten_base: + return [ndtype.base] * int(np.prod(ndtype.shape)) + return [ndtype.base] else: types = [] for field in names: (typ, _) = ndtype.fields[field] - flat_dt = flatten_dtype(typ) + flat_dt = flatten_dtype(typ, flatten_base) types.extend(flat_dt) return types + class LineSplitter: """ Defines a function to split a string at a given delimiter or at given places. diff --git a/numpy/lib/io.py b/numpy/lib/io.py index 03274845c..98d071fab 100644 --- a/numpy/lib/io.py +++ b/numpy/lib/io.py @@ -505,8 +505,12 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, converters=None, # [('x', int), ('s', int), ('t', float)] # # Then, view the array using the specified dtype. - X = np.array(X, dtype=np.dtype([('', t) for t in dtype_types])) - X = X.view(dtype) + try: + X = np.array(X, dtype=np.dtype([('', t) for t in dtype_types])) + X = X.view(dtype) + except TypeError: + # In the case we have an object dtype + X = np.array(X, dtype=dtype) else: X = np.array(X, dtype) @@ -895,14 +899,14 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, missing_values=missing_values.get(_, defmissing)) for _ in range(nbcols)] else: - flatdtypes = flatten_dtype(dtype) + dtype_flat = flatten_dtype(dtype, flatten_base=True) # Initialize the converters - if len(flatdtypes) > 1: + if len(dtype_flat) > 1: # Flexible type : get a converter from each dtype converters = [StringConverter(dt, missing_values=missing_values.get(i, defmissing), locked=True) - for (i, dt) in enumerate(flatdtypes)] + for (i, dt) in enumerate(dtype_flat)] else: # Set to a default converter (but w/ different missing values) converters = [StringConverter(dtype, @@ -1000,27 +1004,27 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, # Overwrite the initial dtype names if needed if names and dtype.names: dtype.names = names - flatdtypes = flatten_dtype(dtype) # Case 1. We have a structured type - if len(flatdtypes) > 1: + if len(dtype_flat) > 1: # Nested dtype, eg [('a', int), ('b', [('b0', int), ('b1', 'f4')])] # First, create the array using a flattened dtype: # [('a', int), ('b1', int), ('b2', float)] # Then, view the array using the specified dtype. - if has_nested_fields(dtype): - if 'O' in (_.char for _ in flatdtypes): + if 'O' in (_.char for _ in dtype_flat): + if has_nested_fields(dtype): errmsg = "Nested fields involving objects "\ "are not supported..." raise NotImplementedError(errmsg) - rows = np.array(data, dtype=[('', t) for t in flatdtypes]) - output = rows.view(dtype) + else: + output = np.array(data, dtype=dtype) else: - output = np.array(data, dtype=dtype) + rows = np.array(data, dtype=[('', _) for _ in dtype_flat]) + output = rows.view(dtype) # Now, process the rowmasks the same way if usemask: rowmasks = np.array(masks, dtype=np.dtype([('', np.bool) - for t in flatdtypes])) + for t in dtype_flat])) # Construct the new dtype mdtype = make_mask_descr(dtype) outputmask = rowmasks.view(mdtype) diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index d9bf03e01..e5a73a86a 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -370,6 +370,25 @@ class TestLoadTxt(TestCase): converters={1: lambda s: int(s, 16)}) assert_array_equal(data, [33, 66]) + def test_dtype_with_object(self): + "Test using an explicit dtype with an object" + from datetime import date + import time + data = """ + 1; 2001-01-01 + 2; 2002-01-31 + """ + ndtype = [('idx', int), ('code', np.object)] + func = lambda s: strptime(s.strip(), "%Y-%m-%d") + converters = {1: func} + test = np.loadtxt(StringIO.StringIO(data), delimiter=";", dtype=ndtype, + converters=converters) + control = np.array([(1, datetime(2001,1,1)), (2, datetime(2002,1,31))], + dtype=ndtype) + assert_equal(test, control) + + + class Testfromregex(TestCase): def test_record(self): c = StringIO.StringIO() @@ -717,6 +736,16 @@ M 33 21.99 assert_equal(test, control) + def test_shaped_dtype(self): + c = StringIO.StringIO("aaaa 1.0 8.0 1 2 3 4 5 6") + dt = np.dtype([('name', 'S4'), ('x', float), ('y', float), + ('block', int, (2, 3))]) + x = np.ndfromtxt(c, dtype=dt) + a = np.array([('aaaa', 1.0, 8.0, [[1, 2, 3], [4, 5, 6]])], + dtype=dt) + assert_array_equal(x, a) + + def test_withmissing(self): data = StringIO.StringIO('A,B\n0,1\n2,N/A') test = np.mafromtxt(data, dtype=None, delimiter=',', missing='N/A', |