diff options
author | pierregm <pierregm@localhost> | 2008-04-30 19:36:42 +0000 |
---|---|---|
committer | pierregm <pierregm@localhost> | 2008-04-30 19:36:42 +0000 |
commit | 21a7341b383356ba99c10eda9654a4470e1247a0 (patch) | |
tree | 95160bb6eca3599eacc59ba8a8f9d412dc7bd418 /numpy/ma/testutils.py | |
parent | 233af718df4cf92bad9b93993e11faba5b6fff26 (diff) | |
download | numpy-21a7341b383356ba99c10eda9654a4470e1247a0.tar.gz |
core : fixed a bug w/ array((0,0))/0.
testutils : introduced assert_almost_equal/assert_approx_equal:
use assert_almost_equal(a,b,decimal) to compare a and b up to decimal places
use assert_approx_equal(a,b,decimal) to compare a and b up to b*10.**-decimal
Diffstat (limited to 'numpy/ma/testutils.py')
-rw-r--r-- | numpy/ma/testutils.py | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/numpy/ma/testutils.py b/numpy/ma/testutils.py index c4722f940..af0e818b2 100644 --- a/numpy/ma/testutils.py +++ b/numpy/ma/testutils.py @@ -40,6 +40,23 @@ small or zero; it says how small a must be also. y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) d = N.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y)) return d.ravel() + +def almost(a, b, decimal=6, fill_value=True): + """Returns True if a and b are equal up to decimal places. +If fill_value is True, masked values considered equal. Otherwise, masked values +are considered unequal. + """ + m = mask_or(getmask(a), getmask(b)) + d1 = filled(a) + d2 = filled(b) + if d1.dtype.char == "O" or d2.dtype.char == "O": + return N.equal(d1,d2).ravel() + x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) + y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) + d = N.around(N.abs(x-y),decimal) <= 10.0**(-decimal) + return d.ravel() + + #................................................ def _assert_equal_on_sequences(actual, desired, err_msg=''): "Asserts the equality of two non-array sequences." @@ -191,7 +208,7 @@ def fail_if_array_equal(x, y, err_msg=''): assert_array_compare(compare, x, y, err_msg=err_msg, header='Arrays are not equal') #............................ -def assert_array_almost_equal(x, y, decimal=6, err_msg=''): +def assert_array_approx_equal(x, y, decimal=6, err_msg=''): """Checks the elementwise equality of two masked arrays, up to a given number of decimals.""" def compare(x, y): @@ -200,6 +217,15 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg=''): assert_array_compare(compare, x, y, err_msg=err_msg, header='Arrays are not almost equal') #............................ +def assert_array_almost_equal(x, y, decimal=6, err_msg=''): + """Checks the elementwise equality of two masked arrays, up to a given + number of decimals.""" + def compare(x, y): + "Returns the result of the loose comparison between x and y)." + return almost(x,y,decimal) + assert_array_compare(compare, x, y, err_msg=err_msg, + header='Arrays are not almost equal') +#............................ def assert_array_less(x, y, err_msg=''): "Checks that x is smaller than y elementwise." assert_array_compare(less, x, y, err_msg=err_msg, |