diff options
author | Stefan van der Walt <stefan@sun.ac.za> | 2007-12-15 01:15:26 +0000 |
---|---|---|
committer | Stefan van der Walt <stefan@sun.ac.za> | 2007-12-15 01:15:26 +0000 |
commit | 703e8d6323b19cbfeb96772c1e35f1cd68629336 (patch) | |
tree | 34bd23200d97ff43369d7d23d37c9c08c3d3a3b4 /numpy/ma/testutils.py | |
parent | 61f9f6d0fb169cadefe35ea2bdd783848aa771f5 (diff) | |
download | numpy-703e8d6323b19cbfeb96772c1e35f1cd68629336.tar.gz |
Move ma to numpy root. Fix unit tests. Remove references to numpy.core.ma.
Diffstat (limited to 'numpy/ma/testutils.py')
-rw-r--r-- | numpy/ma/testutils.py | 220 |
1 files changed, 220 insertions, 0 deletions
diff --git a/numpy/ma/testutils.py b/numpy/ma/testutils.py new file mode 100644 index 000000000..d51e8dab1 --- /dev/null +++ b/numpy/ma/testutils.py @@ -0,0 +1,220 @@ +"""Miscellaneous functions for testing masked arrays and subclasses + +:author: Pierre Gerard-Marchant +:contact: pierregm_at_uga_dot_edu +:version: $Id: testutils.py 3529 2007-11-13 08:01:14Z jarrod.millman $ +""" +__author__ = "Pierre GF Gerard-Marchant ($Author: jarrod.millman $)" +__version__ = "1.0" +__revision__ = "$Revision: 3529 $" +__date__ = "$Date: 2007-11-13 10:01:14 +0200 (Tue, 13 Nov 2007) $" + + +import numpy as N +from numpy.core import ndarray +from numpy.core.numerictypes import float_ +import numpy.core.umath as umath +from numpy.testing import NumpyTest, NumpyTestCase +from numpy.testing.utils import build_err_msg, rand + +import core +from core import mask_or, getmask, getmaskarray, masked_array, nomask, masked +from core import filled, equal, less + +#------------------------------------------------------------------------------ +def approx (a, b, fill_value=True, rtol=1.e-5, atol=1.e-8): + """Returns true if all components of a and b are equal subject to given tolerances. + +If fill_value is True, masked values considered equal. Otherwise, masked values +are considered unequal. +The relative error rtol should be positive and << 1.0 +The absolute error atol comes into play for those elements of b that are very +small or zero; it says how small a must be also. + """ + m = mask_or(getmask(a), getmask(b)) + d1 = filled(a) + d2 = filled(b) + if d1.dtype.char == "O" or d2.dtype.char == "O": + return N.equal(d1,d2).ravel() + x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) + y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) + d = N.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y)) + return d.ravel() +#................................................ +def _assert_equal_on_sequences(actual, desired, err_msg=''): + "Asserts the equality of two non-array sequences." + assert_equal(len(actual),len(desired),err_msg) + for k in range(len(desired)): + assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg)) + return + +def assert_equal_records(a,b): + """Asserts that two records are equal. Pretty crude for now.""" + assert_equal(a.dtype, b.dtype) + for f in a.dtype.names: + (af, bf) = (getattr(a,f), getattr(b,f)) + if not (af is masked) and not (bf is masked): + assert_equal(getattr(a,f), getattr(b,f)) + return + +def assert_equal(actual,desired,err_msg=''): + """Asserts that two items are equal. + """ + # Case #1: dictionary ..... + if isinstance(desired, dict): + assert isinstance(actual, dict), repr(type(actual)) + assert_equal(len(actual),len(desired),err_msg) + for k,i in desired.items(): + assert k in actual, repr(k) + assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k,err_msg)) + return + # Case #2: lists ..... + if isinstance(desired, (list,tuple)) and isinstance(actual, (list,tuple)): + return _assert_equal_on_sequences(actual, desired, err_msg='') + if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)): + msg = build_err_msg([actual, desired], err_msg,) + assert desired == actual, msg + return + # Case #4. arrays or equivalent + if ((actual is masked) and not (desired is masked)) or \ + ((desired is masked) and not (actual is masked)): + msg = build_err_msg([actual, desired], err_msg, header='', names=('x', 'y')) + raise ValueError(msg) + actual = N.array(actual, copy=False, subok=True) + desired = N.array(desired, copy=False, subok=True) + if actual.dtype.char in "OS" and desired.dtype.char in "OS": + return _assert_equal_on_sequences(actual.tolist(), + desired.tolist(), + err_msg='') + return assert_array_equal(actual, desired, err_msg) +#............................. +def fail_if_equal(actual,desired,err_msg='',): + """Raises an assertion error if two items are equal. + """ + if isinstance(desired, dict): + assert isinstance(actual, dict), repr(type(actual)) + fail_if_equal(len(actual),len(desired),err_msg) + for k,i in desired.items(): + assert k in actual, repr(k) + fail_if_equal(actual[k], desired[k], 'key=%r\n%s' % (k,err_msg)) + return + if isinstance(desired, (list,tuple)) and isinstance(actual, (list,tuple)): + fail_if_equal(len(actual),len(desired),err_msg) + for k in range(len(desired)): + fail_if_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg)) + return + if isinstance(actual, N.ndarray) or isinstance(desired, N.ndarray): + return fail_if_array_equal(actual, desired, err_msg) + msg = build_err_msg([actual, desired], err_msg) + assert desired != actual, msg +assert_not_equal = fail_if_equal +#............................ +def assert_almost_equal(actual,desired,decimal=7,err_msg=''): + """Asserts that two items are almost equal. + The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal) + """ + if isinstance(actual, N.ndarray) or isinstance(desired, N.ndarray): + return assert_array_almost_equal(actual, desired, decimal, err_msg) + msg = build_err_msg([actual, desired], err_msg) + assert round(abs(desired - actual),decimal) == 0, msg +#............................ +def assert_array_compare(comparison, x, y, err_msg='', header='', + fill_value=True): + """Asserts that a comparison relation between two masked arrays is satisfied + elementwise.""" + xf = filled(x) + yf = filled(y) + m = mask_or(getmask(x), getmask(y)) + + x = masked_array(xf, copy=False, subok=False, mask=m).filled(fill_value) + y = masked_array(yf, copy=False, subok=False, mask=m).filled(fill_value) + + if ((x is masked) and not (y is masked)) or \ + ((y is masked) and not (x is masked)): + msg = build_err_msg([x, y], err_msg, header=header, names=('x', 'y')) + raise ValueError(msg) + + if (x.dtype.char != "O") and (x.dtype.char != "S"): + x = x.astype(float_) + if isinstance(x, N.ndarray) and x.size > 1: + x[N.isnan(x)] = 0 + elif N.isnan(x): + x = 0 + if (y.dtype.char != "O") and (y.dtype.char != "S"): + y = y.astype(float_) + if isinstance(y, N.ndarray) and y.size > 1: + y[N.isnan(y)] = 0 + elif N.isnan(y): + y = 0 + try: + cond = (x.shape==() or y.shape==()) or x.shape == y.shape + if not cond: + msg = build_err_msg([x, y], + err_msg + + '\n(shapes %s, %s mismatch)' % (x.shape, + y.shape), + header=header, + names=('x', 'y')) + assert cond, msg + val = comparison(x,y) + if m is not nomask and fill_value: + val = masked_array(val, mask=m, copy=False) + if isinstance(val, bool): + cond = val + reduced = [0] + else: + reduced = val.ravel() + cond = reduced.all() + reduced = reduced.tolist() + if not cond: + match = 100-100.0*reduced.count(1)/len(reduced) + msg = build_err_msg([x, y], + err_msg + + '\n(mismatch %s%%)' % (match,), + header=header, + names=('x', 'y')) + assert cond, msg + except ValueError: + msg = build_err_msg([x, y], err_msg, header=header, names=('x', 'y')) + raise ValueError(msg) +#............................ +def assert_array_equal(x, y, err_msg=''): + """Checks the elementwise equality of two masked arrays.""" + assert_array_compare(equal, x, y, err_msg=err_msg, + header='Arrays are not equal') +##............................ +def fail_if_array_equal(x, y, err_msg=''): + "Raises an assertion error if two masked arrays are not equal (elementwise)." + def compare(x,y): + + return (not N.alltrue(approx(x, y))) + assert_array_compare(compare, x, y, err_msg=err_msg, + header='Arrays are not equal') +#............................ +def assert_array_almost_equal(x, y, decimal=6, err_msg=''): + """Checks the elementwise equality of two masked arrays, up to a given + number of decimals.""" + def compare(x, y): + "Returns the result of the loose comparison between x and y)." + return approx(x,y, rtol=10.**-decimal) + assert_array_compare(compare, x, y, err_msg=err_msg, + header='Arrays are not almost equal') +#............................ +def assert_array_less(x, y, err_msg=''): + "Checks that x is smaller than y elementwise." + assert_array_compare(less, x, y, err_msg=err_msg, + header='Arrays are not less-ordered') +#............................ +assert_close = assert_almost_equal +#............................ +def assert_mask_equal(m1, m2): + """Asserts the equality of two masks.""" + if m1 is nomask: + assert(m2 is nomask) + if m2 is nomask: + assert(m1 is nomask) + assert_array_equal(m1, m2) + +if __name__ == '__main__': + a = 12 + assert_equal(a, masked) |