diff options
Diffstat (limited to 'numpy/lib/_iotools.py')
-rw-r--r-- | numpy/lib/_iotools.py | 150 |
1 files changed, 125 insertions, 25 deletions
diff --git a/numpy/lib/_iotools.py b/numpy/lib/_iotools.py index 02385305b..9e8bcce2a 100644 --- a/numpy/lib/_iotools.py +++ b/numpy/lib/_iotools.py @@ -118,6 +118,9 @@ def flatten_dtype(ndtype, flatten_base=False): + + + class LineSplitter: """ Object to split a string at a given delimiter or at given places. @@ -256,19 +259,19 @@ class NameValidator: defaultdeletechars = set("""~!@#$%^&*()-=+~\|]}[{';: /?.>,<""") # def __init__(self, excludelist=None, deletechars=None, case_sensitive=None): - # + # Process the exclusion list .. if excludelist is None: excludelist = [] excludelist.extend(self.defaultexcludelist) self.excludelist = excludelist - # + # Process the list of characters to delete if deletechars is None: delete = self.defaultdeletechars else: delete = set(deletechars) delete.add('"') self.deletechars = delete - + # Process the case option ..... if (case_sensitive is None) or (case_sensitive is True): self.case_converter = lambda x: x elif (case_sensitive is False) or ('u' in case_sensitive): @@ -277,18 +280,21 @@ class NameValidator: self.case_converter = lambda x: x.lower() else: self.case_converter = lambda x: x - # - def validate(self, names, default='f'): + + def validate(self, names, defaultfmt="f%i", nbfields=None): """ Validate a list of strings to use as field names for a structured array. Parameters ---------- - names : list of str - The strings that are to be validated. - default : str, optional - The default field name, used if validating a given string reduces its + names : sequence of str + Strings to be validated. + defaultfmt : str, optional + Default format string, used if validating a given string reduces its length to zero. + nboutput : integer, optional + Final number of validated names, used to expand or shrink the initial + list of names. Returns ------- @@ -301,24 +307,38 @@ class NameValidator: calling `validate`. For examples, see `NameValidator`. """ - # - if names is None: - return - # - validatednames = [] - seen = dict() - # + # Initial checks .............. + if (names is None): + if (nbfields is None): + return None + names = [] + if isinstance(names, basestring): + names = [names,] + if nbfields is not None: + nbnames = len(names) + if (nbnames < nbfields): + names = list(names) + [''] * (nbfields - nbnames) + elif (nbnames > nbfields): + names = names[:nbfields] + # Set some shortcuts ........... deletechars = self.deletechars excludelist = self.excludelist - # case_converter = self.case_converter + # Initializes some variables ... + validatednames = [] + seen = dict() + nbempty = 0 # - for i, item in enumerate(names): + for item in names: item = case_converter(item) item = item.strip().replace(' ', '_') item = ''.join([c for c in item if c not in deletechars]) - if not len(item): - item = '%s%d' % (default, i) + if item == '': + item = defaultfmt % nbempty + while item in names: + nbempty += 1 + item = defaultfmt % nbempty + nbempty += 1 elif item in excludelist: item += '_' cnt = seen.get(item, 0) @@ -326,11 +346,11 @@ class NameValidator: validatednames.append(item + '_%d' % cnt) else: validatednames.append(item) - seen[item] = cnt+1 - return validatednames + seen[item] = cnt + 1 + return tuple(validatednames) # - def __call__(self, names, default='f'): - return self.validate(names, default) + def __call__(self, names, defaultfmt="f%i", nbfields=None): + return self.validate(names, defaultfmt=defaultfmt, nbfields=nbfields) @@ -376,6 +396,10 @@ class ConverterError(Exception): class ConverterLockError(ConverterError): pass +class ConversionWarning(UserWarning): + pass + + class StringConverter: """ @@ -455,7 +479,7 @@ class StringConverter: -------- >>> import dateutil.parser >>> import datetime - >>> dateparser = datetutil.parser.parse + >>> dateparser = datetustil.parser.parse >>> defaultdate = datetime.date(2000, 1, 1) >>> StringConverter.upgrade_mapper(dateparser, default=defaultdate) """ @@ -659,3 +683,79 @@ class StringConverter: self.missing_values.add(val) else: self.missing_values = [] + + + +def easy_dtype(ndtype, names=None, defaultfmt="f%i", **validationargs): + """ + Convenience function to create a `np.dtype` object. + + The function processes the input dtype and matches it with the given names. + + Parameters + ---------- + ndtype : var + Definition of the dtype. Can be any string or dictionary recognized + by the `np.dtype` function or a sequence of types. + names : str or sequence, optional + Sequence of strings to use as field names for a structured dtype. + For convenience, `names` can be a string of a comma-separated list of + names + defaultfmt : str, optional + Format string used to define missing names, such as "f%i" (default), + "fields_%02i"... + validationargs : optional + A series of optional arguments used to initialize a NameValidator. + + Examples + -------- + >>> np.lib._iotools.easy_dtype(float) + dtype('float64') + >>> np.lib._iotools.easy_dtype("i4, f8") + dtype([('f0', '<i4'), ('f1', '<f8')]) + >>> np.lib._iotools.easy_dtype("i4, f8", defaultfmt="field_%03i") + dtype([('field_000', '<i4'), ('field_001', '<f8')]) + >>> np.lib._iotools.easy_dtype((int, float, float), names="a,b,c") + dtype([('a', '<i8'), ('b', '<f8'), ('c', '<f8')]) + >>> np.lib._iotools.easy_dtype(float, names="a,b,c") + dtype([('a', '<f8'), ('b', '<f8'), ('c', '<f8')]) + """ + try: + ndtype = np.dtype(ndtype) + except TypeError: + validate = NameValidator(**validationargs) + nbfields = len(ndtype) + if names is None: + names = [''] * len(ndtype) + elif isinstance(names, basestring): + names = names.split(",") + names = validate(names, nbfields=nbfields, defaultfmt=defaultfmt) + ndtype = np.dtype(dict(formats=ndtype, names=names)) + else: + nbtypes = len(ndtype) + # Explicit names + if names is not None: + validate = NameValidator(**validationargs) + if isinstance(names, basestring): + names = names.split(",") + # Simple dtype: repeat to match the nb of names + if nbtypes == 0: + formats = tuple([ndtype.type] * len(names)) + names = validate(names, defaultfmt=defaultfmt) + ndtype = np.dtype(zip(names, formats)) + # Structured dtype: just validate the names as needed + else: + ndtype.names = validate(names, nbfields=nbtypes, + defaultfmt=defaultfmt) + # No implicit names + elif (nbtypes > 0): + validate = NameValidator(**validationargs) + # Default initial names : should we change the format ? + if (ndtype.names == tuple("f%i" % i for i in range(nbtypes))) and \ + (defaultfmt != "f%i"): + ndtype.names = validate([''] * nbtypes, defaultfmt=defaultfmt) + # Explicit initial names : just validate + else: + ndtype.names = validate(ndtype.names, defaultfmt=defaultfmt) + return ndtype + |