From 8e77ab879c8cb1f25f0245c74bc9bba402c40b51 Mon Sep 17 00:00:00 2001 From: pierregm Date: Tue, 23 Dec 2008 23:43:43 +0000 Subject: testutils: * assert_equal : use assert_equal_array on records * assert_array_compare : prevent the common mask to be back-propagated to the initial input arrays. * assert_equal_array : use operator.__eq__ instead of ma.equal * assert_equal_less: use operator.__less__ instead of ma.less core: * Fixed _check_fill_value for nested flexible types * Add a ndtype option to _make_mask_descr * Fixed mask_or for nested flexible types * Fixed the printing of masked arrays w/ flexible types. --- numpy/ma/core.py | 115 +++++++++++++++++++++++++++------------- numpy/ma/tests/test_core.py | 28 +++++++--- numpy/ma/tests/test_mrecords.py | 4 +- numpy/ma/testutils.py | 27 +++++----- 4 files changed, 116 insertions(+), 58 deletions(-) (limited to 'numpy') diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 6b4dc98e6..8ee19778c 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -217,6 +217,28 @@ def maximum_fill_value(obj): raise TypeError(errmsg) +def _recursive_set_default_fill_value(dtypedescr): + deflist = [] + for currentdescr in dtypedescr: + currenttype = currentdescr[1] + if isinstance(currenttype, list): + deflist.append(tuple(_recursive_set_default_fill_value(currenttype))) + else: + deflist.append(default_fill_value(np.dtype(currenttype))) + return tuple(deflist) + +def _recursive_set_fill_value(fillvalue, dtypedescr): + fillvalue = np.resize(fillvalue, len(dtypedescr)) + output_value = [] + for (fval, descr) in zip(fillvalue, dtypedescr): + cdtype = descr[1] + if isinstance(cdtype, list): + output_value.append(tuple(_recursive_set_fill_value(fval, cdtype))) + else: + output_value.append(np.array(fval, dtype=cdtype).item()) + return tuple(output_value) + + def _check_fill_value(fill_value, ndtype): """ Private function validating the given `fill_value` for the given dtype. @@ -233,10 +255,9 @@ def _check_fill_value(fill_value, ndtype): fields = ndtype.fields if fill_value is None: if fields: - fdtype = [(_[0], _[1]) for _ in ndtype.descr] - fill_value = np.array(tuple([default_fill_value(fields[n][0]) - for n in ndtype.names]), - dtype=fdtype) + descr = ndtype.descr + fill_value = np.array(_recursive_set_default_fill_value(descr), + dtype=ndtype) else: fill_value = default_fill_value(ndtype) elif fields: @@ -248,10 +269,9 @@ def _check_fill_value(fill_value, ndtype): err_msg = "Unable to transform %s to dtype %s" raise ValueError(err_msg % (fill_value, fdtype)) else: - fval = np.resize(fill_value, len(ndtype.descr)) - fill_value = [np.asarray(f).astype(desc[1]).item() - for (f, desc) in zip(fval, ndtype.descr)] - fill_value = np.array(tuple(fill_value), copy=False, dtype=fdtype) + descr = ndtype.descr + fill_value = np.array(_recursive_set_fill_value(fill_value, descr), + dtype=ndtype) else: if isinstance(fill_value, basestring) and (ndtype.char not in 'SV'): fill_value = default_fill_value(ndtype) @@ -831,35 +851,35 @@ mod = _DomainedBinaryOperation(umath.mod, _DomainSafeDivide(), 0, 1) #####-------------------------------------------------------------------------- #---- --- Mask creation functions --- #####-------------------------------------------------------------------------- +def _recursive_make_descr(datatype, newtype=bool_): + "Private function allowing recursion in make_descr." + # Do we have some name fields ? + if datatype.names: + descr = [] + for name in datatype.names: + field = datatype.fields[name] + if len(field) == 3: + # Prepend the title to the name + name = (field[-1], name) + descr.append((name, _recursive_make_descr(field[0], newtype))) + return descr + # Is this some kind of composite a la (np.float,2) + elif datatype.subdtype: + mdescr = list(datatype.subdtype) + mdescr[0] = newtype + return tuple(mdescr) + else: + return newtype def make_mask_descr(ndtype): """Constructs a dtype description list from a given dtype. Each field is set to a bool. """ - def _make_descr(datatype): - "Private function allowing recursion." - # Do we have some name fields ? - if datatype.names: - descr = [] - for name in datatype.names: - field = datatype.fields[name] - if len(field) == 3: - # Prepend the title to the name - name = (field[-1], name) - descr.append((name, _make_descr(field[0]))) - return descr - # Is this some kind of composite a la (np.float,2) - elif datatype.subdtype: - mdescr = list(datatype.subdtype) - mdescr[0] = np.dtype(bool) - return tuple(mdescr) - else: - return np.bool # Make sure we do have a dtype if not isinstance(ndtype, np.dtype): ndtype = np.dtype(ndtype) - return np.dtype(_make_descr(ndtype)) + return np.dtype(_recursive_make_descr(ndtype, np.bool)) def get_mask(a): """Return the mask of a, if any, or nomask. @@ -988,7 +1008,17 @@ def mask_or (m1, m2, copy=False, shrink=True): ValueError If m1 and m2 have different flexible dtypes. - """ + """ + def _recursive_mask_or(m1, m2, newmask): + names = m1.dtype.names + for name in names: + current1 = m1[name] + if current1.dtype.names: + _recursive_mask_or(current1, m2[name], newmask[name]) + else: + umath.logical_or(current1, m2[name], newmask[name]) + return + # if (m1 is nomask) or (m1 is False): dtype = getattr(m2, 'dtype', MaskType) return make_mask(m2, copy=copy, shrink=shrink, dtype=dtype) @@ -1002,8 +1032,7 @@ def mask_or (m1, m2, copy=False, shrink=True): raise ValueError("Incompatible dtypes '%s'<>'%s'" % (dtype1, dtype2)) if dtype1.names: newmask = np.empty_like(m1) - for n in dtype1.names: - newmask[n] = umath.logical_or(m1[n], m2[n]) + _recursive_mask_or(m1, m2, newmask) return newmask return make_mask(umath.logical_or(m1, m2), copy=copy, shrink=shrink) @@ -1291,6 +1320,22 @@ class _MaskedPrintOption: #if you single index into a masked location you get this object. masked_print_option = _MaskedPrintOption('--') + +def _recursive_printoption(result, mask, printopt): + """ + Puts printoptions in result where mask is True. + Private function allowing for recursion + """ + names = result.dtype.names + for name in names: + (curdata, curmask) = (result[name], mask[name]) + if curdata.dtype.names: + _recursive_printoption(curdata, curmask, printopt) + else: + np.putmask(curdata, curmask, printopt) + return + + #####-------------------------------------------------------------------------- #---- --- MaskedArray class --- #####-------------------------------------------------------------------------- @@ -2184,13 +2229,9 @@ class MaskedArray(ndarray): res = self._data.astype("|O8") res[m] = f else: - rdtype = [list(_) for _ in self.dtype.descr] - for r in rdtype: - r[1] = '|O8' - rdtype = [tuple(_) for _ in rdtype] + rdtype = _recursive_make_descr(self.dtype, "|O8") res = self._data.astype(rdtype) - for field in names: - np.putmask(res[field], m[field], f) + _recursive_printoption(res, m, f) else: res = self.filled(self.fill_value) return str(res) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 8a0ca08ac..5c65ef5f4 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -483,6 +483,16 @@ class TestMaskedArray(TestCase): y._optinfo['info'] = '!!!' assert_equal(x._optinfo['info'], '???') + + def test_fancy_printoptions(self): + "Test printing a masked array w/ fancy dtype." + fancydtype = np.dtype([('x', int), ('y', [('t', int), ('s', float)])]) + test = array([(1, (2, 3.0)), (4, (5, 6.0))], + mask=[(1, (0, 1)), (0, (1, 0))], + dtype=fancydtype) + control = "[(--, (2, --)) (4, (--, 6.0))]" + assert_equal(str(test), control) + #------------------------------------------------------------------------------ class TestMaskedArrayArithmetic(TestCase): @@ -1049,19 +1059,19 @@ class TestFillingValues(TestCase): # The shape shouldn't matter ndtype = [('f0', float, (2, 2))] control = np.array((default_fill_value(0.),), - dtype=[('f0',float)]) + dtype=[('f0',float)]).astype(ndtype) assert_equal(_check_fill_value(None, ndtype), control) - control = np.array((0,), dtype=[('f0',float)]) + control = np.array((0,), dtype=[('f0',float)]).astype(ndtype) assert_equal(_check_fill_value(0, ndtype), control) # ndtype = np.dtype("int, (2,3)float, float") control = np.array((default_fill_value(0), default_fill_value(0.), default_fill_value(0.),), - dtype="int, float, float") + dtype="int, float, float").astype(ndtype) test = _check_fill_value(None, ndtype) assert_equal(test, control) - control = np.array((0,0,0), dtype="int, float, float") + control = np.array((0,0,0), dtype="int, float, float").astype(ndtype) assert_equal(_check_fill_value(0, ndtype), control) #------------------------------------------------------------------------------ @@ -1912,8 +1922,8 @@ class TestMaskedArrayMethods(TestCase): dtype=ndtype) data[[0,1,2,-1]] = masked record = data.torecords() - assert_equal(record['_data'], data._data) - assert_equal(record['_mask'], data._mask) + assert_equal_records(record['_data'], data._data) + assert_equal_records(record['_mask'], data._mask) #------------------------------------------------------------------------------ @@ -2531,6 +2541,12 @@ class TestMaskedArrayFunctions(TestCase): test = mask_or(mask, other) except ValueError: pass + # Using nested arrays + dtype = [('a', np.bool), ('b', [('ba', np.bool), ('bb', np.bool)])] + amask = np.array([(0, (1, 0)), (0, (1, 0))], dtype=dtype) + bmask = np.array([(1, (0, 1)), (0, (0, 0))], dtype=dtype) + cntrl = np.array([(1, (1, 1)), (0, (1, 0))], dtype=dtype) + assert_equal(mask_or(amask, bmask), cntrl) def test_flatten_mask(self): diff --git a/numpy/ma/tests/test_mrecords.py b/numpy/ma/tests/test_mrecords.py index 6e9248953..769c5330e 100644 --- a/numpy/ma/tests/test_mrecords.py +++ b/numpy/ma/tests/test_mrecords.py @@ -334,8 +334,8 @@ class TestMRecords(TestCase): mult[0] = masked mult[1] = (1, 1, 1) mult.filled(0) - assert_equal(mult.filled(0), - np.array([(0,0,0),(1,1,1)], dtype=mult.dtype)) + assert_equal_records(mult.filled(0), + np.array([(0,0,0),(1,1,1)], dtype=mult.dtype)) class TestView(TestCase): diff --git a/numpy/ma/testutils.py b/numpy/ma/testutils.py index df7484f0a..5234e0db4 100644 --- a/numpy/ma/testutils.py +++ b/numpy/ma/testutils.py @@ -110,14 +110,14 @@ def assert_equal(actual,desired,err_msg=''): return _assert_equal_on_sequences(actual.tolist(), desired.tolist(), err_msg='') - elif actual_dtype.char in "OV" and desired_dtype.char in "OV": - if (actual_dtype != desired_dtype) and actual_dtype: - msg = build_err_msg([actual_dtype, desired_dtype], - err_msg, header='', names=('actual', 'desired')) - raise ValueError(msg) - return _assert_equal_on_sequences(actual.tolist(), - desired.tolist(), - err_msg='') +# elif actual_dtype.char in "OV" and desired_dtype.char in "OV": +# if (actual_dtype != desired_dtype) and actual_dtype: +# msg = build_err_msg([actual_dtype, desired_dtype], +# err_msg, header='', names=('actual', 'desired')) +# raise ValueError(msg) +# return _assert_equal_on_sequences(actual.tolist(), +# desired.tolist(), +# err_msg='') return assert_array_equal(actual, desired, err_msg) @@ -171,15 +171,14 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', # yf = filled(y) # Allocate a common mask and refill m = mask_or(getmask(x), getmask(y)) - x = masked_array(x, copy=False, mask=m, subok=False) - y = masked_array(y, copy=False, mask=m, subok=False) + x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False) + y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False) if ((x is masked) and not (y is masked)) or \ ((y is masked) and not (x is masked)): msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose, header=header, names=('x', 'y')) raise ValueError(msg) # OK, now run the basic tests on filled versions - comparison = getattr(np, comparison.__name__, lambda x,y: True) return utils.assert_array_compare(comparison, x.filled(fill_value), y.filled(fill_value), @@ -189,7 +188,8 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', def assert_array_equal(x, y, err_msg='', verbose=True): """Checks the elementwise equality of two masked arrays.""" - assert_array_compare(equal, x, y, err_msg=err_msg, verbose=verbose, + assert_array_compare(operator.__eq__, x, y, + err_msg=err_msg, verbose=verbose, header='Arrays are not equal') @@ -223,7 +223,8 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): def assert_array_less(x, y, err_msg='', verbose=True): "Checks that x is smaller than y elementwise." - assert_array_compare(less, x, y, err_msg=err_msg, verbose=verbose, + assert_array_compare(operator.__lt__, x, y, + err_msg=err_msg, verbose=verbose, header='Arrays are not less-ordered') -- cgit v1.2.1