"""Miscellaneous functions for testing masked arrays and subclasses :author: Pierre Gerard-Marchant :contact: pierregm_at_uga_dot_edu :version: $Id: testutils.py 3529 2007-11-13 08:01:14Z jarrod.millman $ """ __author__ = "Pierre GF Gerard-Marchant ($Author: jarrod.millman $)" __version__ = "1.0" __revision__ = "$Revision: 3529 $" __date__ = "$Date: 2007-11-13 10:01:14 +0200 (Tue, 13 Nov 2007) $" import numpy as N from numpy.core import ndarray from numpy.core.numerictypes import float_ import numpy.core.umath as umath from numpy.testing import NumpyTest, NumpyTestCase from numpy.testing.utils import build_err_msg, rand import core from core import mask_or, getmask, getmaskarray, masked_array, nomask, masked from core import filled, equal, less #------------------------------------------------------------------------------ def approx (a, b, fill_value=True, rtol=1.e-5, atol=1.e-8): """Returns true if all components of a and b are equal subject to given tolerances. If fill_value is True, masked values considered equal. Otherwise, masked values are considered unequal. The relative error rtol should be positive and << 1.0 The absolute error atol comes into play for those elements of b that are very small or zero; it says how small a must be also. """ m = mask_or(getmask(a), getmask(b)) d1 = filled(a) d2 = filled(b) if d1.dtype.char == "O" or d2.dtype.char == "O": return N.equal(d1,d2).ravel() x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) d = N.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y)) return d.ravel() def almost(a, b, decimal=6, fill_value=True): """Returns True if a and b are equal up to decimal places. If fill_value is True, masked values considered equal. Otherwise, masked values are considered unequal. """ m = mask_or(getmask(a), getmask(b)) d1 = filled(a) d2 = filled(b) if d1.dtype.char == "O" or d2.dtype.char == "O": return N.equal(d1,d2).ravel() x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) d = N.around(N.abs(x-y),decimal) <= 10.0**(-decimal) return d.ravel() #................................................ def _assert_equal_on_sequences(actual, desired, err_msg=''): "Asserts the equality of two non-array sequences." assert_equal(len(actual),len(desired),err_msg) for k in range(len(desired)): assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg)) return def assert_equal_records(a,b): """Asserts that two records are equal. Pretty crude for now.""" assert_equal(a.dtype, b.dtype) for f in a.dtype.names: (af, bf) = (getattr(a,f), getattr(b,f)) if not (af is masked) and not (bf is masked): assert_equal(getattr(a,f), getattr(b,f)) return def assert_equal(actual,desired,err_msg=''): """Asserts that two items are equal. """ # Case #1: dictionary ..... if isinstance(desired, dict): assert isinstance(actual, dict), repr(type(actual)) assert_equal(len(actual),len(desired),err_msg) for k,i in desired.items(): assert k in actual, repr(k) assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k,err_msg)) return # Case #2: lists ..... if isinstance(desired, (list,tuple)) and isinstance(actual, (list,tuple)): return _assert_equal_on_sequences(actual, desired, err_msg='') if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)): msg = build_err_msg([actual, desired], err_msg,) assert desired == actual, msg return # Case #4. arrays or equivalent if ((actual is masked) and not (desired is masked)) or \ ((desired is masked) and not (actual is masked)): msg = build_err_msg([actual, desired], err_msg, header='', names=('x', 'y')) raise ValueError(msg) actual = N.array(actual, copy=False, subok=True) desired = N.array(desired, copy=False, subok=True) if actual.dtype.char in "OS" and desired.dtype.char in "OS": return _assert_equal_on_sequences(actual.tolist(), desired.tolist(), err_msg='') return assert_array_equal(actual, desired, err_msg) #............................. def fail_if_equal(actual,desired,err_msg='',): """Raises an assertion error if two items are equal. """ if isinstance(desired, dict): assert isinstance(actual, dict), repr(type(actual)) fail_if_equal(len(actual),len(desired),err_msg) for k,i in desired.items(): assert k in actual, repr(k) fail_if_equal(actual[k], desired[k], 'key=%r\n%s' % (k,err_msg)) return if isinstance(desired, (list,tuple)) and isinstance(actual, (list,tuple)): fail_if_equal(len(actual),len(desired),err_msg) for k in range(len(desired)): fail_if_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg)) return if isinstance(actual, N.ndarray) or isinstance(desired, N.ndarray): return fail_if_array_equal(actual, desired, err_msg) msg = build_err_msg([actual, desired], err_msg) assert desired != actual, msg assert_not_equal = fail_if_equal #............................ def assert_almost_equal(actual,desired,decimal=7,err_msg=''): """Asserts that two items are almost equal. The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal) """ if isinstance(actual, N.ndarray) or isinstance(desired, N.ndarray): return assert_array_almost_equal(actual, desired, decimal, err_msg) msg = build_err_msg([actual, desired], err_msg) assert round(abs(desired - actual),decimal) == 0, msg #............................ def assert_array_compare(comparison, x, y, err_msg='', header='', fill_value=True): """Asserts that a comparison relation between two masked arrays is satisfied elementwise.""" xf = filled(x) yf = filled(y) m = mask_or(getmask(x), getmask(y)) x = masked_array(xf, copy=False, subok=False, mask=m).filled(fill_value) y = masked_array(yf, copy=False, subok=False, mask=m).filled(fill_value) if ((x is masked) and not (y is masked)) or \ ((y is masked) and not (x is masked)): msg = build_err_msg([x, y], err_msg, header=header, names=('x', 'y')) raise ValueError(msg) if (x.dtype.char != "O") and (x.dtype.char != "S"): x = x.astype(float_) if isinstance(x, N.ndarray) and x.size > 1: x[N.isnan(x)] = 0 elif N.isnan(x): x = 0 if (y.dtype.char != "O") and (y.dtype.char != "S"): y = y.astype(float_) if isinstance(y, N.ndarray) and y.size > 1: y[N.isnan(y)] = 0 elif N.isnan(y): y = 0 try: cond = (x.shape==() or y.shape==()) or x.shape == y.shape if not cond: msg = build_err_msg([x, y], err_msg + '\n(shapes %s, %s mismatch)' % (x.shape, y.shape), header=header, names=('x', 'y')) assert cond, msg val = comparison(x,y) if m is not nomask and fill_value: val = masked_array(val, mask=m, copy=False) if isinstance(val, bool): cond = val reduced = [0] else: reduced = val.ravel() cond = reduced.all() reduced = reduced.tolist() if not cond: match = 100-100.0*reduced.count(1)/len(reduced) msg = build_err_msg([x, y], err_msg + '\n(mismatch %s%%)' % (match,), header=header, names=('x', 'y')) assert cond, msg except ValueError: msg = build_err_msg([x, y], err_msg, header=header, names=('x', 'y')) raise ValueError(msg) #............................ def assert_array_equal(x, y, err_msg=''): """Checks the elementwise equality of two masked arrays.""" assert_array_compare(equal, x, y, err_msg=err_msg, header='Arrays are not equal') ##............................ def fail_if_array_equal(x, y, err_msg=''): "Raises an assertion error if two masked arrays are not equal (elementwise)." def compare(x,y): return (not N.alltrue(approx(x, y))) assert_array_compare(compare, x, y, err_msg=err_msg, header='Arrays are not equal') #............................ def assert_array_approx_equal(x, y, decimal=6, err_msg=''): """Checks the elementwise equality of two masked arrays, up to a given number of decimals.""" def compare(x, y): "Returns the result of the loose comparison between x and y)." return approx(x,y, rtol=10.**-decimal) assert_array_compare(compare, x, y, err_msg=err_msg, header='Arrays are not almost equal') #............................ def assert_array_almost_equal(x, y, decimal=6, err_msg=''): """Checks the elementwise equality of two masked arrays, up to a given number of decimals.""" def compare(x, y): "Returns the result of the loose comparison between x and y)." return almost(x,y,decimal) assert_array_compare(compare, x, y, err_msg=err_msg, header='Arrays are not almost equal') #............................ def assert_array_less(x, y, err_msg=''): "Checks that x is smaller than y elementwise." assert_array_compare(less, x, y, err_msg=err_msg, header='Arrays are not less-ordered') #............................ assert_close = assert_almost_equal #............................ def assert_mask_equal(m1, m2): """Asserts the equality of two masks.""" if m1 is nomask: assert(m2 is nomask) if m2 is nomask: assert(m1 is nomask) assert_array_equal(m1, m2) if __name__ == '__main__': pass