summaryrefslogtreecommitdiff
path: root/numpy/ma/testutils.py
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2010-08-09 20:07:06 +0000
committerpierregm <pierregm@localhost>2010-08-09 20:07:06 +0000
commit32c08068f877965c64bf7ddf5e3885f560b6920f (patch)
tree611c1bca335053d1df77225e426545356adf769b /numpy/ma/testutils.py
parent133bb76452af3dfe8e7b6c928b82bdb9fea70770 (diff)
downloadnumpy-32c08068f877965c64bf7ddf5e3885f560b6920f.tar.gz
Diffstat (limited to 'numpy/ma/testutils.py')
-rw-r--r--numpy/ma/testutils.py54
1 files changed, 27 insertions, 27 deletions
diff --git a/numpy/ma/testutils.py b/numpy/ma/testutils.py
index 1918e97a0..2b69cc4ad 100644
--- a/numpy/ma/testutils.py
+++ b/numpy/ma/testutils.py
@@ -22,7 +22,7 @@ from core import mask_or, getmask, masked_array, nomask, masked, filled, \
equal, less
#------------------------------------------------------------------------------
-def approx (a, b, fill_value=True, rtol=1.e-5, atol=1.e-8):
+def approx (a, b, fill_value=True, rtol=1e-5, atol=1e-8):
"""Returns true if all components of a and b are equal subject to given tolerances.
If fill_value is True, masked values considered equal. Otherwise, masked values
@@ -35,10 +35,10 @@ small or zero; it says how small a must be also.
d1 = filled(a)
d2 = filled(b)
if d1.dtype.char == "O" or d2.dtype.char == "O":
- return np.equal(d1,d2).ravel()
+ return np.equal(d1, d2).ravel()
x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
- d = np.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y))
+ d = np.less_equal(umath.absolute(x - y), atol + rtol * umath.absolute(y))
return d.ravel()
@@ -51,46 +51,46 @@ are considered unequal.
d1 = filled(a)
d2 = filled(b)
if d1.dtype.char == "O" or d2.dtype.char == "O":
- return np.equal(d1,d2).ravel()
+ return np.equal(d1, d2).ravel()
x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
- d = np.around(np.abs(x-y),decimal) <= 10.0**(-decimal)
+ d = np.around(np.abs(x - y), decimal) <= 10.0 ** (-decimal)
return d.ravel()
#................................................
def _assert_equal_on_sequences(actual, desired, err_msg=''):
"Asserts the equality of two non-array sequences."
- assert_equal(len(actual),len(desired),err_msg)
+ assert_equal(len(actual), len(desired), err_msg)
for k in range(len(desired)):
- assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg))
+ assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg))
return
-def assert_equal_records(a,b):
+def assert_equal_records(a, b):
"""Asserts that two records are equal. Pretty crude for now."""
assert_equal(a.dtype, b.dtype)
for f in a.dtype.names:
- (af, bf) = (operator.getitem(a,f), operator.getitem(b,f))
+ (af, bf) = (operator.getitem(a, f), operator.getitem(b, f))
if not (af is masked) and not (bf is masked):
- assert_equal(operator.getitem(a,f), operator.getitem(b,f))
+ assert_equal(operator.getitem(a, f), operator.getitem(b, f))
return
-def assert_equal(actual,desired,err_msg=''):
+def assert_equal(actual, desired, err_msg=''):
"""Asserts that two items are equal.
"""
# Case #1: dictionary .....
if isinstance(desired, dict):
if not isinstance(actual, dict):
raise AssertionError(repr(type(actual)))
- assert_equal(len(actual),len(desired),err_msg)
- for k,i in desired.items():
+ assert_equal(len(actual), len(desired), err_msg)
+ for k, i in desired.items():
if not k in actual:
- raise AssertionError("%s not in %s" % (k,actual))
- assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k,err_msg))
+ raise AssertionError("%s not in %s" % (k, actual))
+ assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg))
return
# Case #2: lists .....
- if isinstance(desired, (list,tuple)) and isinstance(actual, (list,tuple)):
+ if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
return _assert_equal_on_sequences(actual, desired, err_msg='')
if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)):
msg = build_err_msg([actual, desired], err_msg,)
@@ -121,22 +121,22 @@ def assert_equal(actual,desired,err_msg=''):
return assert_array_equal(actual, desired, err_msg)
-def fail_if_equal(actual,desired,err_msg='',):
+def fail_if_equal(actual, desired, err_msg='',):
"""Raises an assertion error if two items are equal.
"""
if isinstance(desired, dict):
if not isinstance(actual, dict):
raise AssertionError(repr(type(actual)))
- fail_if_equal(len(actual),len(desired),err_msg)
- for k,i in desired.items():
+ fail_if_equal(len(actual), len(desired), err_msg)
+ for k, i in desired.items():
if not k in actual:
raise AssertionError(repr(k))
- fail_if_equal(actual[k], desired[k], 'key=%r\n%s' % (k,err_msg))
+ fail_if_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg))
return
- if isinstance(desired, (list,tuple)) and isinstance(actual, (list,tuple)):
- fail_if_equal(len(actual),len(desired),err_msg)
+ if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
+ fail_if_equal(len(actual), len(desired), err_msg)
for k in range(len(desired)):
- fail_if_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg))
+ fail_if_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg))
return
if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
return fail_if_array_equal(actual, desired, err_msg)
@@ -155,7 +155,7 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
err_msg=err_msg, verbose=verbose)
msg = build_err_msg([actual, desired],
err_msg=err_msg, verbose=verbose)
- if not round(abs(desired - actual),decimal) == 0:
+ if not round(abs(desired - actual), decimal) == 0:
raise AssertionError(msg)
@@ -195,7 +195,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
def fail_if_array_equal(x, y, err_msg='', verbose=True):
"Raises an assertion error if two masked arrays are not equal (elementwise)."
- def compare(x,y):
+ def compare(x, y):
return (not np.alltrue(approx(x, y)))
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
header='Arrays are not equal')
@@ -206,7 +206,7 @@ def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True):
number of decimals."""
def compare(x, y):
"Returns the result of the loose comparison between x and y)."
- return approx(x,y, rtol=10.**-decimal)
+ return approx(x, y, rtol=10. ** -decimal)
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
header='Arrays are not almost equal')
@@ -216,7 +216,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
number of decimals."""
def compare(x, y):
"Returns the result of the loose comparison between x and y)."
- return almost(x,y,decimal)
+ return almost(x, y, decimal)
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
header='Arrays are not almost equal')