diff options
author | pierregm <pierregm@localhost> | 2010-08-09 20:07:06 +0000 |
---|---|---|
committer | pierregm <pierregm@localhost> | 2010-08-09 20:07:06 +0000 |
commit | 32c08068f877965c64bf7ddf5e3885f560b6920f (patch) | |
tree | 611c1bca335053d1df77225e426545356adf769b /numpy/ma/testutils.py | |
parent | 133bb76452af3dfe8e7b6c928b82bdb9fea70770 (diff) | |
download | numpy-32c08068f877965c64bf7ddf5e3885f560b6920f.tar.gz |
Diffstat (limited to 'numpy/ma/testutils.py')
-rw-r--r-- | numpy/ma/testutils.py | 54 |
1 files changed, 27 insertions, 27 deletions
diff --git a/numpy/ma/testutils.py b/numpy/ma/testutils.py index 1918e97a0..2b69cc4ad 100644 --- a/numpy/ma/testutils.py +++ b/numpy/ma/testutils.py @@ -22,7 +22,7 @@ from core import mask_or, getmask, masked_array, nomask, masked, filled, \ equal, less #------------------------------------------------------------------------------ -def approx (a, b, fill_value=True, rtol=1.e-5, atol=1.e-8): +def approx (a, b, fill_value=True, rtol=1e-5, atol=1e-8): """Returns true if all components of a and b are equal subject to given tolerances. If fill_value is True, masked values considered equal. Otherwise, masked values @@ -35,10 +35,10 @@ small or zero; it says how small a must be also. d1 = filled(a) d2 = filled(b) if d1.dtype.char == "O" or d2.dtype.char == "O": - return np.equal(d1,d2).ravel() + return np.equal(d1, d2).ravel() x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) - d = np.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y)) + d = np.less_equal(umath.absolute(x - y), atol + rtol * umath.absolute(y)) return d.ravel() @@ -51,46 +51,46 @@ are considered unequal. d1 = filled(a) d2 = filled(b) if d1.dtype.char == "O" or d2.dtype.char == "O": - return np.equal(d1,d2).ravel() + return np.equal(d1, d2).ravel() x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) - d = np.around(np.abs(x-y),decimal) <= 10.0**(-decimal) + d = np.around(np.abs(x - y), decimal) <= 10.0 ** (-decimal) return d.ravel() #................................................ def _assert_equal_on_sequences(actual, desired, err_msg=''): "Asserts the equality of two non-array sequences." - assert_equal(len(actual),len(desired),err_msg) + assert_equal(len(actual), len(desired), err_msg) for k in range(len(desired)): - assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg)) + assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg)) return -def assert_equal_records(a,b): +def assert_equal_records(a, b): """Asserts that two records are equal. Pretty crude for now.""" assert_equal(a.dtype, b.dtype) for f in a.dtype.names: - (af, bf) = (operator.getitem(a,f), operator.getitem(b,f)) + (af, bf) = (operator.getitem(a, f), operator.getitem(b, f)) if not (af is masked) and not (bf is masked): - assert_equal(operator.getitem(a,f), operator.getitem(b,f)) + assert_equal(operator.getitem(a, f), operator.getitem(b, f)) return -def assert_equal(actual,desired,err_msg=''): +def assert_equal(actual, desired, err_msg=''): """Asserts that two items are equal. """ # Case #1: dictionary ..... if isinstance(desired, dict): if not isinstance(actual, dict): raise AssertionError(repr(type(actual))) - assert_equal(len(actual),len(desired),err_msg) - for k,i in desired.items(): + assert_equal(len(actual), len(desired), err_msg) + for k, i in desired.items(): if not k in actual: - raise AssertionError("%s not in %s" % (k,actual)) - assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k,err_msg)) + raise AssertionError("%s not in %s" % (k, actual)) + assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg)) return # Case #2: lists ..... - if isinstance(desired, (list,tuple)) and isinstance(actual, (list,tuple)): + if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): return _assert_equal_on_sequences(actual, desired, err_msg='') if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)): msg = build_err_msg([actual, desired], err_msg,) @@ -121,22 +121,22 @@ def assert_equal(actual,desired,err_msg=''): return assert_array_equal(actual, desired, err_msg) -def fail_if_equal(actual,desired,err_msg='',): +def fail_if_equal(actual, desired, err_msg='',): """Raises an assertion error if two items are equal. """ if isinstance(desired, dict): if not isinstance(actual, dict): raise AssertionError(repr(type(actual))) - fail_if_equal(len(actual),len(desired),err_msg) - for k,i in desired.items(): + fail_if_equal(len(actual), len(desired), err_msg) + for k, i in desired.items(): if not k in actual: raise AssertionError(repr(k)) - fail_if_equal(actual[k], desired[k], 'key=%r\n%s' % (k,err_msg)) + fail_if_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg)) return - if isinstance(desired, (list,tuple)) and isinstance(actual, (list,tuple)): - fail_if_equal(len(actual),len(desired),err_msg) + if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): + fail_if_equal(len(actual), len(desired), err_msg) for k in range(len(desired)): - fail_if_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg)) + fail_if_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg)) return if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray): return fail_if_array_equal(actual, desired, err_msg) @@ -155,7 +155,7 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True): err_msg=err_msg, verbose=verbose) msg = build_err_msg([actual, desired], err_msg=err_msg, verbose=verbose) - if not round(abs(desired - actual),decimal) == 0: + if not round(abs(desired - actual), decimal) == 0: raise AssertionError(msg) @@ -195,7 +195,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True): def fail_if_array_equal(x, y, err_msg='', verbose=True): "Raises an assertion error if two masked arrays are not equal (elementwise)." - def compare(x,y): + def compare(x, y): return (not np.alltrue(approx(x, y))) assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, header='Arrays are not equal') @@ -206,7 +206,7 @@ def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True): number of decimals.""" def compare(x, y): "Returns the result of the loose comparison between x and y)." - return approx(x,y, rtol=10.**-decimal) + return approx(x, y, rtol=10. ** -decimal) assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, header='Arrays are not almost equal') @@ -216,7 +216,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): number of decimals.""" def compare(x, y): "Returns the result of the loose comparison between x and y)." - return almost(x,y,decimal) + return almost(x, y, decimal) assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, header='Arrays are not almost equal') |