summaryrefslogtreecommitdiff
path: root/numpy/ma/testutils.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-07-04 17:09:26 -0600
committerCharles Harris <charlesr.harris@gmail.com>2015-07-04 23:50:39 -0600
commit7c8c9adda27efe7f84fc98240ee1b7fa15714f06 (patch)
tree0506690eafdb29bad6b8f91527f05597c37e1718 /numpy/ma/testutils.py
parentc2ae6aa0103aecdb5e2a71504583451cada1bfbc (diff)
downloadnumpy-7c8c9adda27efe7f84fc98240ee1b7fa15714f06.tar.gz
STY,MAINT: PEP8 and pyflakes fixes for numpy/ma/*.py
Also * Add __all__ to numpy/ma/testutils.py * Remove various stray "#" We might want to consider removing/refactoring both numpy/ma/bench.py and numpy/ma/timer_comparison.
Diffstat (limited to 'numpy/ma/testutils.py')
-rw-r--r--numpy/ma/testutils.py138
1 files changed, 86 insertions, 52 deletions
diff --git a/numpy/ma/testutils.py b/numpy/ma/testutils.py
index feff3e879..2af39d6b4 100644
--- a/numpy/ma/testutils.py
+++ b/numpy/ma/testutils.py
@@ -7,32 +7,33 @@
"""
from __future__ import division, absolute_import, print_function
-__author__ = "Pierre GF Gerard-Marchant ($Author: jarrod.millman $)"
-__version__ = "1.0"
-__revision__ = "$Revision: 3529 $"
-__date__ = "$Date: 2007-11-13 10:01:14 +0200 (Tue, 13 Nov 2007) $"
-
-
import operator
import numpy as np
from numpy import ndarray, float_
import numpy.core.umath as umath
-from numpy.testing import *
+from numpy.testing import assert_, build_err_msg
import numpy.testing.utils as utils
+from .core import mask_or, getmask, masked_array, nomask, masked, filled
+
+__all__ = [
+ 'almost', 'approx', 'assert_almost_equal', 'assert_array_almost_equal',
+ 'assert_array_approx_equal', 'assert_array_compare',
+ 'assert_array_equal', 'assert_array_less', 'assert_close',
+ 'assert_equal', 'assert_equal_records', 'assert_mask_equal',
+ 'assert_not_equal', 'fail_if_array_equal',
+ ]
-from .core import mask_or, getmask, masked_array, nomask, masked, filled, \
- equal, less
+def approx(a, b, fill_value=True, rtol=1e-5, atol=1e-8):
+ """
+ Returns true if all components of a and b are equal to given tolerances.
-#------------------------------------------------------------------------------
-def approx (a, b, fill_value=True, rtol=1e-5, atol=1e-8):
- """Returns true if all components of a and b are equal subject to given tolerances.
+ If fill_value is True, masked values considered equal. Otherwise,
+ masked values are considered unequal. The relative error rtol should
+ be positive and << 1.0 The absolute error atol comes into play for
+ those elements of b that are very small or zero; it says how small a
+ must be also.
-If fill_value is True, masked values considered equal. Otherwise, masked values
-are considered unequal.
-The relative error rtol should be positive and << 1.0
-The absolute error atol comes into play for those elements of b that are very
-small or zero; it says how small a must be also.
"""
m = mask_or(getmask(a), getmask(b))
d1 = filled(a)
@@ -46,9 +47,12 @@ small or zero; it says how small a must be also.
def almost(a, b, decimal=6, fill_value=True):
- """Returns True if a and b are equal up to decimal places.
-If fill_value is True, masked values considered equal. Otherwise, masked values
-are considered unequal.
+ """
+ Returns True if a and b are equal up to decimal places.
+
+ If fill_value is True, masked values considered equal. Otherwise,
+ masked values are considered unequal.
+
"""
m = mask_or(getmask(a), getmask(b))
d1 = filled(a)
@@ -61,16 +65,24 @@ are considered unequal.
return d.ravel()
-#................................................
def _assert_equal_on_sequences(actual, desired, err_msg=''):
- "Asserts the equality of two non-array sequences."
+ """
+ Asserts the equality of two non-array sequences.
+
+ """
assert_equal(len(actual), len(desired), err_msg)
for k in range(len(desired)):
assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg))
return
+
def assert_equal_records(a, b):
- """Asserts that two records are equal. Pretty crude for now."""
+ """
+ Asserts that two records are equal.
+
+ Pretty crude for now.
+
+ """
assert_equal(a.dtype, b.dtype)
for f in a.dtype.names:
(af, bf) = (operator.getitem(a, f), operator.getitem(b, f))
@@ -80,14 +92,17 @@ def assert_equal_records(a, b):
def assert_equal(actual, desired, err_msg=''):
- "Asserts that two items are equal."
+ """
+ Asserts that two items are equal.
+
+ """
# Case #1: dictionary .....
if isinstance(desired, dict):
if not isinstance(actual, dict):
raise AssertionError(repr(type(actual)))
assert_equal(len(actual), len(desired), err_msg)
for k, i in desired.items():
- if not k in actual:
+ if k not in actual:
raise AssertionError("%s not in %s" % (k, actual))
assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg))
return
@@ -101,7 +116,7 @@ def assert_equal(actual, desired, err_msg=''):
return
# Case #4. arrays or equivalent
if ((actual is masked) and not (desired is masked)) or \
- ((desired is masked) and not (actual is masked)):
+ ((desired is masked) and not (actual is masked)):
msg = build_err_msg([actual, desired],
err_msg, header='', names=('x', 'y'))
raise ValueError(msg)
@@ -112,26 +127,20 @@ def assert_equal(actual, desired, err_msg=''):
return _assert_equal_on_sequences(actual.tolist(),
desired.tolist(),
err_msg='')
-# elif actual_dtype.char in "OV" and desired_dtype.char in "OV":
-# if (actual_dtype != desired_dtype) and actual_dtype:
-# msg = build_err_msg([actual_dtype, desired_dtype],
-# err_msg, header='', names=('actual', 'desired'))
-# raise ValueError(msg)
-# return _assert_equal_on_sequences(actual.tolist(),
-# desired.tolist(),
-# err_msg='')
return assert_array_equal(actual, desired, err_msg)
def fail_if_equal(actual, desired, err_msg='',):
- """Raises an assertion error if two items are equal.
+ """
+ Raises an assertion error if two items are equal.
+
"""
if isinstance(desired, dict):
if not isinstance(actual, dict):
raise AssertionError(repr(type(actual)))
fail_if_equal(len(actual), len(desired), err_msg)
for k, i in desired.items():
- if not k in actual:
+ if k not in actual:
raise AssertionError(repr(k))
fail_if_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg))
return
@@ -146,12 +155,16 @@ def fail_if_equal(actual, desired, err_msg='',):
if not desired != actual:
raise AssertionError(msg)
+
assert_not_equal = fail_if_equal
def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
- """Asserts that two items are almost equal.
- The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal)
+ """
+ Asserts that two items are almost equal.
+
+ The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal).
+
"""
if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
return assert_array_almost_equal(actual, desired, decimal=decimal,
@@ -167,17 +180,18 @@ assert_close = assert_almost_equal
def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
fill_value=True):
- """Asserts that a comparison relation between two masked arrays is satisfied
- elementwise."""
- # Fill the data first
-# xf = filled(x)
-# yf = filled(y)
+ """
+ Asserts that comparison between two masked arrays is satisfied.
+
+ The comparison is elementwise.
+
+ """
# Allocate a common mask and refill
m = mask_or(getmask(x), getmask(y))
x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False)
y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False)
if ((x is masked) and not (y is masked)) or \
- ((y is masked) and not (x is masked)):
+ ((y is masked) and not (x is masked)):
msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose,
header=header, names=('x', 'y'))
raise ValueError(msg)
@@ -190,14 +204,20 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
def assert_array_equal(x, y, err_msg='', verbose=True):
- """Checks the elementwise equality of two masked arrays."""
+ """
+ Checks the elementwise equality of two masked arrays.
+
+ """
assert_array_compare(operator.__eq__, x, y,
err_msg=err_msg, verbose=verbose,
header='Arrays are not equal')
def fail_if_array_equal(x, y, err_msg='', verbose=True):
- "Raises an assertion error if two masked arrays are not equal (elementwise)."
+ """
+ Raises an assertion error if two masked arrays are not equal elementwise.
+
+ """
def compare(x, y):
return (not np.alltrue(approx(x, y)))
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
@@ -205,8 +225,12 @@ def fail_if_array_equal(x, y, err_msg='', verbose=True):
def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True):
- """Checks the elementwise equality of two masked arrays, up to a given
- number of decimals."""
+ """
+ Checks the equality of two masked arrays, up to given number odecimals.
+
+ The equality is checked elementwise.
+
+ """
def compare(x, y):
"Returns the result of the loose comparison between x and y)."
return approx(x, y, rtol=10. ** -decimal)
@@ -215,8 +239,12 @@ def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True):
def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
- """Checks the elementwise equality of two masked arrays, up to a given
- number of decimals."""
+ """
+ Checks the equality of two masked arrays, up to given number odecimals.
+
+ The equality is checked elementwise.
+
+ """
def compare(x, y):
"Returns the result of the loose comparison between x and y)."
return almost(x, y, decimal)
@@ -225,14 +253,20 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
def assert_array_less(x, y, err_msg='', verbose=True):
- "Checks that x is smaller than y elementwise."
+ """
+ Checks that x is smaller than y elementwise.
+
+ """
assert_array_compare(operator.__lt__, x, y,
err_msg=err_msg, verbose=verbose,
header='Arrays are not less-ordered')
def assert_mask_equal(m1, m2, err_msg=''):
- """Asserts the equality of two masks."""
+ """
+ Asserts the equality of two masks.
+
+ """
if m1 is nomask:
assert_(m2 is nomask)
if m2 is nomask: