summaryrefslogtreecommitdiff
path: root/numpy/lib/_iotools.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/_iotools.py')
-rw-r--r--numpy/lib/_iotools.py150
1 files changed, 125 insertions, 25 deletions
diff --git a/numpy/lib/_iotools.py b/numpy/lib/_iotools.py
index 02385305b..9e8bcce2a 100644
--- a/numpy/lib/_iotools.py
+++ b/numpy/lib/_iotools.py
@@ -118,6 +118,9 @@ def flatten_dtype(ndtype, flatten_base=False):
+
+
+
class LineSplitter:
"""
Object to split a string at a given delimiter or at given places.
@@ -256,19 +259,19 @@ class NameValidator:
defaultdeletechars = set("""~!@#$%^&*()-=+~\|]}[{';: /?.>,<""")
#
def __init__(self, excludelist=None, deletechars=None, case_sensitive=None):
- #
+ # Process the exclusion list ..
if excludelist is None:
excludelist = []
excludelist.extend(self.defaultexcludelist)
self.excludelist = excludelist
- #
+ # Process the list of characters to delete
if deletechars is None:
delete = self.defaultdeletechars
else:
delete = set(deletechars)
delete.add('"')
self.deletechars = delete
-
+ # Process the case option .....
if (case_sensitive is None) or (case_sensitive is True):
self.case_converter = lambda x: x
elif (case_sensitive is False) or ('u' in case_sensitive):
@@ -277,18 +280,21 @@ class NameValidator:
self.case_converter = lambda x: x.lower()
else:
self.case_converter = lambda x: x
- #
- def validate(self, names, default='f'):
+
+ def validate(self, names, defaultfmt="f%i", nbfields=None):
"""
Validate a list of strings to use as field names for a structured array.
Parameters
----------
- names : list of str
- The strings that are to be validated.
- default : str, optional
- The default field name, used if validating a given string reduces its
+ names : sequence of str
+ Strings to be validated.
+ defaultfmt : str, optional
+ Default format string, used if validating a given string reduces its
length to zero.
+ nboutput : integer, optional
+ Final number of validated names, used to expand or shrink the initial
+ list of names.
Returns
-------
@@ -301,24 +307,38 @@ class NameValidator:
calling `validate`. For examples, see `NameValidator`.
"""
- #
- if names is None:
- return
- #
- validatednames = []
- seen = dict()
- #
+ # Initial checks ..............
+ if (names is None):
+ if (nbfields is None):
+ return None
+ names = []
+ if isinstance(names, basestring):
+ names = [names,]
+ if nbfields is not None:
+ nbnames = len(names)
+ if (nbnames < nbfields):
+ names = list(names) + [''] * (nbfields - nbnames)
+ elif (nbnames > nbfields):
+ names = names[:nbfields]
+ # Set some shortcuts ...........
deletechars = self.deletechars
excludelist = self.excludelist
- #
case_converter = self.case_converter
+ # Initializes some variables ...
+ validatednames = []
+ seen = dict()
+ nbempty = 0
#
- for i, item in enumerate(names):
+ for item in names:
item = case_converter(item)
item = item.strip().replace(' ', '_')
item = ''.join([c for c in item if c not in deletechars])
- if not len(item):
- item = '%s%d' % (default, i)
+ if item == '':
+ item = defaultfmt % nbempty
+ while item in names:
+ nbempty += 1
+ item = defaultfmt % nbempty
+ nbempty += 1
elif item in excludelist:
item += '_'
cnt = seen.get(item, 0)
@@ -326,11 +346,11 @@ class NameValidator:
validatednames.append(item + '_%d' % cnt)
else:
validatednames.append(item)
- seen[item] = cnt+1
- return validatednames
+ seen[item] = cnt + 1
+ return tuple(validatednames)
#
- def __call__(self, names, default='f'):
- return self.validate(names, default)
+ def __call__(self, names, defaultfmt="f%i", nbfields=None):
+ return self.validate(names, defaultfmt=defaultfmt, nbfields=nbfields)
@@ -376,6 +396,10 @@ class ConverterError(Exception):
class ConverterLockError(ConverterError):
pass
+class ConversionWarning(UserWarning):
+ pass
+
+
class StringConverter:
"""
@@ -455,7 +479,7 @@ class StringConverter:
--------
>>> import dateutil.parser
>>> import datetime
- >>> dateparser = datetutil.parser.parse
+ >>> dateparser = datetustil.parser.parse
>>> defaultdate = datetime.date(2000, 1, 1)
>>> StringConverter.upgrade_mapper(dateparser, default=defaultdate)
"""
@@ -659,3 +683,79 @@ class StringConverter:
self.missing_values.add(val)
else:
self.missing_values = []
+
+
+
+def easy_dtype(ndtype, names=None, defaultfmt="f%i", **validationargs):
+ """
+ Convenience function to create a `np.dtype` object.
+
+ The function processes the input dtype and matches it with the given names.
+
+ Parameters
+ ----------
+ ndtype : var
+ Definition of the dtype. Can be any string or dictionary recognized
+ by the `np.dtype` function or a sequence of types.
+ names : str or sequence, optional
+ Sequence of strings to use as field names for a structured dtype.
+ For convenience, `names` can be a string of a comma-separated list of
+ names
+ defaultfmt : str, optional
+ Format string used to define missing names, such as "f%i" (default),
+ "fields_%02i"...
+ validationargs : optional
+ A series of optional arguments used to initialize a NameValidator.
+
+ Examples
+ --------
+ >>> np.lib._iotools.easy_dtype(float)
+ dtype('float64')
+ >>> np.lib._iotools.easy_dtype("i4, f8")
+ dtype([('f0', '<i4'), ('f1', '<f8')])
+ >>> np.lib._iotools.easy_dtype("i4, f8", defaultfmt="field_%03i")
+ dtype([('field_000', '<i4'), ('field_001', '<f8')])
+ >>> np.lib._iotools.easy_dtype((int, float, float), names="a,b,c")
+ dtype([('a', '<i8'), ('b', '<f8'), ('c', '<f8')])
+ >>> np.lib._iotools.easy_dtype(float, names="a,b,c")
+ dtype([('a', '<f8'), ('b', '<f8'), ('c', '<f8')])
+ """
+ try:
+ ndtype = np.dtype(ndtype)
+ except TypeError:
+ validate = NameValidator(**validationargs)
+ nbfields = len(ndtype)
+ if names is None:
+ names = [''] * len(ndtype)
+ elif isinstance(names, basestring):
+ names = names.split(",")
+ names = validate(names, nbfields=nbfields, defaultfmt=defaultfmt)
+ ndtype = np.dtype(dict(formats=ndtype, names=names))
+ else:
+ nbtypes = len(ndtype)
+ # Explicit names
+ if names is not None:
+ validate = NameValidator(**validationargs)
+ if isinstance(names, basestring):
+ names = names.split(",")
+ # Simple dtype: repeat to match the nb of names
+ if nbtypes == 0:
+ formats = tuple([ndtype.type] * len(names))
+ names = validate(names, defaultfmt=defaultfmt)
+ ndtype = np.dtype(zip(names, formats))
+ # Structured dtype: just validate the names as needed
+ else:
+ ndtype.names = validate(names, nbfields=nbtypes,
+ defaultfmt=defaultfmt)
+ # No implicit names
+ elif (nbtypes > 0):
+ validate = NameValidator(**validationargs)
+ # Default initial names : should we change the format ?
+ if (ndtype.names == tuple("f%i" % i for i in range(nbtypes))) and \
+ (defaultfmt != "f%i"):
+ ndtype.names = validate([''] * nbtypes, defaultfmt=defaultfmt)
+ # Explicit initial names : just validate
+ else:
+ ndtype.names = validate(ndtype.names, defaultfmt=defaultfmt)
+ return ndtype
+