summaryrefslogtreecommitdiff
path: root/numpy/ma/testutils.py
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2008-04-30 19:36:42 +0000
committerpierregm <pierregm@localhost>2008-04-30 19:36:42 +0000
commit21a7341b383356ba99c10eda9654a4470e1247a0 (patch)
tree95160bb6eca3599eacc59ba8a8f9d412dc7bd418 /numpy/ma/testutils.py
parent233af718df4cf92bad9b93993e11faba5b6fff26 (diff)
downloadnumpy-21a7341b383356ba99c10eda9654a4470e1247a0.tar.gz
core : fixed a bug w/ array((0,0))/0.
testutils : introduced assert_almost_equal/assert_approx_equal: use assert_almost_equal(a,b,decimal) to compare a and b up to decimal places use assert_approx_equal(a,b,decimal) to compare a and b up to b*10.**-decimal
Diffstat (limited to 'numpy/ma/testutils.py')
-rw-r--r--numpy/ma/testutils.py28
1 files changed, 27 insertions, 1 deletions
diff --git a/numpy/ma/testutils.py b/numpy/ma/testutils.py
index c4722f940..af0e818b2 100644
--- a/numpy/ma/testutils.py
+++ b/numpy/ma/testutils.py
@@ -40,6 +40,23 @@ small or zero; it says how small a must be also.
y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
d = N.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y))
return d.ravel()
+
+def almost(a, b, decimal=6, fill_value=True):
+ """Returns True if a and b are equal up to decimal places.
+If fill_value is True, masked values considered equal. Otherwise, masked values
+are considered unequal.
+ """
+ m = mask_or(getmask(a), getmask(b))
+ d1 = filled(a)
+ d2 = filled(b)
+ if d1.dtype.char == "O" or d2.dtype.char == "O":
+ return N.equal(d1,d2).ravel()
+ x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
+ y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
+ d = N.around(N.abs(x-y),decimal) <= 10.0**(-decimal)
+ return d.ravel()
+
+
#................................................
def _assert_equal_on_sequences(actual, desired, err_msg=''):
"Asserts the equality of two non-array sequences."
@@ -191,7 +208,7 @@ def fail_if_array_equal(x, y, err_msg=''):
assert_array_compare(compare, x, y, err_msg=err_msg,
header='Arrays are not equal')
#............................
-def assert_array_almost_equal(x, y, decimal=6, err_msg=''):
+def assert_array_approx_equal(x, y, decimal=6, err_msg=''):
"""Checks the elementwise equality of two masked arrays, up to a given
number of decimals."""
def compare(x, y):
@@ -200,6 +217,15 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg=''):
assert_array_compare(compare, x, y, err_msg=err_msg,
header='Arrays are not almost equal')
#............................
+def assert_array_almost_equal(x, y, decimal=6, err_msg=''):
+ """Checks the elementwise equality of two masked arrays, up to a given
+ number of decimals."""
+ def compare(x, y):
+ "Returns the result of the loose comparison between x and y)."
+ return almost(x,y,decimal)
+ assert_array_compare(compare, x, y, err_msg=err_msg,
+ header='Arrays are not almost equal')
+#............................
def assert_array_less(x, y, err_msg=''):
"Checks that x is smaller than y elementwise."
assert_array_compare(less, x, y, err_msg=err_msg,