summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2008-12-23 23:43:43 +0000
committerpierregm <pierregm@localhost>2008-12-23 23:43:43 +0000
commit8e77ab879c8cb1f25f0245c74bc9bba402c40b51 (patch)
tree78a00afed9fd6742657bb6e1f444b21ec203d216 /numpy/ma
parented916ff07cf2f1fbb6056cdc01fe3ae82a027ed5 (diff)
downloadnumpy-8e77ab879c8cb1f25f0245c74bc9bba402c40b51.tar.gz
testutils:
* assert_equal : use assert_equal_array on records * assert_array_compare : prevent the common mask to be back-propagated to the initial input arrays. * assert_equal_array : use operator.__eq__ instead of ma.equal * assert_equal_less: use operator.__less__ instead of ma.less core: * Fixed _check_fill_value for nested flexible types * Add a ndtype option to _make_mask_descr * Fixed mask_or for nested flexible types * Fixed the printing of masked arrays w/ flexible types.
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py115
-rw-r--r--numpy/ma/tests/test_core.py28
-rw-r--r--numpy/ma/tests/test_mrecords.py4
-rw-r--r--numpy/ma/testutils.py27
4 files changed, 116 insertions, 58 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 6b4dc98e6..8ee19778c 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -217,6 +217,28 @@ def maximum_fill_value(obj):
raise TypeError(errmsg)
+def _recursive_set_default_fill_value(dtypedescr):
+ deflist = []
+ for currentdescr in dtypedescr:
+ currenttype = currentdescr[1]
+ if isinstance(currenttype, list):
+ deflist.append(tuple(_recursive_set_default_fill_value(currenttype)))
+ else:
+ deflist.append(default_fill_value(np.dtype(currenttype)))
+ return tuple(deflist)
+
+def _recursive_set_fill_value(fillvalue, dtypedescr):
+ fillvalue = np.resize(fillvalue, len(dtypedescr))
+ output_value = []
+ for (fval, descr) in zip(fillvalue, dtypedescr):
+ cdtype = descr[1]
+ if isinstance(cdtype, list):
+ output_value.append(tuple(_recursive_set_fill_value(fval, cdtype)))
+ else:
+ output_value.append(np.array(fval, dtype=cdtype).item())
+ return tuple(output_value)
+
+
def _check_fill_value(fill_value, ndtype):
"""
Private function validating the given `fill_value` for the given dtype.
@@ -233,10 +255,9 @@ def _check_fill_value(fill_value, ndtype):
fields = ndtype.fields
if fill_value is None:
if fields:
- fdtype = [(_[0], _[1]) for _ in ndtype.descr]
- fill_value = np.array(tuple([default_fill_value(fields[n][0])
- for n in ndtype.names]),
- dtype=fdtype)
+ descr = ndtype.descr
+ fill_value = np.array(_recursive_set_default_fill_value(descr),
+ dtype=ndtype)
else:
fill_value = default_fill_value(ndtype)
elif fields:
@@ -248,10 +269,9 @@ def _check_fill_value(fill_value, ndtype):
err_msg = "Unable to transform %s to dtype %s"
raise ValueError(err_msg % (fill_value, fdtype))
else:
- fval = np.resize(fill_value, len(ndtype.descr))
- fill_value = [np.asarray(f).astype(desc[1]).item()
- for (f, desc) in zip(fval, ndtype.descr)]
- fill_value = np.array(tuple(fill_value), copy=False, dtype=fdtype)
+ descr = ndtype.descr
+ fill_value = np.array(_recursive_set_fill_value(fill_value, descr),
+ dtype=ndtype)
else:
if isinstance(fill_value, basestring) and (ndtype.char not in 'SV'):
fill_value = default_fill_value(ndtype)
@@ -831,35 +851,35 @@ mod = _DomainedBinaryOperation(umath.mod, _DomainSafeDivide(), 0, 1)
#####--------------------------------------------------------------------------
#---- --- Mask creation functions ---
#####--------------------------------------------------------------------------
+def _recursive_make_descr(datatype, newtype=bool_):
+ "Private function allowing recursion in make_descr."
+ # Do we have some name fields ?
+ if datatype.names:
+ descr = []
+ for name in datatype.names:
+ field = datatype.fields[name]
+ if len(field) == 3:
+ # Prepend the title to the name
+ name = (field[-1], name)
+ descr.append((name, _recursive_make_descr(field[0], newtype)))
+ return descr
+ # Is this some kind of composite a la (np.float,2)
+ elif datatype.subdtype:
+ mdescr = list(datatype.subdtype)
+ mdescr[0] = newtype
+ return tuple(mdescr)
+ else:
+ return newtype
def make_mask_descr(ndtype):
"""Constructs a dtype description list from a given dtype.
Each field is set to a bool.
"""
- def _make_descr(datatype):
- "Private function allowing recursion."
- # Do we have some name fields ?
- if datatype.names:
- descr = []
- for name in datatype.names:
- field = datatype.fields[name]
- if len(field) == 3:
- # Prepend the title to the name
- name = (field[-1], name)
- descr.append((name, _make_descr(field[0])))
- return descr
- # Is this some kind of composite a la (np.float,2)
- elif datatype.subdtype:
- mdescr = list(datatype.subdtype)
- mdescr[0] = np.dtype(bool)
- return tuple(mdescr)
- else:
- return np.bool
# Make sure we do have a dtype
if not isinstance(ndtype, np.dtype):
ndtype = np.dtype(ndtype)
- return np.dtype(_make_descr(ndtype))
+ return np.dtype(_recursive_make_descr(ndtype, np.bool))
def get_mask(a):
"""Return the mask of a, if any, or nomask.
@@ -988,7 +1008,17 @@ def mask_or (m1, m2, copy=False, shrink=True):
ValueError
If m1 and m2 have different flexible dtypes.
- """
+ """
+ def _recursive_mask_or(m1, m2, newmask):
+ names = m1.dtype.names
+ for name in names:
+ current1 = m1[name]
+ if current1.dtype.names:
+ _recursive_mask_or(current1, m2[name], newmask[name])
+ else:
+ umath.logical_or(current1, m2[name], newmask[name])
+ return
+ #
if (m1 is nomask) or (m1 is False):
dtype = getattr(m2, 'dtype', MaskType)
return make_mask(m2, copy=copy, shrink=shrink, dtype=dtype)
@@ -1002,8 +1032,7 @@ def mask_or (m1, m2, copy=False, shrink=True):
raise ValueError("Incompatible dtypes '%s'<>'%s'" % (dtype1, dtype2))
if dtype1.names:
newmask = np.empty_like(m1)
- for n in dtype1.names:
- newmask[n] = umath.logical_or(m1[n], m2[n])
+ _recursive_mask_or(m1, m2, newmask)
return newmask
return make_mask(umath.logical_or(m1, m2), copy=copy, shrink=shrink)
@@ -1291,6 +1320,22 @@ class _MaskedPrintOption:
#if you single index into a masked location you get this object.
masked_print_option = _MaskedPrintOption('--')
+
+def _recursive_printoption(result, mask, printopt):
+ """
+ Puts printoptions in result where mask is True.
+ Private function allowing for recursion
+ """
+ names = result.dtype.names
+ for name in names:
+ (curdata, curmask) = (result[name], mask[name])
+ if curdata.dtype.names:
+ _recursive_printoption(curdata, curmask, printopt)
+ else:
+ np.putmask(curdata, curmask, printopt)
+ return
+
+
#####--------------------------------------------------------------------------
#---- --- MaskedArray class ---
#####--------------------------------------------------------------------------
@@ -2184,13 +2229,9 @@ class MaskedArray(ndarray):
res = self._data.astype("|O8")
res[m] = f
else:
- rdtype = [list(_) for _ in self.dtype.descr]
- for r in rdtype:
- r[1] = '|O8'
- rdtype = [tuple(_) for _ in rdtype]
+ rdtype = _recursive_make_descr(self.dtype, "|O8")
res = self._data.astype(rdtype)
- for field in names:
- np.putmask(res[field], m[field], f)
+ _recursive_printoption(res, m, f)
else:
res = self.filled(self.fill_value)
return str(res)
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 8a0ca08ac..5c65ef5f4 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -483,6 +483,16 @@ class TestMaskedArray(TestCase):
y._optinfo['info'] = '!!!'
assert_equal(x._optinfo['info'], '???')
+
+ def test_fancy_printoptions(self):
+ "Test printing a masked array w/ fancy dtype."
+ fancydtype = np.dtype([('x', int), ('y', [('t', int), ('s', float)])])
+ test = array([(1, (2, 3.0)), (4, (5, 6.0))],
+ mask=[(1, (0, 1)), (0, (1, 0))],
+ dtype=fancydtype)
+ control = "[(--, (2, --)) (4, (--, 6.0))]"
+ assert_equal(str(test), control)
+
#------------------------------------------------------------------------------
class TestMaskedArrayArithmetic(TestCase):
@@ -1049,19 +1059,19 @@ class TestFillingValues(TestCase):
# The shape shouldn't matter
ndtype = [('f0', float, (2, 2))]
control = np.array((default_fill_value(0.),),
- dtype=[('f0',float)])
+ dtype=[('f0',float)]).astype(ndtype)
assert_equal(_check_fill_value(None, ndtype), control)
- control = np.array((0,), dtype=[('f0',float)])
+ control = np.array((0,), dtype=[('f0',float)]).astype(ndtype)
assert_equal(_check_fill_value(0, ndtype), control)
#
ndtype = np.dtype("int, (2,3)float, float")
control = np.array((default_fill_value(0),
default_fill_value(0.),
default_fill_value(0.),),
- dtype="int, float, float")
+ dtype="int, float, float").astype(ndtype)
test = _check_fill_value(None, ndtype)
assert_equal(test, control)
- control = np.array((0,0,0), dtype="int, float, float")
+ control = np.array((0,0,0), dtype="int, float, float").astype(ndtype)
assert_equal(_check_fill_value(0, ndtype), control)
#------------------------------------------------------------------------------
@@ -1912,8 +1922,8 @@ class TestMaskedArrayMethods(TestCase):
dtype=ndtype)
data[[0,1,2,-1]] = masked
record = data.torecords()
- assert_equal(record['_data'], data._data)
- assert_equal(record['_mask'], data._mask)
+ assert_equal_records(record['_data'], data._data)
+ assert_equal_records(record['_mask'], data._mask)
#------------------------------------------------------------------------------
@@ -2531,6 +2541,12 @@ class TestMaskedArrayFunctions(TestCase):
test = mask_or(mask, other)
except ValueError:
pass
+ # Using nested arrays
+ dtype = [('a', np.bool), ('b', [('ba', np.bool), ('bb', np.bool)])]
+ amask = np.array([(0, (1, 0)), (0, (1, 0))], dtype=dtype)
+ bmask = np.array([(1, (0, 1)), (0, (0, 0))], dtype=dtype)
+ cntrl = np.array([(1, (1, 1)), (0, (1, 0))], dtype=dtype)
+ assert_equal(mask_or(amask, bmask), cntrl)
def test_flatten_mask(self):
diff --git a/numpy/ma/tests/test_mrecords.py b/numpy/ma/tests/test_mrecords.py
index 6e9248953..769c5330e 100644
--- a/numpy/ma/tests/test_mrecords.py
+++ b/numpy/ma/tests/test_mrecords.py
@@ -334,8 +334,8 @@ class TestMRecords(TestCase):
mult[0] = masked
mult[1] = (1, 1, 1)
mult.filled(0)
- assert_equal(mult.filled(0),
- np.array([(0,0,0),(1,1,1)], dtype=mult.dtype))
+ assert_equal_records(mult.filled(0),
+ np.array([(0,0,0),(1,1,1)], dtype=mult.dtype))
class TestView(TestCase):
diff --git a/numpy/ma/testutils.py b/numpy/ma/testutils.py
index df7484f0a..5234e0db4 100644
--- a/numpy/ma/testutils.py
+++ b/numpy/ma/testutils.py
@@ -110,14 +110,14 @@ def assert_equal(actual,desired,err_msg=''):
return _assert_equal_on_sequences(actual.tolist(),
desired.tolist(),
err_msg='')
- elif actual_dtype.char in "OV" and desired_dtype.char in "OV":
- if (actual_dtype != desired_dtype) and actual_dtype:
- msg = build_err_msg([actual_dtype, desired_dtype],
- err_msg, header='', names=('actual', 'desired'))
- raise ValueError(msg)
- return _assert_equal_on_sequences(actual.tolist(),
- desired.tolist(),
- err_msg='')
+# elif actual_dtype.char in "OV" and desired_dtype.char in "OV":
+# if (actual_dtype != desired_dtype) and actual_dtype:
+# msg = build_err_msg([actual_dtype, desired_dtype],
+# err_msg, header='', names=('actual', 'desired'))
+# raise ValueError(msg)
+# return _assert_equal_on_sequences(actual.tolist(),
+# desired.tolist(),
+# err_msg='')
return assert_array_equal(actual, desired, err_msg)
@@ -171,15 +171,14 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
# yf = filled(y)
# Allocate a common mask and refill
m = mask_or(getmask(x), getmask(y))
- x = masked_array(x, copy=False, mask=m, subok=False)
- y = masked_array(y, copy=False, mask=m, subok=False)
+ x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False)
+ y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False)
if ((x is masked) and not (y is masked)) or \
((y is masked) and not (x is masked)):
msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose,
header=header, names=('x', 'y'))
raise ValueError(msg)
# OK, now run the basic tests on filled versions
- comparison = getattr(np, comparison.__name__, lambda x,y: True)
return utils.assert_array_compare(comparison,
x.filled(fill_value),
y.filled(fill_value),
@@ -189,7 +188,8 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
def assert_array_equal(x, y, err_msg='', verbose=True):
"""Checks the elementwise equality of two masked arrays."""
- assert_array_compare(equal, x, y, err_msg=err_msg, verbose=verbose,
+ assert_array_compare(operator.__eq__, x, y,
+ err_msg=err_msg, verbose=verbose,
header='Arrays are not equal')
@@ -223,7 +223,8 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
def assert_array_less(x, y, err_msg='', verbose=True):
"Checks that x is smaller than y elementwise."
- assert_array_compare(less, x, y, err_msg=err_msg, verbose=verbose,
+ assert_array_compare(operator.__lt__, x, y,
+ err_msg=err_msg, verbose=verbose,
header='Arrays are not less-ordered')