summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py138
1 files changed, 52 insertions, 86 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index fcffc9002..11e3b95d3 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -4,6 +4,7 @@ Utility function to facilitate testing.
import os
import sys
+import operator
__all__ = ['assert_equal', 'assert_almost_equal','assert_approx_equal',
'assert_array_equal', 'assert_array_less',
@@ -98,8 +99,9 @@ if os.name=='nt' and sys.version[:3] > '2.3':
processName, instance,
win32pdh.PDH_FMT_LONG, None)
-def build_err_msg(actual, desired, err_msg, header='Items are not equal:',
- verbose=True):
+def build_err_msg(arrays, err_msg, header='Items are not equal:',
+ verbose=True,
+ names=('ACTUAL', 'DESIRED')):
msg = ['\n' + header]
if err_msg:
if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
@@ -107,18 +109,13 @@ def build_err_msg(actual, desired, err_msg, header='Items are not equal:',
else:
msg.append(err_msg)
if verbose:
- try:
- rd = repr(desired)
- except:
- rd = '[repr failed]'
- rd = rd[:100]
- msg.append(' DESIRED: ' + rd)
- try:
- ra = repr(actual)
- except:
- ra = '[repr failed]'
- rd = ra[:100]
- msg.append(' ACTUAL: ' + ra)
+ for i, a in enumerate(arrays):
+ try:
+ r = repr(a)
+ except:
+ r = '[repr failed]'
+ r = r[:100]
+ msg.append(' %s: %s' % (names[i], r))
return '\n'.join(msg)
def assert_equal(actual,desired,err_msg='',verbose=True):
@@ -140,7 +137,7 @@ def assert_equal(actual,desired,err_msg='',verbose=True):
from numpy.core import ndarray
if isinstance(actual, ndarray) or isinstance(desired, ndarray):
return assert_array_equal(actual, desired, err_msg)
- msg = build_err_msg(actual, desired, err_msg, verbose=verbose)
+ msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
assert desired == actual, msg
def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
@@ -153,7 +150,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
from numpy.core import ndarray
if isinstance(actual, ndarray) or isinstance(desired, ndarray):
return assert_array_almost_equal(actual, desired, decimal, err_msg)
- msg = build_err_msg(actual, desired, err_msg, verbose=verbose)
+ msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
assert round(abs(desired - actual),decimal) == 0, msg
@@ -177,88 +174,57 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
sc_actual = actual/scale
except ZeroDivisionError:
sc_actual = 0.0
- msg = build_err_msg(actual, desired, err_msg,
+ msg = build_err_msg([actual, desired], err_msg,
header='Items are not equal to %d significant digits:' %
significant,
verbose=verbose)
assert math.fabs(sc_desired - sc_actual) < pow(10.,-1*significant), msg
-
-def assert_array_equal(x,y,err_msg=''):
- from numpy.core import asarray, alltrue, equal, shape, ravel, array2string
- x,y = asarray(x), asarray(y)
- msg = '\nArrays are not equal'
- try:
- assert 0 in [len(shape(x)),len(shape(y))] \
- or (len(shape(x))==len(shape(y)) and \
- alltrue(equal(shape(x),shape(y)))),\
- msg + ' (shapes %s, %s mismatch):\n\t' \
- % (shape(x),shape(y)) + err_msg
- reduced = ravel(equal(x,y))
- cond = alltrue(reduced)
- if not cond:
- s1 = array2string(x,precision=16)
- s2 = array2string(y,precision=16)
- if len(s1)>120: s1 = s1[:120] + '...'
- if len(s2)>120: s2 = s2[:120] + '...'
- match = 100-100.0*reduced.tolist().count(1)/len(reduced)
- msg = msg + ' (mismatch %s%%):\n\tArray 1: %s\n\tArray 2: %s' % (match,s1,s2)
- assert cond,\
- msg + '\n\t' + err_msg
- except ValueError:
- raise ValueError, msg
-
-
-def assert_array_almost_equal(x,y,decimal=6,err_msg=''):
- from numpy.core import asarray, alltrue, equal, shape, ravel,\
- array2string, less_equal, around
+def assert_array_compare(comparision, x, y, err_msg='', verbose=True,
+ header=''):
+ from numpy.core import asarray
x = asarray(x)
y = asarray(y)
- msg = '\nArrays are not almost equal'
try:
- cond = alltrue(equal(shape(x),shape(y)))
- if not cond:
- msg = msg + ' (shapes mismatch):\n\t'\
- 'Shape of array 1: %s\n\tShape of array 2: %s' % (shape(x),shape(y))
- assert cond, msg + '\n\t' + err_msg
- reduced = ravel(equal(less_equal(around(abs(x-y),decimal),10.0**(-decimal)),1))
- cond = alltrue(reduced)
+ cond = (x.shape==() or y.shape==()) or x.shape == y.shape
if not cond:
- s1 = array2string(x,precision=decimal+1)
- s2 = array2string(y,precision=decimal+1)
- if len(s1)>120: s1 = s1[:120] + '...'
- if len(s2)>120: s2 = s2[:120] + '...'
- match = 100-100.0*reduced.tolist().count(1)/len(reduced)
- msg = msg + ' (mismatch %s%%):\n\tArray 1: %s\n\tArray 2: %s' % (match,s1,s2)
- assert cond,\
- msg + '\n\t' + err_msg
- except ValueError:
- print sys.exc_value
- print shape(x),shape(y)
- print x, y
- raise ValueError, 'arrays are not almost equal'
-
-def assert_array_less(x,y,err_msg=''):
- from numpy.core import asarray, alltrue, less, equal, shape, ravel, array2string
- x,y = asarray(x), asarray(y)
- msg = '\nArrays are not less-ordered'
- try:
- assert alltrue(equal(shape(x),shape(y))),\
- msg + ' (shapes mismatch):\n\t' + err_msg
- reduced = ravel(less(x,y))
- cond = alltrue(reduced)
+ msg = build_err_msg([x, y],
+ err_msg
+ + '\n(shapes %s, %s mismatch)' % (x.shape,
+ y.shape),
+ verbose=verbose, header=header,
+ names=('x', 'y'))
+ assert cond, msg
+ reduced = comparision(x, y).ravel()
+ cond = reduced.all()
if not cond:
- s1 = array2string(x,precision=16)
- s2 = array2string(y,precision=16)
- if len(s1)>120: s1 = s1[:120] + '...'
- if len(s2)>120: s2 = s2[:120] + '...'
match = 100-100.0*reduced.tolist().count(1)/len(reduced)
- msg = msg + ' (mismatch %s%%):\n\tArray 1: %s\n\tArray 2: %s' % (match,s1,s2)
- assert cond,\
- msg + '\n\t' + err_msg
+ msg = build_err_msg([x, y],
+ err_msg
+ + '\n(mismatch %s%%)' % (match,),
+ verbose=verbose, header=header,
+ names=('x', 'y'))
+ assert cond, msg
except ValueError:
- print shape(x),shape(y)
- raise ValueError, 'arrays are not less-ordered'
+ msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
+ names=('x', 'y'))
+ raise ValueError(msg)
+
+def assert_array_equal(x, y, err_msg='', verbose=True):
+ assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
+ verbose=verbose, header='Arrays are not equal')
+
+def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
+ from numpy.core import around
+ def compare(x, y):
+ return around(abs(x-y),decimal) <= 10.0**(-decimal)
+ assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
+ header='Arrays are not almost equal')
+
+def assert_array_less(x, y, err_msg='', verbose=True):
+ assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
+ verbose=verbose,
+ header='Arrays are not less-ordered')
def runstring(astr, dict):
exec astr in dict