diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/_iotools.py | 13 | ||||
-rw-r--r-- | numpy/lib/npyio.py | 4 | ||||
-rw-r--r-- | numpy/lib/recfunctions.py | 22 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 7 | ||||
-rw-r--r-- | numpy/lib/tests/test_recfunctions.py | 21 |
5 files changed, 47 insertions, 20 deletions
diff --git a/numpy/lib/_iotools.py b/numpy/lib/_iotools.py index 0ebd39b8c..c392929fd 100644 --- a/numpy/lib/_iotools.py +++ b/numpy/lib/_iotools.py @@ -121,7 +121,7 @@ def has_nested_fields(ndtype): """ for name in ndtype.names or (): - if ndtype[name].names: + if ndtype[name].names is not None: return True return False @@ -931,28 +931,27 @@ def easy_dtype(ndtype, names=None, defaultfmt="f%i", **validationargs): names = validate(names, nbfields=nbfields, defaultfmt=defaultfmt) ndtype = np.dtype(dict(formats=ndtype, names=names)) else: - nbtypes = len(ndtype) # Explicit names if names is not None: validate = NameValidator(**validationargs) if isinstance(names, basestring): names = names.split(",") # Simple dtype: repeat to match the nb of names - if nbtypes == 0: + if ndtype.names is None: formats = tuple([ndtype.type] * len(names)) names = validate(names, defaultfmt=defaultfmt) ndtype = np.dtype(list(zip(names, formats))) # Structured dtype: just validate the names as needed else: - ndtype.names = validate(names, nbfields=nbtypes, + ndtype.names = validate(names, nbfields=len(ndtype.names), defaultfmt=defaultfmt) # No implicit names - elif (nbtypes > 0): + elif ndtype.names is not None: validate = NameValidator(**validationargs) # Default initial names : should we change the format ? - if ((ndtype.names == tuple("f%i" % i for i in range(nbtypes))) and + if ((ndtype.names == tuple("f%i" % i for i in range(len(ndtype.names)))) and (defaultfmt != "f%i")): - ndtype.names = validate([''] * nbtypes, defaultfmt=defaultfmt) + ndtype.names = validate([''] * len(ndtype.names), defaultfmt=defaultfmt) # Explicit initial names : just validate else: ndtype.names = validate(ndtype.names, defaultfmt=defaultfmt) diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index c45622edd..e57a6dd47 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -2180,7 +2180,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, outputmask = np.array(masks, dtype=mdtype) else: # Overwrite the initial dtype names if needed - if names and dtype.names: + if names and dtype.names is not None: dtype.names = names # Case 1. We have a structured type if len(dtype_flat) > 1: @@ -2230,7 +2230,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, # output = np.array(data, dtype) if usemask: - if dtype.names: + if dtype.names is not None: mdtype = [(_, bool) for _ in dtype.names] else: mdtype = bool diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py index 6e257bb3f..014f5e200 100644 --- a/numpy/lib/recfunctions.py +++ b/numpy/lib/recfunctions.py @@ -72,7 +72,7 @@ def recursive_fill_fields(input, output): current = input[field] except ValueError: continue - if current.dtype.names: + if current.dtype.names is not None: recursive_fill_fields(current, output[field]) else: output[field][:len(current)] = current @@ -139,11 +139,11 @@ def get_names(adtype): names = adtype.names for name in names: current = adtype[name] - if current.names: + if current.names is not None: listnames.append((name, tuple(get_names(current)))) else: listnames.append(name) - return tuple(listnames) or None + return tuple(listnames) def get_names_flat(adtype): @@ -176,9 +176,9 @@ def get_names_flat(adtype): for name in names: listnames.append(name) current = adtype[name] - if current.names: + if current.names is not None: listnames.extend(get_names_flat(current)) - return tuple(listnames) or None + return tuple(listnames) def flatten_descr(ndtype): @@ -215,8 +215,8 @@ def _zip_dtype(seqarrays, flatten=False): else: for a in seqarrays: current = a.dtype - if current.names and len(current.names) <= 1: - # special case - dtypes of 0 or 1 field are flattened + if current.names is not None and len(current.names) == 1: + # special case - dtypes of 1 field are flattened newdtype.extend(_get_fieldspec(current)) else: newdtype.append(('', current)) @@ -268,7 +268,7 @@ def get_fieldstructure(adtype, lastname=None, parents=None,): names = adtype.names for name in names: current = adtype[name] - if current.names: + if current.names is not None: if lastname: parents[name] = [lastname, ] else: @@ -281,7 +281,7 @@ def get_fieldstructure(adtype, lastname=None, parents=None,): elif lastname: lastparent = [lastname, ] parents[name] = lastparent or [] - return parents or None + return parents def _izip_fields_flat(iterable): @@ -435,7 +435,7 @@ def merge_arrays(seqarrays, fill_value=-1, flatten=False, if isinstance(seqarrays, (ndarray, np.void)): seqdtype = seqarrays.dtype # Make sure we have named fields - if not seqdtype.names: + if seqdtype.names is None: seqdtype = np.dtype([('', seqdtype)]) if not flatten or _zip_dtype((seqarrays,), flatten=True) == seqdtype: # Minimal processing needed: just make sure everythng's a-ok @@ -653,7 +653,7 @@ def rename_fields(base, namemapper): for name in ndtype.names: newname = namemapper.get(name, name) current = ndtype[name] - if current.names: + if current.names is not None: newdtype.append( (newname, _recursive_rename_fields(current, namemapper)) ) diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index 407bb56bf..6ee17c830 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -1565,6 +1565,13 @@ M 33 21.99 test = np.genfromtxt(TextIO(data), delimiter=";", dtype=ndtype, converters=converters) + # nested but empty fields also aren't supported + ndtype = [('idx', int), ('code', object), ('nest', [])] + with assert_raises_regex(NotImplementedError, + 'Nested fields.* not supported.*'): + test = np.genfromtxt(TextIO(data), delimiter=";", + dtype=ndtype, converters=converters) + def test_userconverters_with_explicit_dtype(self): # Test user_converters w/ explicit (standard) dtype data = TextIO('skip,skip,2001-01-01,1.0,skip') diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py index 0126ccaf8..eb81190b7 100644 --- a/numpy/lib/tests/test_recfunctions.py +++ b/numpy/lib/tests/test_recfunctions.py @@ -115,6 +115,14 @@ class TestRecFunctions(object): test = get_names(ndtype) assert_equal(test, ('a', ('b', ('ba', 'bb')))) + ndtype = np.dtype([('a', int), ('b', [])]) + test = get_names(ndtype) + assert_equal(test, ('a', ('b', ()))) + + ndtype = np.dtype([]) + test = get_names(ndtype) + assert_equal(test, ()) + def test_get_names_flat(self): # Test get_names_flat ndtype = np.dtype([('A', '|S3'), ('B', float)]) @@ -125,6 +133,14 @@ class TestRecFunctions(object): test = get_names_flat(ndtype) assert_equal(test, ('a', 'b', 'ba', 'bb')) + ndtype = np.dtype([('a', int), ('b', [])]) + test = get_names_flat(ndtype) + assert_equal(test, ('a', 'b')) + + ndtype = np.dtype([]) + test = get_names_flat(ndtype) + assert_equal(test, ()) + def test_get_fieldstructure(self): # Test get_fieldstructure @@ -147,6 +163,11 @@ class TestRecFunctions(object): 'BBA': ['B', 'BB'], 'BBB': ['B', 'BB']} assert_equal(test, control) + # 0 fields + ndtype = np.dtype([]) + test = get_fieldstructure(ndtype) + assert_equal(test, {}) + def test_find_duplicates(self): # Test find_duplicates a = ma.array([(2, (2., 'B')), (1, (2., 'B')), (2, (2., 'B')), |