summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorRobert Kern <robert.kern@gmail.com>2008-04-17 20:51:03 +0000
committerRobert Kern <robert.kern@gmail.com>2008-04-17 20:51:03 +0000
commite653089a814b89dc73c08123ded75af4d3b37d17 (patch)
tree5ac097f820c337e46e889cf414f2724bb9aae721 /numpy
parent0d6981095ca7855eafa2118cf4a39baff16475c1 (diff)
downloadnumpy-e653089a814b89dc73c08123ded75af4d3b37d17.tar.gz
Correct dependency on missing code.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/testing/tests/test_utils.py54
-rw-r--r--numpy/testing/utils.py47
2 files changed, 67 insertions, 34 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 48cd89cdd..6f40c778b 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -1,10 +1,10 @@
import numpy as N
from numpy.testing.utils import *
-class _GenericTest:
- def __init__(self, assert_func):
- self._assert_func = assert_func
+import unittest
+
+class _GenericTest(object):
def _test_equal(self, a, b):
self._assert_func(a, b)
@@ -47,9 +47,9 @@ class _GenericTest:
self._test_not_equal(a, b)
-class TestEqual(_GenericTest):
- def __init__(self):
- _GenericTest.__init__(self, assert_array_equal)
+class TestEqual(_GenericTest, unittest.TestCase):
+ def setUp(self):
+ self._assert_func = assert_array_equal
def test_generic_rank1(self):
"""Test rank 1 array for all dtypes."""
@@ -126,6 +126,42 @@ class TestEqual(_GenericTest):
self._test_not_equal(c, b)
-class TestAlmostEqual(_GenericTest):
- def __init__(self):
- _GenericTest.__init__(self, assert_array_almost_equal)
+class TestAlmostEqual(_GenericTest, unittest.TestCase):
+ def setUp(self):
+ self._assert_func = assert_array_almost_equal
+
+
+class TestRaises(unittest.TestCase):
+ def setUp(self):
+ class MyException(Exception):
+ pass
+
+ self.e = MyException
+
+ def raises_exception(self, e):
+ raise e
+
+ def does_not_raise_exception(self):
+ pass
+
+ def test_correct_catch(self):
+ f = raises(self.e)(self.raises_exception)(self.e)
+
+ def test_wrong_exception(self):
+ try:
+ f = raises(self.e)(self.raises_exception)(RuntimeError)
+ except RuntimeError:
+ return
+ else:
+ raise AssertionError("should have caught RuntimeError")
+
+ def test_catch_no_raise(self):
+ try:
+ f = raises(self.e)(self.does_not_raise_exception)()
+ except AssertionError:
+ return
+ else:
+ raise AssertionError("should have raised an AssertionError")
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index ab8d77b76..8374999ff 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -294,36 +294,33 @@ def assert_string_equal(actual, desired):
msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip()
assert actual==desired, msg
-# Ripped from nose.tools
-def raises(*exceptions):
- """Test must raise one of expected exceptions to pass.
-
- Example use::
-
- @raises(TypeError, ValueError)
- def test_raises_type_error():
- raise TypeError("This test passes")
-
- @raises(Exception):
- def test_that_fails_by_passing():
- pass
- If you want to test many assertions about exceptions in a single test,
- you may want to use `assert_raises` instead.
+def raises(*exceptions):
+ """ Assert that a test function raises one of the specified exceptions to
+ pass.
"""
- valid = ' or '.join([e.__name__ for e in exceptions])
- def decorate(func):
- name = func.__name__
- def newfunc(*arg, **kw):
+ # FIXME: when we transition to nose, just use its implementation. It's
+ # better.
+ def deco(function):
+ def f2(*args, **kwds):
try:
- func(*arg, **kw)
+ function(*args, **kwds)
except exceptions:
pass
except:
+ # Anything else.
raise
else:
- message = "%s() did not raise %s" % (name, valid)
- raise AssertionError(message)
- newfunc = make_decorator(func)(newfunc)
- return newfunc
- return decorate
+ raise AssertionError('%s() did not raise one of (%s)' %
+ (function.__name__, ', '.join([e.__name__ for e in exceptions])))
+ try:
+ f2.__name__ = function.__name__
+ except TypeError:
+ # Python 2.3 does not permit this.
+ pass
+ f2.__dict__ = function.__dict__
+ f2.__doc__ = function.__doc__
+ f2.__module__ = function.__module__
+ return f2
+
+ return deco