diff options
author | pierregm <pierregm@localhost> | 2009-10-09 02:17:30 +0000 |
---|---|---|
committer | pierregm <pierregm@localhost> | 2009-10-09 02:17:30 +0000 |
commit | 8cad335f8b97100df988dbb6fd5d06072c667515 (patch) | |
tree | 5eb53d599da9028fc42d31ffdd979f022268a472 | |
parent | e54df63621db89eddbadc7bf0f36798ee1f79e0a (diff) | |
download | numpy-8cad335f8b97100df988dbb6fd5d06072c667515.tar.gz |
* ma.masked_equal : force the `fill_value` of the output to `value` (ticket #1253)
* lib._iotools:
- NameValidator : add the `nbfields` optional argument to validate
- add easy_dtype
* lib.io.genfromtxt :
- add the `autostrip` optional argument (ticket #1238)
- use `invalid_raise=True` as default
- use the easy_dtype mechanism (ticket #1252)
-rw-r--r-- | numpy/lib/_iotools.py | 150 | ||||
-rw-r--r-- | numpy/lib/io.py | 111 | ||||
-rw-r--r-- | numpy/lib/tests/test__iotools.py | 87 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 81 | ||||
-rw-r--r-- | numpy/ma/core.py | 4 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 6 |
6 files changed, 331 insertions, 108 deletions
diff --git a/numpy/lib/_iotools.py b/numpy/lib/_iotools.py index 02385305b..9e8bcce2a 100644 --- a/numpy/lib/_iotools.py +++ b/numpy/lib/_iotools.py @@ -118,6 +118,9 @@ def flatten_dtype(ndtype, flatten_base=False): + + + class LineSplitter: """ Object to split a string at a given delimiter or at given places. @@ -256,19 +259,19 @@ class NameValidator: defaultdeletechars = set("""~!@#$%^&*()-=+~\|]}[{';: /?.>,<""") # def __init__(self, excludelist=None, deletechars=None, case_sensitive=None): - # + # Process the exclusion list .. if excludelist is None: excludelist = [] excludelist.extend(self.defaultexcludelist) self.excludelist = excludelist - # + # Process the list of characters to delete if deletechars is None: delete = self.defaultdeletechars else: delete = set(deletechars) delete.add('"') self.deletechars = delete - + # Process the case option ..... if (case_sensitive is None) or (case_sensitive is True): self.case_converter = lambda x: x elif (case_sensitive is False) or ('u' in case_sensitive): @@ -277,18 +280,21 @@ class NameValidator: self.case_converter = lambda x: x.lower() else: self.case_converter = lambda x: x - # - def validate(self, names, default='f'): + + def validate(self, names, defaultfmt="f%i", nbfields=None): """ Validate a list of strings to use as field names for a structured array. Parameters ---------- - names : list of str - The strings that are to be validated. - default : str, optional - The default field name, used if validating a given string reduces its + names : sequence of str + Strings to be validated. + defaultfmt : str, optional + Default format string, used if validating a given string reduces its length to zero. + nboutput : integer, optional + Final number of validated names, used to expand or shrink the initial + list of names. Returns ------- @@ -301,24 +307,38 @@ class NameValidator: calling `validate`. For examples, see `NameValidator`. """ - # - if names is None: - return - # - validatednames = [] - seen = dict() - # + # Initial checks .............. + if (names is None): + if (nbfields is None): + return None + names = [] + if isinstance(names, basestring): + names = [names,] + if nbfields is not None: + nbnames = len(names) + if (nbnames < nbfields): + names = list(names) + [''] * (nbfields - nbnames) + elif (nbnames > nbfields): + names = names[:nbfields] + # Set some shortcuts ........... deletechars = self.deletechars excludelist = self.excludelist - # case_converter = self.case_converter + # Initializes some variables ... + validatednames = [] + seen = dict() + nbempty = 0 # - for i, item in enumerate(names): + for item in names: item = case_converter(item) item = item.strip().replace(' ', '_') item = ''.join([c for c in item if c not in deletechars]) - if not len(item): - item = '%s%d' % (default, i) + if item == '': + item = defaultfmt % nbempty + while item in names: + nbempty += 1 + item = defaultfmt % nbempty + nbempty += 1 elif item in excludelist: item += '_' cnt = seen.get(item, 0) @@ -326,11 +346,11 @@ class NameValidator: validatednames.append(item + '_%d' % cnt) else: validatednames.append(item) - seen[item] = cnt+1 - return validatednames + seen[item] = cnt + 1 + return tuple(validatednames) # - def __call__(self, names, default='f'): - return self.validate(names, default) + def __call__(self, names, defaultfmt="f%i", nbfields=None): + return self.validate(names, defaultfmt=defaultfmt, nbfields=nbfields) @@ -376,6 +396,10 @@ class ConverterError(Exception): class ConverterLockError(ConverterError): pass +class ConversionWarning(UserWarning): + pass + + class StringConverter: """ @@ -455,7 +479,7 @@ class StringConverter: -------- >>> import dateutil.parser >>> import datetime - >>> dateparser = datetutil.parser.parse + >>> dateparser = datetustil.parser.parse >>> defaultdate = datetime.date(2000, 1, 1) >>> StringConverter.upgrade_mapper(dateparser, default=defaultdate) """ @@ -659,3 +683,79 @@ class StringConverter: self.missing_values.add(val) else: self.missing_values = [] + + + +def easy_dtype(ndtype, names=None, defaultfmt="f%i", **validationargs): + """ + Convenience function to create a `np.dtype` object. + + The function processes the input dtype and matches it with the given names. + + Parameters + ---------- + ndtype : var + Definition of the dtype. Can be any string or dictionary recognized + by the `np.dtype` function or a sequence of types. + names : str or sequence, optional + Sequence of strings to use as field names for a structured dtype. + For convenience, `names` can be a string of a comma-separated list of + names + defaultfmt : str, optional + Format string used to define missing names, such as "f%i" (default), + "fields_%02i"... + validationargs : optional + A series of optional arguments used to initialize a NameValidator. + + Examples + -------- + >>> np.lib._iotools.easy_dtype(float) + dtype('float64') + >>> np.lib._iotools.easy_dtype("i4, f8") + dtype([('f0', '<i4'), ('f1', '<f8')]) + >>> np.lib._iotools.easy_dtype("i4, f8", defaultfmt="field_%03i") + dtype([('field_000', '<i4'), ('field_001', '<f8')]) + >>> np.lib._iotools.easy_dtype((int, float, float), names="a,b,c") + dtype([('a', '<i8'), ('b', '<f8'), ('c', '<f8')]) + >>> np.lib._iotools.easy_dtype(float, names="a,b,c") + dtype([('a', '<f8'), ('b', '<f8'), ('c', '<f8')]) + """ + try: + ndtype = np.dtype(ndtype) + except TypeError: + validate = NameValidator(**validationargs) + nbfields = len(ndtype) + if names is None: + names = [''] * len(ndtype) + elif isinstance(names, basestring): + names = names.split(",") + names = validate(names, nbfields=nbfields, defaultfmt=defaultfmt) + ndtype = np.dtype(dict(formats=ndtype, names=names)) + else: + nbtypes = len(ndtype) + # Explicit names + if names is not None: + validate = NameValidator(**validationargs) + if isinstance(names, basestring): + names = names.split(",") + # Simple dtype: repeat to match the nb of names + if nbtypes == 0: + formats = tuple([ndtype.type] * len(names)) + names = validate(names, defaultfmt=defaultfmt) + ndtype = np.dtype(zip(names, formats)) + # Structured dtype: just validate the names as needed + else: + ndtype.names = validate(names, nbfields=nbtypes, + defaultfmt=defaultfmt) + # No implicit names + elif (nbtypes > 0): + validate = NameValidator(**validationargs) + # Default initial names : should we change the format ? + if (ndtype.names == tuple("f%i" % i for i in range(nbtypes))) and \ + (defaultfmt != "f%i"): + ndtype.names = validate([''] * nbtypes, defaultfmt=defaultfmt) + # Explicit initial names : just validate + else: + ndtype.names = validate(ndtype.names, defaultfmt=defaultfmt) + return ndtype + diff --git a/numpy/lib/io.py b/numpy/lib/io.py index 255c5a7f5..239d0808e 100644 --- a/numpy/lib/io.py +++ b/numpy/lib/io.py @@ -19,8 +19,9 @@ from _datasource import DataSource from _compiled_base import packbits, unpackbits from _iotools import LineSplitter, NameValidator, StringConverter, \ - ConverterError, ConverterLockError, \ - _is_string_like, has_nested_fields, flatten_dtype + ConverterError, ConverterLockError, ConversionWarning, \ + _is_string_like, has_nested_fields, flatten_dtype, \ + easy_dtype _file = file _string_like = _is_string_like @@ -872,8 +873,8 @@ def fromregex(file, regexp, dtype): def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, converters=None, missing='', missing_values=None, usecols=None, names=None, excludelist=None, deletechars=None, autostrip=False, - case_sensitive=True, unpack=None, usemask=False, loose=True, - invalid_raise=False): + case_sensitive=True, defaultfmt="f%i", unpack=None, + usemask=False, loose=True, invalid_raise=True): """ Load data from a text file, with missing values handled as specified. @@ -930,6 +931,8 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, deletechars : str, optional A string combining invalid characters that must be deleted from the names. + defaultfmt : str, optional + A format used to define default field names, such as "f%i" or "f_%02i". autostrip : bool, optional Whether to automatically strip white spaces from the variables. case_sensitive : {True, False, 'upper', 'lower'}, optional @@ -945,7 +948,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, invalid_raise : bool, optional If True, an exception is raised if an inconsistency is detected in the number of columns. - If False, a warning is emitted but the incriminating lines are skipped. + If False, a warning is emitted and the offending lines are skipped. Returns ------- @@ -1063,9 +1066,6 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, nbcols = len(usecols or first_values) # Check the names and overwrite the dtype.names if needed - if dtype is not None: - dtype = np.dtype(dtype) - dtypenames = getattr(dtype, 'names', None) if names is True: names = validate_names([_.strip() for _ in first_values]) first_line = '' @@ -1073,10 +1073,9 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, names = validate_names([_.strip() for _ in names.split(',')]) elif names: names = validate_names(names) - elif dtypenames: - dtype.names = validate_names(dtypenames) - if names and dtypenames: - dtype.names = names + # Get the dtype + if dtype is not None: + dtype = easy_dtype(dtype, defaultfmt=defaultfmt, names=names) # If usecols is a list of names, convert to a list of indices if usecols: @@ -1213,7 +1212,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, raise ValueError(errmsg) # Issue a warning ? else: - warnings.warn(errmsg) + warnings.warn(errmsg, ConversionWarning) # Convert each value according to the converter: # We want to modify the list in place to avoid creating a new one... @@ -1234,26 +1233,28 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, data = rows if dtype is None: # Get the dtypes from the types of the converters - coldtypes = [conv.type for conv in converters] + column_types = [conv.type for conv in converters] # Find the columns with strings... - strcolidx = [i for (i, v) in enumerate(coldtypes) + strcolidx = [i for (i, v) in enumerate(column_types) if v in (type('S'), np.string_)] # ... and take the largest number of chars. for i in strcolidx: - coldtypes[i] = "|S%i" % max(len(row[i]) for row in data) + column_types[i] = "|S%i" % max(len(row[i]) for row in data) # if names is None: # If the dtype is uniform, don't define names, else use '' base = set([c.type for c in converters if c._checked]) - if len(base) == 1: (ddtype, mdtype) = (list(base)[0], np.bool) else: - ddtype = [('', dt) for dt in coldtypes] - mdtype = [('', np.bool) for dt in coldtypes] + ddtype = [(defaultfmt % i, dt) + for (i, dt) in enumerate(column_types)] + if usemask: + mdtype = [(defaultfmt % i, np.bool) + for (i, dt) in enumerate(column_types)] else: - ddtype = zip(names, coldtypes) - mdtype = zip(names, [np.bool] * len(coldtypes)) + ddtype = zip(names, column_types) + mdtype = zip(names, [np.bool] * len(column_types)) output = np.array(data, dtype=ddtype) if usemask: outputmask = np.array(masks, dtype=mdtype) @@ -1331,11 +1332,8 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, return output.squeeze() -def ndfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, - converters=None, missing='', missing_values=None, usecols=None, - names=None, excludelist=None, deletechars=None, autostrip=False, - case_sensitive=True, unpack=None, loose=True, - invalid_raise=False): + +def ndfromtxt(fname, **kwargs): """ Load ASCII data stored in a file and return it as a single array. @@ -1347,21 +1345,11 @@ def ndfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, numpy.genfromtxt : generic function. """ - kwargs = dict(dtype=dtype, comments=comments, delimiter=delimiter, - skiprows=skiprows, converters=converters, - missing=missing, missing_values=missing_values, - usecols=usecols, unpack=unpack, names=names, - excludelist=excludelist, deletechars=deletechars, - case_sensitive=case_sensitive, usemask=False, - autostrip=autostrip, loose=loose, invalid_raise=invalid_raise) + kwargs['usemask'] = False return genfromtxt(fname, **kwargs) -def mafromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, - converters=None, missing='', missing_values=None, usecols=None, - names=None, excludelist=None, deletechars=None, autostrip=False, - case_sensitive=True, unpack=None, loose=True, - invalid_raise=False): +def mafromtxt(fname, **kwargs): """ Load ASCII data stored in a text file and return a masked array. @@ -1372,22 +1360,11 @@ def mafromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0, numpy.genfromtxt : generic function to load ASCII data. """ - kwargs = dict(dtype=dtype, comments=comments, delimiter=delimiter, - skiprows=skiprows, converters=converters, - missing=missing, missing_values=missing_values, - usecols=usecols, unpack=unpack, names=names, - excludelist=excludelist, deletechars=deletechars, - case_sensitive=case_sensitive, autostrip=autostrip, - loose=loose, invalid_raise=invalid_raise, - usemask=True) + kwargs['usemask'] = True return genfromtxt(fname, **kwargs) -def recfromtxt(fname, dtype=None, comments='#', delimiter=None, skiprows=0, - converters=None, missing='', missing_values=None, - usecols=None, unpack=None, names=None, autostrip=False, - excludelist=None, deletechars=None, case_sensitive=True, - loose=True, invalid_raise=False, usemask=False): +def recfromtxt(fname, **kwargs): """ Load ASCII data from a file and return it in a record array. @@ -1407,13 +1384,8 @@ def recfromtxt(fname, dtype=None, comments='#', delimiter=None, skiprows=0, array will be determined from the data. """ - kwargs = dict(dtype=dtype, comments=comments, delimiter=delimiter, - skiprows=skiprows, converters=converters, - missing=missing, missing_values=missing_values, - usecols=usecols, unpack=unpack, names=names, - excludelist=excludelist, deletechars=deletechars, - case_sensitive=case_sensitive, usemask=usemask, - loose=loose, autostrip=autostrip, invalid_raise=invalid_raise) + kwargs.update(dtype=kwargs.get('dtype', None)) + usemask = kwargs.get('usemask', False) output = genfromtxt(fname, **kwargs) if usemask: from numpy.ma.mrecords import MaskedRecords @@ -1423,12 +1395,7 @@ def recfromtxt(fname, dtype=None, comments='#', delimiter=None, skiprows=0, return output -def recfromcsv(fname, dtype=None, comments='#', skiprows=0, - converters=None, missing='', missing_values=None, - usecols=None, unpack=None, names=True, - excludelist=None, deletechars=None, case_sensitive='lower', - loose=True, autostrip=False, invalid_raise=False, - usemask=False): +def recfromcsv(fname, **kwargs): """ Load ASCII data stored in a comma-separated file. @@ -1443,13 +1410,15 @@ def recfromcsv(fname, dtype=None, comments='#', skiprows=0, numpy.genfromtxt : generic function to load ASCII data. """ - kwargs = dict(dtype=dtype, comments=comments, delimiter=",", - skiprows=skiprows, converters=converters, - missing=missing, missing_values=missing_values, - usecols=usecols, unpack=unpack, names=names, - excludelist=excludelist, deletechars=deletechars, - case_sensitive=case_sensitive, usemask=usemask, - loose=loose, autostrip=autostrip, invalid_raise=invalid_raise) + case_sensitive = kwargs.get('case_sensitive', "lower") or "lower" + names = kwargs.get('names', True) + if names is None: + names = True + kwargs.update(dtype=kwargs.get('update', None), + delimiter=kwargs.get('delimiter', ",") or ",", + names=names, + case_sensitive=case_sensitive) + usemask = kwargs.get("usemask", False) output = genfromtxt(fname, **kwargs) if usemask: from numpy.ma.mrecords import MaskedRecords diff --git a/numpy/lib/tests/test__iotools.py b/numpy/lib/tests/test__iotools.py index 2cb8461c3..c16491aee 100644 --- a/numpy/lib/tests/test__iotools.py +++ b/numpy/lib/tests/test__iotools.py @@ -3,7 +3,7 @@ import StringIO import numpy as np from numpy.lib._iotools import LineSplitter, NameValidator, StringConverter,\ - has_nested_fields + has_nested_fields, easy_dtype from numpy.testing import * class TestLineSplitter(TestCase): @@ -90,6 +90,35 @@ class TestNameValidator(TestCase): validator = NameValidator(excludelist = ['dates', 'data', 'mask']) test = validator.validate(names) assert_equal(test, ['dates_', 'data_', 'Other_Data', 'mask_']) + # + def test_missing_names(self): + "Test validate missing names" + namelist = ('a', 'b', 'c') + validator = NameValidator() + assert_equal(validator(namelist), ['a', 'b', 'c']) + namelist = ('', 'b', 'c') + assert_equal(validator(namelist), ['f0', 'b', 'c']) + namelist = ('a', 'b', '') + assert_equal(validator(namelist), ['a', 'b', 'f0']) + namelist = ('', 'f0', '') + assert_equal(validator(namelist), ['f1', 'f0', 'f2']) + # + def test_validate_nb_names(self): + "Test validate nb names" + namelist = ('a', 'b', 'c') + validator = NameValidator() + assert_equal(validator(namelist, nbfields=1), ('a', )) + assert_equal(validator(namelist, nbfields=5, defaultfmt="g%i"), + ['a', 'b', 'c', 'g0', 'g1']) + # + def test_validate_wo_names(self): + "Test validate no names" + namelist = None + validator = NameValidator() + assert(validator(namelist) is None) + assert_equal(validator(namelist, nbfields=3), ['f0', 'f1', 'f2']) + + #------------------------------------------------------------------------------- @@ -165,3 +194,59 @@ class TestMiscFunctions(TestCase): ndtype = np.dtype([('A', int), ('B', [('BA', float), ('BB', '|S1')])]) assert_equal(has_nested_fields(ndtype), True) + def test_easy_dtype(self): + "Test ndtype on dtypes" + # Simple case + ndtype = float + assert_equal(easy_dtype(ndtype), np.dtype(float)) + # As string w/o names + ndtype = "i4, f8" + assert_equal(easy_dtype(ndtype), + np.dtype([('f0', "i4"), ('f1', "f8")])) + # As string w/o names but different default format + assert_equal(easy_dtype(ndtype, defaultfmt="field_%03i"), + np.dtype([('field_000', "i4"), ('field_001', "f8")])) + # As string w/ names + ndtype = "i4, f8" + assert_equal(easy_dtype(ndtype, names="a, b"), + np.dtype([('a', "i4"), ('b', "f8")])) + # As string w/ names (too many) + ndtype = "i4, f8" + assert_equal(easy_dtype(ndtype, names="a, b, c"), + np.dtype([('a', "i4"), ('b', "f8")])) + # As string w/ names (not enough) + ndtype = "i4, f8" + assert_equal(easy_dtype(ndtype, names=", b"), + np.dtype([('f0', "i4"), ('b', "f8")])) + # ... (with different default format) + assert_equal(easy_dtype(ndtype, names="a", defaultfmt="f%02i"), + np.dtype([('a', "i4"), ('f00', "f8")])) + # As list of tuples w/o names + ndtype = [('A', int), ('B', float)] + assert_equal(easy_dtype(ndtype), np.dtype([('A', int), ('B', float)])) + # As list of tuples w/ names + assert_equal(easy_dtype(ndtype, names="a,b"), + np.dtype([('a', int), ('b', float)])) + # As list of tuples w/ not enough names + assert_equal(easy_dtype(ndtype, names="a"), + np.dtype([('a', int), ('f0', float)])) + # As list of tuples w/ too many names + assert_equal(easy_dtype(ndtype, names="a,b,c"), + np.dtype([('a', int), ('b', float)])) + # As list of types w/o names + ndtype = (int, float, float) + assert_equal(easy_dtype(ndtype), + np.dtype([('f0', int), ('f1', float), ('f2', float)])) + # As list of types w names + ndtype = (int, float, float) + assert_equal(easy_dtype(ndtype, names="a, b, c"), + np.dtype([('a', int), ('b', float), ('c', float)])) + # As simple dtype w/ names + ndtype = np.dtype(float) + assert_equal(easy_dtype(ndtype, names="a, b, c"), + np.dtype([(_, float) for _ in ('a', 'b', 'c')])) + # As simple dtype w/o names (but multiple fields) + ndtype = np.dtype(float) + assert_equal(easy_dtype(ndtype, names=['', '', ''], defaultfmt="f%02i"), + np.dtype([(_, float) for _ in ('f00', 'f01', 'f02')])) + diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index bc05ed4d3..5e0d666c2 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -763,8 +763,8 @@ M 33 21.99 def test_withmissing(self): data = StringIO.StringIO('A,B\n0,1\n2,N/A') - test = np.mafromtxt(data, dtype=None, delimiter=',', missing='N/A', - names=True) + kwargs = dict(delimiter=",", missing="N/A", names=True) + test = np.mafromtxt(data, dtype=None, **kwargs) control = ma.array([(0, 1), (2, -1)], mask=[(False, False), (False, True)], dtype=[('A', np.int), ('B', np.int)]) @@ -772,9 +772,10 @@ M 33 21.99 assert_equal(test.mask, control.mask) # data.seek(0) - test = np.mafromtxt(data, delimiter=',', missing='N/A', names=True) + test = np.mafromtxt(data, **kwargs) control = ma.array([(0, 1), (2, -1)], - mask=[[False, False], [False, True]],) + mask=[(False, False), (False, True)], + dtype=[('A', np.float), ('B', np.float)]) assert_equal(test, control) assert_equal(test.mask, control.mask) @@ -848,13 +849,14 @@ M 33 21.99 data.insert(0, "a, b, c, d, e") mdata = StringIO.StringIO("\n".join(data)) # - mtest = np.ndfromtxt(mdata, delimiter=",", names=True, dtype=None,) + kwargs = dict(delimiter=",", dtype=None, names=True) + mtest = np.ndfromtxt(mdata, invalid_raise=False, **kwargs) assert_equal(len(mtest), 45) assert_equal(mtest, np.ones(45, dtype=[(_, int) for _ in 'abcde'])) # mdata.seek(0) assert_raises(ValueError, np.ndfromtxt, mdata, - delimiter=",", names=True, invalid_raise=True) + delimiter=",", names=True) def test_invalid_raise_with_usecols(self): "Test invalid_raise with usecols" @@ -863,15 +865,15 @@ M 33 21.99 data[10 * i] = "2, 2, 2, 2 2" data.insert(0, "a, b, c, d, e") mdata = StringIO.StringIO("\n".join(data)) + kwargs = dict(delimiter=",", dtype=None, names=True, + invalid_raise=False) # - mtest = np.ndfromtxt(mdata, delimiter=",", names=True, dtype=None, - usecols=(0, 4)) + mtest = np.ndfromtxt(mdata, usecols=(0, 4), **kwargs) assert_equal(len(mtest), 45) assert_equal(mtest, np.ones(45, dtype=[(_, int) for _ in 'ae'])) # mdata.seek(0) - mtest = np.ndfromtxt(mdata, delimiter=",", names=True, dtype=None, - usecols=(0, 1)) + mtest = np.ndfromtxt(mdata, usecols=(0, 1), **kwargs) assert_equal(len(mtest), 50) control = np.ones(50, dtype=[(_, int) for _ in 'ab']) control[[10 * _ for _ in range(5)]] = (2, 2) @@ -879,6 +881,7 @@ M 33 21.99 def test_inconsistent_dtype(self): + "Test inconsistent dtype" data = ["1, 1, 1, 1, -1.1"] * 50 mdata = StringIO.StringIO("\n".join(data)) @@ -888,6 +891,64 @@ M 33 21.99 assert_raises(TypeError, np.genfromtxt, mdata, **kwargs) + def test_default_field_format(self): + "Test default format" + data = "0, 1, 2.3\n4, 5, 6.7" + mtest = np.ndfromtxt(StringIO.StringIO(data), + delimiter=",", dtype=None, defaultfmt="f%02i") + ctrl = np.array([(0, 1, 2.3), (4, 5, 6.7)], + dtype=[("f00", int), ("f01", int), ("f02", float)]) + assert_equal(mtest, ctrl) + + def test_single_dtype_wo_names(self): + "Test single dtype w/o names" + data = "0, 1, 2.3\n4, 5, 6.7" + mtest = np.ndfromtxt(StringIO.StringIO(data), + delimiter=",", dtype=float, defaultfmt="f%02i") + ctrl = np.array([[0., 1., 2.3], [4., 5., 6.7]], dtype=float) + assert_equal(mtest, ctrl) + + def test_single_dtype_w_explicit_names(self): + "Test single dtype w explicit names" + data = "0, 1, 2.3\n4, 5, 6.7" + mtest = np.ndfromtxt(StringIO.StringIO(data), + delimiter=",", dtype=float, names="a, b, c") + ctrl = np.array([(0., 1., 2.3), (4., 5., 6.7)], + dtype=[(_, float) for _ in "abc"]) + assert_equal(mtest, ctrl) + + def test_single_dtype_w_implicit_names(self): + "Test single dtype w implicit names" + data = "a, b, c\n0, 1, 2.3\n4, 5, 6.7" + mtest = np.ndfromtxt(StringIO.StringIO(data), + delimiter=",", dtype=float, names=True) + ctrl = np.array([(0., 1., 2.3), (4., 5., 6.7)], + dtype=[(_, float) for _ in "abc"]) + assert_equal(mtest, ctrl) + + def test_easy_structured_dtype(self): + "Test easy structured dtype" + data = "0, 1, 2.3\n4, 5, 6.7" + mtest = np.ndfromtxt(StringIO.StringIO(data), delimiter=",", + dtype=(int, float, float), defaultfmt="f_%02i") + ctrl = np.array([(0, 1., 2.3), (4, 5., 6.7)], + dtype=[("f_00", int), ("f_01", float), ("f_02", float)]) + assert_equal(mtest, ctrl) + + def test_autostrip(self): + "Test autostrip" + data = "01/01/2003 , 1.3, abcde" + kwargs = dict(delimiter=",", dtype=None) + mtest = np.ndfromtxt(StringIO.StringIO(data), **kwargs) + ctrl = np.array([('01/01/2003 ', 1.3, ' abcde')], + dtype=[('f0', '|S12'), ('f1', float), ('f2', '|S8')]) + assert_equal(mtest, ctrl) + mtest = np.ndfromtxt(StringIO.StringIO(data), autostrip=True, **kwargs) + ctrl = np.array([('01/01/2003', 1.3, 'abcde')], + dtype=[('f0', '|S10'), ('f1', float), ('f2', '|S5')]) + assert_equal(mtest, ctrl) + + def test_recfromtxt(self): # data = StringIO.StringIO('A,B\n0,1\n2,3') diff --git a/numpy/ma/core.py b/numpy/ma/core.py index d32385493..cee884e9d 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -1963,7 +1963,9 @@ def masked_equal(x, value, copy=True): # c = umath.equal(d, value) # m = mask_or(c, getmask(x)) # return array(d, mask=m, copy=copy) - return masked_where(equal(x, value), x, copy=copy) + output = masked_where(equal(x, value), x, copy=copy) + output.fill_value = value + return output def masked_inside(x, v1, v2, copy=True): diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 61f315753..e3447b546 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -2591,6 +2591,12 @@ class TestMaskedArrayFunctions(TestCase): assert_equal(mx, x) assert_equal(mx._mask, [1,1,0]) + def test_masked_equal_fill_value(self): + x = [1, 2, 3] + mx = masked_equal(x, 3) + assert_equal(mx._mask, [0, 0, 1]) + assert_equal(mx.fill_value, 3) + def test_masked_where_condition(self): "Tests masking functions." x = array([1.,2.,3.,4.,5.]) |