summaryrefslogtreecommitdiff
path: root/numpy/ma/testutils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/testutils.py')
-rw-r--r--numpy/ma/testutils.py163
1 files changed, 69 insertions, 94 deletions
diff --git a/numpy/ma/testutils.py b/numpy/ma/testutils.py
index d6f3aa1dd..83aec7ea2 100644
--- a/numpy/ma/testutils.py
+++ b/numpy/ma/testutils.py
@@ -10,16 +10,18 @@ __revision__ = "$Revision: 3529 $"
__date__ = "$Date: 2007-11-13 10:01:14 +0200 (Tue, 13 Nov 2007) $"
-import numpy as N
-from numpy.core import ndarray
-from numpy.core.numerictypes import float_
+import operator
+
+import numpy as np
+from numpy import ndarray, float_
import numpy.core.umath as umath
from numpy.testing import NumpyTest, NumpyTestCase
+import numpy.testing.utils as utils
from numpy.testing.utils import build_err_msg, rand
import core
from core import mask_or, getmask, getmaskarray, masked_array, nomask, masked
-from core import filled, equal, less
+from core import fix_invalid, filled, equal, less
#------------------------------------------------------------------------------
def approx (a, b, fill_value=True, rtol=1.e-5, atol=1.e-8):
@@ -35,12 +37,13 @@ small or zero; it says how small a must be also.
d1 = filled(a)
d2 = filled(b)
if d1.dtype.char == "O" or d2.dtype.char == "O":
- return N.equal(d1,d2).ravel()
+ return np.equal(d1,d2).ravel()
x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
- d = N.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y))
+ d = np.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y))
return d.ravel()
+
def almost(a, b, decimal=6, fill_value=True):
"""Returns True if a and b are equal up to decimal places.
If fill_value is True, masked values considered equal. Otherwise, masked values
@@ -50,10 +53,10 @@ are considered unequal.
d1 = filled(a)
d2 = filled(b)
if d1.dtype.char == "O" or d2.dtype.char == "O":
- return N.equal(d1,d2).ravel()
+ return np.equal(d1,d2).ravel()
x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
- d = N.around(N.abs(x-y),decimal) <= 10.0**(-decimal)
+ d = np.around(np.abs(x-y),decimal) <= 10.0**(-decimal)
return d.ravel()
@@ -69,11 +72,12 @@ def assert_equal_records(a,b):
"""Asserts that two records are equal. Pretty crude for now."""
assert_equal(a.dtype, b.dtype)
for f in a.dtype.names:
- (af, bf) = (getattr(a,f), getattr(b,f))
+ (af, bf) = (operator.getitem(a,f), operator.getitem(b,f))
if not (af is masked) and not (bf is masked):
- assert_equal(getattr(a,f), getattr(b,f))
+ assert_equal(operator.getitem(a,f), operator.getitem(b,f))
return
+
def assert_equal(actual,desired,err_msg=''):
"""Asserts that two items are equal.
"""
@@ -95,16 +99,18 @@ def assert_equal(actual,desired,err_msg=''):
# Case #4. arrays or equivalent
if ((actual is masked) and not (desired is masked)) or \
((desired is masked) and not (actual is masked)):
- msg = build_err_msg([actual, desired], err_msg, header='', names=('x', 'y'))
+ msg = build_err_msg([actual, desired],
+ err_msg, header='', names=('x', 'y'))
raise ValueError(msg)
- actual = N.array(actual, copy=False, subok=True)
- desired = N.array(desired, copy=False, subok=True)
- if actual.dtype.char in "OS" and desired.dtype.char in "OS":
+ actual = np.array(actual, copy=False, subok=True)
+ desired = np.array(desired, copy=False, subok=True)
+ if actual.dtype.char in "OSV" and desired.dtype.char in "OSV":
return _assert_equal_on_sequences(actual.tolist(),
desired.tolist(),
err_msg='')
return assert_array_equal(actual, desired, err_msg)
-#.............................
+
+
def fail_if_equal(actual,desired,err_msg='',):
"""Raises an assertion error if two items are equal.
"""
@@ -120,119 +126,91 @@ def fail_if_equal(actual,desired,err_msg='',):
for k in range(len(desired)):
fail_if_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg))
return
- if isinstance(actual, N.ndarray) or isinstance(desired, N.ndarray):
+ if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
return fail_if_array_equal(actual, desired, err_msg)
msg = build_err_msg([actual, desired], err_msg)
assert desired != actual, msg
assert_not_equal = fail_if_equal
-#............................
-def assert_almost_equal(actual,desired,decimal=7,err_msg=''):
+
+
+def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
"""Asserts that two items are almost equal.
The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal)
"""
- if isinstance(actual, N.ndarray) or isinstance(desired, N.ndarray):
- return assert_array_almost_equal(actual, desired, decimal, err_msg)
- msg = build_err_msg([actual, desired], err_msg)
+ if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
+ return assert_array_almost_equal(actual, desired, decimal=decimal,
+ err_msg=err_msg, verbose=verbose)
+ msg = build_err_msg([actual, desired],
+ err_msg=err_msg, verbose=verbose)
assert round(abs(desired - actual),decimal) == 0, msg
-#............................
-def assert_array_compare(comparison, x, y, err_msg='', header='',
+
+
+assert_close = assert_almost_equal
+
+
+def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
fill_value=True):
"""Asserts that a comparison relation between two masked arrays is satisfied
elementwise."""
+ # Fill the data first
xf = filled(x)
yf = filled(y)
+ # Allocate a common mask and refill
m = mask_or(getmask(x), getmask(y))
-
- x = masked_array(xf, copy=False, subok=False, mask=m).filled(fill_value)
- y = masked_array(yf, copy=False, subok=False, mask=m).filled(fill_value)
-
+ x = masked_array(xf, copy=False, mask=m)
+ y = masked_array(yf, copy=False, mask=m)
if ((x is masked) and not (y is masked)) or \
((y is masked) and not (x is masked)):
- msg = build_err_msg([x, y], err_msg, header=header, names=('x', 'y'))
+ msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose,
+ header=header, names=('x', 'y'))
raise ValueError(msg)
+ # OK, now run the basic tests on filled versions
+ return utils.assert_array_compare(comparison,
+ x.filled(fill_value), y.filled(fill_value),
+ err_msg=err_msg,
+ verbose=verbose, header=header)
- if (x.dtype.char != "O") and (x.dtype.char != "S"):
- x = x.astype(float_)
- if isinstance(x, N.ndarray) and x.size > 1:
- x[N.isnan(x)] = 0
- elif N.isnan(x):
- x = 0
- if (y.dtype.char != "O") and (y.dtype.char != "S"):
- y = y.astype(float_)
- if isinstance(y, N.ndarray) and y.size > 1:
- y[N.isnan(y)] = 0
- elif N.isnan(y):
- y = 0
- try:
- cond = (x.shape==() or y.shape==()) or x.shape == y.shape
- if not cond:
- msg = build_err_msg([x, y],
- err_msg
- + '\n(shapes %s, %s mismatch)' % (x.shape,
- y.shape),
- header=header,
- names=('x', 'y'))
- assert cond, msg
- val = comparison(x,y)
- if m is not nomask and fill_value:
- val = masked_array(val, mask=m, copy=False)
- if isinstance(val, bool):
- cond = val
- reduced = [0]
- else:
- reduced = val.ravel()
- cond = reduced.all()
- reduced = reduced.tolist()
- if not cond:
- match = 100-100.0*reduced.count(1)/len(reduced)
- msg = build_err_msg([x, y],
- err_msg
- + '\n(mismatch %s%%)' % (match,),
- header=header,
- names=('x', 'y'))
- assert cond, msg
- except ValueError:
- msg = build_err_msg([x, y], err_msg, header=header, names=('x', 'y'))
- raise ValueError(msg)
-#............................
-def assert_array_equal(x, y, err_msg=''):
+
+def assert_array_equal(x, y, err_msg='', verbose=True):
"""Checks the elementwise equality of two masked arrays."""
- assert_array_compare(equal, x, y, err_msg=err_msg,
+ assert_array_compare(equal, x, y, err_msg=err_msg, verbose=verbose,
header='Arrays are not equal')
-##............................
-def fail_if_array_equal(x, y, err_msg=''):
+
+
+def fail_if_array_equal(x, y, err_msg='', verbose=True):
"Raises an assertion error if two masked arrays are not equal (elementwise)."
def compare(x,y):
-
- return (not N.alltrue(approx(x, y)))
- assert_array_compare(compare, x, y, err_msg=err_msg,
+ return (not np.alltrue(approx(x, y)))
+ assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
header='Arrays are not equal')
-#............................
-def assert_array_approx_equal(x, y, decimal=6, err_msg=''):
+
+
+def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True):
"""Checks the elementwise equality of two masked arrays, up to a given
number of decimals."""
def compare(x, y):
"Returns the result of the loose comparison between x and y)."
return approx(x,y, rtol=10.**-decimal)
- assert_array_compare(compare, x, y, err_msg=err_msg,
+ assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
header='Arrays are not almost equal')
-#............................
-def assert_array_almost_equal(x, y, decimal=6, err_msg=''):
+
+
+def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
"""Checks the elementwise equality of two masked arrays, up to a given
number of decimals."""
def compare(x, y):
"Returns the result of the loose comparison between x and y)."
return almost(x,y,decimal)
- assert_array_compare(compare, x, y, err_msg=err_msg,
+ assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
header='Arrays are not almost equal')
-#............................
-def assert_array_less(x, y, err_msg=''):
+
+
+def assert_array_less(x, y, err_msg='', verbose=True):
"Checks that x is smaller than y elementwise."
- assert_array_compare(less, x, y, err_msg=err_msg,
+ assert_array_compare(less, x, y, err_msg=err_msg, verbose=verbose,
header='Arrays are not less-ordered')
-#............................
-assert_close = assert_almost_equal
-#............................
+
+
def assert_mask_equal(m1, m2):
"""Asserts the equality of two masks."""
if m1 is nomask:
@@ -240,6 +218,3 @@ def assert_mask_equal(m1, m2):
if m2 is nomask:
assert(m1 is nomask)
assert_array_equal(m1, m2)
-
-if __name__ == '__main__':
- pass