summaryrefslogtreecommitdiff
path: root/numpy/ma/testutils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/testutils.py')
-rw-r--r--numpy/ma/testutils.py28
1 files changed, 27 insertions, 1 deletions
diff --git a/numpy/ma/testutils.py b/numpy/ma/testutils.py
index c4722f940..af0e818b2 100644
--- a/numpy/ma/testutils.py
+++ b/numpy/ma/testutils.py
@@ -40,6 +40,23 @@ small or zero; it says how small a must be also.
y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
d = N.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y))
return d.ravel()
+
+def almost(a, b, decimal=6, fill_value=True):
+ """Returns True if a and b are equal up to decimal places.
+If fill_value is True, masked values considered equal. Otherwise, masked values
+are considered unequal.
+ """
+ m = mask_or(getmask(a), getmask(b))
+ d1 = filled(a)
+ d2 = filled(b)
+ if d1.dtype.char == "O" or d2.dtype.char == "O":
+ return N.equal(d1,d2).ravel()
+ x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
+ y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
+ d = N.around(N.abs(x-y),decimal) <= 10.0**(-decimal)
+ return d.ravel()
+
+
#................................................
def _assert_equal_on_sequences(actual, desired, err_msg=''):
"Asserts the equality of two non-array sequences."
@@ -191,7 +208,7 @@ def fail_if_array_equal(x, y, err_msg=''):
assert_array_compare(compare, x, y, err_msg=err_msg,
header='Arrays are not equal')
#............................
-def assert_array_almost_equal(x, y, decimal=6, err_msg=''):
+def assert_array_approx_equal(x, y, decimal=6, err_msg=''):
"""Checks the elementwise equality of two masked arrays, up to a given
number of decimals."""
def compare(x, y):
@@ -200,6 +217,15 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg=''):
assert_array_compare(compare, x, y, err_msg=err_msg,
header='Arrays are not almost equal')
#............................
+def assert_array_almost_equal(x, y, decimal=6, err_msg=''):
+ """Checks the elementwise equality of two masked arrays, up to a given
+ number of decimals."""
+ def compare(x, y):
+ "Returns the result of the loose comparison between x and y)."
+ return almost(x,y,decimal)
+ assert_array_compare(compare, x, y, err_msg=err_msg,
+ header='Arrays are not almost equal')
+#............................
def assert_array_less(x, y, err_msg=''):
"Checks that x is smaller than y elementwise."
assert_array_compare(less, x, y, err_msg=err_msg,