summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2009-10-09 02:17:30 +0000
committerpierregm <pierregm@localhost>2009-10-09 02:17:30 +0000
commit8cad335f8b97100df988dbb6fd5d06072c667515 (patch)
tree5eb53d599da9028fc42d31ffdd979f022268a472
parente54df63621db89eddbadc7bf0f36798ee1f79e0a (diff)
downloadnumpy-8cad335f8b97100df988dbb6fd5d06072c667515.tar.gz
* ma.masked_equal : force the `fill_value` of the output to `value` (ticket #1253)
* lib._iotools: - NameValidator : add the `nbfields` optional argument to validate - add easy_dtype * lib.io.genfromtxt : - add the `autostrip` optional argument (ticket #1238) - use `invalid_raise=True` as default - use the easy_dtype mechanism (ticket #1252)
-rw-r--r--numpy/lib/_iotools.py150
-rw-r--r--numpy/lib/io.py111
-rw-r--r--numpy/lib/tests/test__iotools.py87
-rw-r--r--numpy/lib/tests/test_io.py81
-rw-r--r--numpy/ma/core.py4
-rw-r--r--numpy/ma/tests/test_core.py6
6 files changed, 331 insertions, 108 deletions
diff --git a/numpy/lib/_iotools.py b/numpy/lib/_iotools.py
index 02385305b..9e8bcce2a 100644
--- a/numpy/lib/_iotools.py
+++ b/numpy/lib/_iotools.py
@@ -118,6 +118,9 @@ def flatten_dtype(ndtype, flatten_base=False):
+
+
+
class LineSplitter:
"""
Object to split a string at a given delimiter or at given places.
@@ -256,19 +259,19 @@ class NameValidator:
defaultdeletechars = set("""~!@#$%^&*()-=+~\|]}[{';: /?.>,<""")
#
def __init__(self, excludelist=None, deletechars=None, case_sensitive=None):
- #
+ # Process the exclusion list ..
if excludelist is None:
excludelist = []
excludelist.extend(self.defaultexcludelist)
self.excludelist = excludelist
- #
+ # Process the list of characters to delete
if deletechars is None:
delete = self.defaultdeletechars
else:
delete = set(deletechars)
delete.add('"')
self.deletechars = delete
-
+ # Process the case option .....
if (case_sensitive is None) or (case_sensitive is True):
self.case_converter = lambda x: x
elif (case_sensitive is False) or ('u' in case_sensitive):
@@ -277,18 +280,21 @@ class NameValidator:
self.case_converter = lambda x: x.lower()
else:
self.case_converter = lambda x: x
- #
- def validate(self, names, default='f'):
+
+ def validate(self, names, defaultfmt="f%i", nbfields=None):
"""
Validate a list of strings to use as field names for a structured array.
Parameters
----------
- names : list of str
- The strings that are to be validated.
- default : str, optional
- The default field name, used if validating a given string reduces its
+ names : sequence of str
+ Strings to be validated.
+ defaultfmt : str, optional
+ Default format string, used if validating a given string reduces its
length to zero.
+ nboutput : integer, optional
+ Final number of validated names, used to expand or shrink the initial
+ list of names.
Returns
-------
@@ -301,24 +307,38 @@ class NameValidator:
calling `validate`. For examples, see `NameValidator`.
"""
- #
- if names is None:
- return
- #
- validatednames = []
- seen = dict()
- #
+ # Initial checks ..............
+ if (names is None):
+ if (nbfields is None):
+ return None
+ names = []
+ if isinstance(names, basestring):
+ names = [names,]
+ if nbfields is not None:
+ nbnames = len(names)
+ if (nbnames < nbfields):
+ names = list(names) + [''] * (nbfields - nbnames)
+ elif (nbnames > nbfields):
+ names = names[:nbfields]
+ # Set some shortcuts ...........
deletechars = self.deletechars
excludelist = self.excludelist
- #
case_converter = self.case_converter
+ # Initializes some variables ...
+ validatednames = []
+ seen = dict()
+ nbempty = 0
#
- for i, item in enumerate(names):
+ for item in names:
item = case_converter(item)
item = item.strip().replace(' ', '_')
item = ''.join([c for c in item if c not in deletechars])
- if not len(item):
- item = '%s%d' % (default, i)
+ if item == '':
+ item = defaultfmt % nbempty
+ while item in names:
+ nbempty += 1
+ item = defaultfmt % nbempty
+ nbempty += 1
elif item in excludelist:
item += '_'
cnt = seen.get(item, 0)
@@ -326,11 +346,11 @@ class NameValidator:
validatednames.append(item + '_%d' % cnt)
else:
validatednames.append(item)
- seen[item] = cnt+1
- return validatednames
+ seen[item] = cnt + 1
+ return tuple(validatednames)
#
- def __call__(self, names, default='f'):
- return self.validate(names, default)
+ def __call__(self, names, defaultfmt="f%i", nbfields=None):
+ return self.validate(names, defaultfmt=defaultfmt, nbfields=nbfields)
@@ -376,6 +396,10 @@ class ConverterError(Exception):
class ConverterLockError(ConverterError):
pass
+class ConversionWarning(UserWarning):
+ pass
+
+
class StringConverter:
"""
@@ -455,7 +479,7 @@ class StringConverter:
--------
>>> import dateutil.parser
>>> import datetime
- >>> dateparser = datetutil.parser.parse
+ >>> dateparser = datetustil.parser.parse
>>> defaultdate = datetime.date(2000, 1, 1)
>>> StringConverter.upgrade_mapper(dateparser, default=defaultdate)
"""
@@ -659,3 +683,79 @@ class StringConverter:
self.missing_values.add(val)
else:
self.missing_values = []
+
+
+
+def easy_dtype(ndtype, names=None, defaultfmt="f%i", **validationargs):
+ """
+ Convenience function to create a `np.dtype` object.
+
+ The function processes the input dtype and matches it with the given names.
+
+ Parameters
+ ----------
+ ndtype : var
+ Definition of the dtype. Can be any string or dictionary recognized
+ by the `np.dtype` function or a sequence of types.
+ names : str or sequence, optional
+ Sequence of strings to use as field names for a structured dtype.
+ For convenience, `names` can be a string of a comma-separated list of
+ names
+ defaultfmt : str, optional
+ Format string used to define missing names, such as "f%i" (default),
+ "fields_%02i"...
+ validationargs : optional
+ A series of optional arguments used to initialize a NameValidator.
+
+ Examples
+ --------
+ >>> np.lib._iotools.easy_dtype(float)
+ dtype('float64')
+ >>> np.lib._iotools.easy_dtype("i4, f8")
+ dtype([('f0', '<i4'), ('f1', '<f8')])
+ >>> np.lib._iotools.easy_dtype("i4, f8", defaultfmt="field_%03i")
+ dtype([('field_000', '<i4'), ('field_001', '<f8')])
+ >>> np.lib._iotools.easy_dtype((int, float, float), names="a,b,c")
+ dtype([('a', '<i8'), ('b', '<f8'), ('c', '<f8')])
+ >>> np.lib._iotools.easy_dtype(float, names="a,b,c")
+ dtype([('a', '<f8'), ('b', '<f8'), ('c', '<f8')])
+ """
+ try:
+ ndtype = np.dtype(ndtype)
+ except TypeError:
+ validate = NameValidator(**validationargs)
+ nbfields = len(ndtype)
+ if names is None:
+ names = [''] * len(ndtype)
+ elif isinstance(names, basestring):
+ names = names.split(",")
+ names = validate(names, nbfields=nbfields, defaultfmt=defaultfmt)
+ ndtype = np.dtype(dict(formats=ndtype, names=names))
+ else:
+ nbtypes = len(ndtype)
+ # Explicit names
+ if names is not None:
+ validate = NameValidator(**validationargs)
+ if isinstance(names, basestring):
+ names = names.split(",")
+ # Simple dtype: repeat to match the nb of names
+ if nbtypes == 0:
+ formats = tuple([ndtype.type] * len(names))
+ names = validate(names, defaultfmt=defaultfmt)
+ ndtype = np.dtype(zip(names, formats))
+ # Structured dtype: just validate the names as needed
+ else:
+ ndtype.names = validate(names, nbfields=nbtypes,
+ defaultfmt=defaultfmt)
+ # No implicit names
+ elif (nbtypes > 0):
+ validate = NameValidator(**validationargs)
+ # Default initial names : should we change the format ?
+ if (ndtype.names == tuple("f%i" % i for i in range(nbtypes))) and \
+ (defaultfmt != "f%i"):
+ ndtype.names = validate([''] * nbtypes, defaultfmt=defaultfmt)
+ # Explicit initial names : just validate
+ else:
+ ndtype.names = validate(ndtype.names, defaultfmt=defaultfmt)
+ return ndtype
+
diff --git a/numpy/lib/io.py b/numpy/lib/io.py
index 255c5a7f5..239d0808e 100644
--- a/numpy/lib/io.py
+++ b/numpy/lib/io.py
@@ -19,8 +19,9 @@ from _datasource import DataSource
from _compiled_base import packbits, unpackbits
from _iotools import LineSplitter, NameValidator, StringConverter, \
- ConverterError, ConverterLockError, \
- _is_string_like, has_nested_fields, flatten_dtype
+ ConverterError, ConverterLockError, ConversionWarning, \
+ _is_string_like, has_nested_fields, flatten_dtype, \
+ easy_dtype
_file = file
_string_like = _is_string_like
@@ -872,8 +873,8 @@ def fromregex(file, regexp, dtype):
def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
converters=None, missing='', missing_values=None, usecols=None,
names=None, excludelist=None, deletechars=None, autostrip=False,
- case_sensitive=True, unpack=None, usemask=False, loose=True,
- invalid_raise=False):
+ case_sensitive=True, defaultfmt="f%i", unpack=None,
+ usemask=False, loose=True, invalid_raise=True):
"""
Load data from a text file, with missing values handled as specified.
@@ -930,6 +931,8 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
deletechars : str, optional
A string combining invalid characters that must be deleted from the
names.
+ defaultfmt : str, optional
+ A format used to define default field names, such as "f%i" or "f_%02i".
autostrip : bool, optional
Whether to automatically strip white spaces from the variables.
case_sensitive : {True, False, 'upper', 'lower'}, optional
@@ -945,7 +948,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
invalid_raise : bool, optional
If True, an exception is raised if an inconsistency is detected in the
number of columns.
- If False, a warning is emitted but the incriminating lines are skipped.
+ If False, a warning is emitted and the offending lines are skipped.
Returns
-------
@@ -1063,9 +1066,6 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
nbcols = len(usecols or first_values)
# Check the names and overwrite the dtype.names if needed
- if dtype is not None:
- dtype = np.dtype(dtype)
- dtypenames = getattr(dtype, 'names', None)
if names is True:
names = validate_names([_.strip() for _ in first_values])
first_line = ''
@@ -1073,10 +1073,9 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
names = validate_names([_.strip() for _ in names.split(',')])
elif names:
names = validate_names(names)
- elif dtypenames:
- dtype.names = validate_names(dtypenames)
- if names and dtypenames:
- dtype.names = names
+ # Get the dtype
+ if dtype is not None:
+ dtype = easy_dtype(dtype, defaultfmt=defaultfmt, names=names)
# If usecols is a list of names, convert to a list of indices
if usecols:
@@ -1213,7 +1212,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
raise ValueError(errmsg)
# Issue a warning ?
else:
- warnings.warn(errmsg)
+ warnings.warn(errmsg, ConversionWarning)
# Convert each value according to the converter:
# We want to modify the list in place to avoid creating a new one...
@@ -1234,26 +1233,28 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
data = rows
if dtype is None:
# Get the dtypes from the types of the converters
- coldtypes = [conv.type for conv in converters]
+ column_types = [conv.type for conv in converters]
# Find the columns with strings...
- strcolidx = [i for (i, v) in enumerate(coldtypes)
+ strcolidx = [i for (i, v) in enumerate(column_types)
if v in (type('S'), np.string_)]
# ... and take the largest number of chars.
for i in strcolidx:
- coldtypes[i] = "|S%i" % max(len(row[i]) for row in data)
+ column_types[i] = "|S%i" % max(len(row[i]) for row in data)
#
if names is None:
# If the dtype is uniform, don't define names, else use ''
base = set([c.type for c in converters if c._checked])
-
if len(base) == 1:
(ddtype, mdtype) = (list(base)[0], np.bool)
else:
- ddtype = [('', dt) for dt in coldtypes]
- mdtype = [('', np.bool) for dt in coldtypes]
+ ddtype = [(defaultfmt % i, dt)
+ for (i, dt) in enumerate(column_types)]
+ if usemask:
+ mdtype = [(defaultfmt % i, np.bool)
+ for (i, dt) in enumerate(column_types)]
else:
- ddtype = zip(names, coldtypes)
- mdtype = zip(names, [np.bool] * len(coldtypes))
+ ddtype = zip(names, column_types)
+ mdtype = zip(names, [np.bool] * len(column_types))
output = np.array(data, dtype=ddtype)
if usemask:
outputmask = np.array(masks, dtype=mdtype)
@@ -1331,11 +1332,8 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
return output.squeeze()
-def ndfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
- converters=None, missing='', missing_values=None, usecols=None,
- names=None, excludelist=None, deletechars=None, autostrip=False,
- case_sensitive=True, unpack=None, loose=True,
- invalid_raise=False):
+
+def ndfromtxt(fname, **kwargs):
"""
Load ASCII data stored in a file and return it as a single array.
@@ -1347,21 +1345,11 @@ def ndfromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
numpy.genfromtxt : generic function.
"""
- kwargs = dict(dtype=dtype, comments=comments, delimiter=delimiter,
- skiprows=skiprows, converters=converters,
- missing=missing, missing_values=missing_values,
- usecols=usecols, unpack=unpack, names=names,
- excludelist=excludelist, deletechars=deletechars,
- case_sensitive=case_sensitive, usemask=False,
- autostrip=autostrip, loose=loose, invalid_raise=invalid_raise)
+ kwargs['usemask'] = False
return genfromtxt(fname, **kwargs)
-def mafromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
- converters=None, missing='', missing_values=None, usecols=None,
- names=None, excludelist=None, deletechars=None, autostrip=False,
- case_sensitive=True, unpack=None, loose=True,
- invalid_raise=False):
+def mafromtxt(fname, **kwargs):
"""
Load ASCII data stored in a text file and return a masked array.
@@ -1372,22 +1360,11 @@ def mafromtxt(fname, dtype=float, comments='#', delimiter=None, skiprows=0,
numpy.genfromtxt : generic function to load ASCII data.
"""
- kwargs = dict(dtype=dtype, comments=comments, delimiter=delimiter,
- skiprows=skiprows, converters=converters,
- missing=missing, missing_values=missing_values,
- usecols=usecols, unpack=unpack, names=names,
- excludelist=excludelist, deletechars=deletechars,
- case_sensitive=case_sensitive, autostrip=autostrip,
- loose=loose, invalid_raise=invalid_raise,
- usemask=True)
+ kwargs['usemask'] = True
return genfromtxt(fname, **kwargs)
-def recfromtxt(fname, dtype=None, comments='#', delimiter=None, skiprows=0,
- converters=None, missing='', missing_values=None,
- usecols=None, unpack=None, names=None, autostrip=False,
- excludelist=None, deletechars=None, case_sensitive=True,
- loose=True, invalid_raise=False, usemask=False):
+def recfromtxt(fname, **kwargs):
"""
Load ASCII data from a file and return it in a record array.
@@ -1407,13 +1384,8 @@ def recfromtxt(fname, dtype=None, comments='#', delimiter=None, skiprows=0,
array will be determined from the data.
"""
- kwargs = dict(dtype=dtype, comments=comments, delimiter=delimiter,
- skiprows=skiprows, converters=converters,
- missing=missing, missing_values=missing_values,
- usecols=usecols, unpack=unpack, names=names,
- excludelist=excludelist, deletechars=deletechars,
- case_sensitive=case_sensitive, usemask=usemask,
- loose=loose, autostrip=autostrip, invalid_raise=invalid_raise)
+ kwargs.update(dtype=kwargs.get('dtype', None))
+ usemask = kwargs.get('usemask', False)
output = genfromtxt(fname, **kwargs)
if usemask:
from numpy.ma.mrecords import MaskedRecords
@@ -1423,12 +1395,7 @@ def recfromtxt(fname, dtype=None, comments='#', delimiter=None, skiprows=0,
return output
-def recfromcsv(fname, dtype=None, comments='#', skiprows=0,
- converters=None, missing='', missing_values=None,
- usecols=None, unpack=None, names=True,
- excludelist=None, deletechars=None, case_sensitive='lower',
- loose=True, autostrip=False, invalid_raise=False,
- usemask=False):
+def recfromcsv(fname, **kwargs):
"""
Load ASCII data stored in a comma-separated file.
@@ -1443,13 +1410,15 @@ def recfromcsv(fname, dtype=None, comments='#', skiprows=0,
numpy.genfromtxt : generic function to load ASCII data.
"""
- kwargs = dict(dtype=dtype, comments=comments, delimiter=",",
- skiprows=skiprows, converters=converters,
- missing=missing, missing_values=missing_values,
- usecols=usecols, unpack=unpack, names=names,
- excludelist=excludelist, deletechars=deletechars,
- case_sensitive=case_sensitive, usemask=usemask,
- loose=loose, autostrip=autostrip, invalid_raise=invalid_raise)
+ case_sensitive = kwargs.get('case_sensitive', "lower") or "lower"
+ names = kwargs.get('names', True)
+ if names is None:
+ names = True
+ kwargs.update(dtype=kwargs.get('update', None),
+ delimiter=kwargs.get('delimiter', ",") or ",",
+ names=names,
+ case_sensitive=case_sensitive)
+ usemask = kwargs.get("usemask", False)
output = genfromtxt(fname, **kwargs)
if usemask:
from numpy.ma.mrecords import MaskedRecords
diff --git a/numpy/lib/tests/test__iotools.py b/numpy/lib/tests/test__iotools.py
index 2cb8461c3..c16491aee 100644
--- a/numpy/lib/tests/test__iotools.py
+++ b/numpy/lib/tests/test__iotools.py
@@ -3,7 +3,7 @@ import StringIO
import numpy as np
from numpy.lib._iotools import LineSplitter, NameValidator, StringConverter,\
- has_nested_fields
+ has_nested_fields, easy_dtype
from numpy.testing import *
class TestLineSplitter(TestCase):
@@ -90,6 +90,35 @@ class TestNameValidator(TestCase):
validator = NameValidator(excludelist = ['dates', 'data', 'mask'])
test = validator.validate(names)
assert_equal(test, ['dates_', 'data_', 'Other_Data', 'mask_'])
+ #
+ def test_missing_names(self):
+ "Test validate missing names"
+ namelist = ('a', 'b', 'c')
+ validator = NameValidator()
+ assert_equal(validator(namelist), ['a', 'b', 'c'])
+ namelist = ('', 'b', 'c')
+ assert_equal(validator(namelist), ['f0', 'b', 'c'])
+ namelist = ('a', 'b', '')
+ assert_equal(validator(namelist), ['a', 'b', 'f0'])
+ namelist = ('', 'f0', '')
+ assert_equal(validator(namelist), ['f1', 'f0', 'f2'])
+ #
+ def test_validate_nb_names(self):
+ "Test validate nb names"
+ namelist = ('a', 'b', 'c')
+ validator = NameValidator()
+ assert_equal(validator(namelist, nbfields=1), ('a', ))
+ assert_equal(validator(namelist, nbfields=5, defaultfmt="g%i"),
+ ['a', 'b', 'c', 'g0', 'g1'])
+ #
+ def test_validate_wo_names(self):
+ "Test validate no names"
+ namelist = None
+ validator = NameValidator()
+ assert(validator(namelist) is None)
+ assert_equal(validator(namelist, nbfields=3), ['f0', 'f1', 'f2'])
+
+
#-------------------------------------------------------------------------------
@@ -165,3 +194,59 @@ class TestMiscFunctions(TestCase):
ndtype = np.dtype([('A', int), ('B', [('BA', float), ('BB', '|S1')])])
assert_equal(has_nested_fields(ndtype), True)
+ def test_easy_dtype(self):
+ "Test ndtype on dtypes"
+ # Simple case
+ ndtype = float
+ assert_equal(easy_dtype(ndtype), np.dtype(float))
+ # As string w/o names
+ ndtype = "i4, f8"
+ assert_equal(easy_dtype(ndtype),
+ np.dtype([('f0', "i4"), ('f1', "f8")]))
+ # As string w/o names but different default format
+ assert_equal(easy_dtype(ndtype, defaultfmt="field_%03i"),
+ np.dtype([('field_000', "i4"), ('field_001', "f8")]))
+ # As string w/ names
+ ndtype = "i4, f8"
+ assert_equal(easy_dtype(ndtype, names="a, b"),
+ np.dtype([('a', "i4"), ('b', "f8")]))
+ # As string w/ names (too many)
+ ndtype = "i4, f8"
+ assert_equal(easy_dtype(ndtype, names="a, b, c"),
+ np.dtype([('a', "i4"), ('b', "f8")]))
+ # As string w/ names (not enough)
+ ndtype = "i4, f8"
+ assert_equal(easy_dtype(ndtype, names=", b"),
+ np.dtype([('f0', "i4"), ('b', "f8")]))
+ # ... (with different default format)
+ assert_equal(easy_dtype(ndtype, names="a", defaultfmt="f%02i"),
+ np.dtype([('a', "i4"), ('f00', "f8")]))
+ # As list of tuples w/o names
+ ndtype = [('A', int), ('B', float)]
+ assert_equal(easy_dtype(ndtype), np.dtype([('A', int), ('B', float)]))
+ # As list of tuples w/ names
+ assert_equal(easy_dtype(ndtype, names="a,b"),
+ np.dtype([('a', int), ('b', float)]))
+ # As list of tuples w/ not enough names
+ assert_equal(easy_dtype(ndtype, names="a"),
+ np.dtype([('a', int), ('f0', float)]))
+ # As list of tuples w/ too many names
+ assert_equal(easy_dtype(ndtype, names="a,b,c"),
+ np.dtype([('a', int), ('b', float)]))
+ # As list of types w/o names
+ ndtype = (int, float, float)
+ assert_equal(easy_dtype(ndtype),
+ np.dtype([('f0', int), ('f1', float), ('f2', float)]))
+ # As list of types w names
+ ndtype = (int, float, float)
+ assert_equal(easy_dtype(ndtype, names="a, b, c"),
+ np.dtype([('a', int), ('b', float), ('c', float)]))
+ # As simple dtype w/ names
+ ndtype = np.dtype(float)
+ assert_equal(easy_dtype(ndtype, names="a, b, c"),
+ np.dtype([(_, float) for _ in ('a', 'b', 'c')]))
+ # As simple dtype w/o names (but multiple fields)
+ ndtype = np.dtype(float)
+ assert_equal(easy_dtype(ndtype, names=['', '', ''], defaultfmt="f%02i"),
+ np.dtype([(_, float) for _ in ('f00', 'f01', 'f02')]))
+
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index bc05ed4d3..5e0d666c2 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -763,8 +763,8 @@ M 33 21.99
def test_withmissing(self):
data = StringIO.StringIO('A,B\n0,1\n2,N/A')
- test = np.mafromtxt(data, dtype=None, delimiter=',', missing='N/A',
- names=True)
+ kwargs = dict(delimiter=",", missing="N/A", names=True)
+ test = np.mafromtxt(data, dtype=None, **kwargs)
control = ma.array([(0, 1), (2, -1)],
mask=[(False, False), (False, True)],
dtype=[('A', np.int), ('B', np.int)])
@@ -772,9 +772,10 @@ M 33 21.99
assert_equal(test.mask, control.mask)
#
data.seek(0)
- test = np.mafromtxt(data, delimiter=',', missing='N/A', names=True)
+ test = np.mafromtxt(data, **kwargs)
control = ma.array([(0, 1), (2, -1)],
- mask=[[False, False], [False, True]],)
+ mask=[(False, False), (False, True)],
+ dtype=[('A', np.float), ('B', np.float)])
assert_equal(test, control)
assert_equal(test.mask, control.mask)
@@ -848,13 +849,14 @@ M 33 21.99
data.insert(0, "a, b, c, d, e")
mdata = StringIO.StringIO("\n".join(data))
#
- mtest = np.ndfromtxt(mdata, delimiter=",", names=True, dtype=None,)
+ kwargs = dict(delimiter=",", dtype=None, names=True)
+ mtest = np.ndfromtxt(mdata, invalid_raise=False, **kwargs)
assert_equal(len(mtest), 45)
assert_equal(mtest, np.ones(45, dtype=[(_, int) for _ in 'abcde']))
#
mdata.seek(0)
assert_raises(ValueError, np.ndfromtxt, mdata,
- delimiter=",", names=True, invalid_raise=True)
+ delimiter=",", names=True)
def test_invalid_raise_with_usecols(self):
"Test invalid_raise with usecols"
@@ -863,15 +865,15 @@ M 33 21.99
data[10 * i] = "2, 2, 2, 2 2"
data.insert(0, "a, b, c, d, e")
mdata = StringIO.StringIO("\n".join(data))
+ kwargs = dict(delimiter=",", dtype=None, names=True,
+ invalid_raise=False)
#
- mtest = np.ndfromtxt(mdata, delimiter=",", names=True, dtype=None,
- usecols=(0, 4))
+ mtest = np.ndfromtxt(mdata, usecols=(0, 4), **kwargs)
assert_equal(len(mtest), 45)
assert_equal(mtest, np.ones(45, dtype=[(_, int) for _ in 'ae']))
#
mdata.seek(0)
- mtest = np.ndfromtxt(mdata, delimiter=",", names=True, dtype=None,
- usecols=(0, 1))
+ mtest = np.ndfromtxt(mdata, usecols=(0, 1), **kwargs)
assert_equal(len(mtest), 50)
control = np.ones(50, dtype=[(_, int) for _ in 'ab'])
control[[10 * _ for _ in range(5)]] = (2, 2)
@@ -879,6 +881,7 @@ M 33 21.99
def test_inconsistent_dtype(self):
+ "Test inconsistent dtype"
data = ["1, 1, 1, 1, -1.1"] * 50
mdata = StringIO.StringIO("\n".join(data))
@@ -888,6 +891,64 @@ M 33 21.99
assert_raises(TypeError, np.genfromtxt, mdata, **kwargs)
+ def test_default_field_format(self):
+ "Test default format"
+ data = "0, 1, 2.3\n4, 5, 6.7"
+ mtest = np.ndfromtxt(StringIO.StringIO(data),
+ delimiter=",", dtype=None, defaultfmt="f%02i")
+ ctrl = np.array([(0, 1, 2.3), (4, 5, 6.7)],
+ dtype=[("f00", int), ("f01", int), ("f02", float)])
+ assert_equal(mtest, ctrl)
+
+ def test_single_dtype_wo_names(self):
+ "Test single dtype w/o names"
+ data = "0, 1, 2.3\n4, 5, 6.7"
+ mtest = np.ndfromtxt(StringIO.StringIO(data),
+ delimiter=",", dtype=float, defaultfmt="f%02i")
+ ctrl = np.array([[0., 1., 2.3], [4., 5., 6.7]], dtype=float)
+ assert_equal(mtest, ctrl)
+
+ def test_single_dtype_w_explicit_names(self):
+ "Test single dtype w explicit names"
+ data = "0, 1, 2.3\n4, 5, 6.7"
+ mtest = np.ndfromtxt(StringIO.StringIO(data),
+ delimiter=",", dtype=float, names="a, b, c")
+ ctrl = np.array([(0., 1., 2.3), (4., 5., 6.7)],
+ dtype=[(_, float) for _ in "abc"])
+ assert_equal(mtest, ctrl)
+
+ def test_single_dtype_w_implicit_names(self):
+ "Test single dtype w implicit names"
+ data = "a, b, c\n0, 1, 2.3\n4, 5, 6.7"
+ mtest = np.ndfromtxt(StringIO.StringIO(data),
+ delimiter=",", dtype=float, names=True)
+ ctrl = np.array([(0., 1., 2.3), (4., 5., 6.7)],
+ dtype=[(_, float) for _ in "abc"])
+ assert_equal(mtest, ctrl)
+
+ def test_easy_structured_dtype(self):
+ "Test easy structured dtype"
+ data = "0, 1, 2.3\n4, 5, 6.7"
+ mtest = np.ndfromtxt(StringIO.StringIO(data), delimiter=",",
+ dtype=(int, float, float), defaultfmt="f_%02i")
+ ctrl = np.array([(0, 1., 2.3), (4, 5., 6.7)],
+ dtype=[("f_00", int), ("f_01", float), ("f_02", float)])
+ assert_equal(mtest, ctrl)
+
+ def test_autostrip(self):
+ "Test autostrip"
+ data = "01/01/2003 , 1.3, abcde"
+ kwargs = dict(delimiter=",", dtype=None)
+ mtest = np.ndfromtxt(StringIO.StringIO(data), **kwargs)
+ ctrl = np.array([('01/01/2003 ', 1.3, ' abcde')],
+ dtype=[('f0', '|S12'), ('f1', float), ('f2', '|S8')])
+ assert_equal(mtest, ctrl)
+ mtest = np.ndfromtxt(StringIO.StringIO(data), autostrip=True, **kwargs)
+ ctrl = np.array([('01/01/2003', 1.3, 'abcde')],
+ dtype=[('f0', '|S10'), ('f1', float), ('f2', '|S5')])
+ assert_equal(mtest, ctrl)
+
+
def test_recfromtxt(self):
#
data = StringIO.StringIO('A,B\n0,1\n2,3')
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index d32385493..cee884e9d 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -1963,7 +1963,9 @@ def masked_equal(x, value, copy=True):
# c = umath.equal(d, value)
# m = mask_or(c, getmask(x))
# return array(d, mask=m, copy=copy)
- return masked_where(equal(x, value), x, copy=copy)
+ output = masked_where(equal(x, value), x, copy=copy)
+ output.fill_value = value
+ return output
def masked_inside(x, v1, v2, copy=True):
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 61f315753..e3447b546 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -2591,6 +2591,12 @@ class TestMaskedArrayFunctions(TestCase):
assert_equal(mx, x)
assert_equal(mx._mask, [1,1,0])
+ def test_masked_equal_fill_value(self):
+ x = [1, 2, 3]
+ mx = masked_equal(x, 3)
+ assert_equal(mx._mask, [0, 0, 1])
+ assert_equal(mx.fill_value, 3)
+
def test_masked_where_condition(self):
"Tests masking functions."
x = array([1.,2.,3.,4.,5.])