summaryrefslogtreecommitdiff
path: root/numpy/testing/tests/test_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/tests/test_utils.py')
-rw-r--r--numpy/testing/tests/test_utils.py54
1 files changed, 45 insertions, 9 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 48cd89cdd..6f40c778b 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -1,10 +1,10 @@
import numpy as N
from numpy.testing.utils import *
-class _GenericTest:
- def __init__(self, assert_func):
- self._assert_func = assert_func
+import unittest
+
+class _GenericTest(object):
def _test_equal(self, a, b):
self._assert_func(a, b)
@@ -47,9 +47,9 @@ class _GenericTest:
self._test_not_equal(a, b)
-class TestEqual(_GenericTest):
- def __init__(self):
- _GenericTest.__init__(self, assert_array_equal)
+class TestEqual(_GenericTest, unittest.TestCase):
+ def setUp(self):
+ self._assert_func = assert_array_equal
def test_generic_rank1(self):
"""Test rank 1 array for all dtypes."""
@@ -126,6 +126,42 @@ class TestEqual(_GenericTest):
self._test_not_equal(c, b)
-class TestAlmostEqual(_GenericTest):
- def __init__(self):
- _GenericTest.__init__(self, assert_array_almost_equal)
+class TestAlmostEqual(_GenericTest, unittest.TestCase):
+ def setUp(self):
+ self._assert_func = assert_array_almost_equal
+
+
+class TestRaises(unittest.TestCase):
+ def setUp(self):
+ class MyException(Exception):
+ pass
+
+ self.e = MyException
+
+ def raises_exception(self, e):
+ raise e
+
+ def does_not_raise_exception(self):
+ pass
+
+ def test_correct_catch(self):
+ f = raises(self.e)(self.raises_exception)(self.e)
+
+ def test_wrong_exception(self):
+ try:
+ f = raises(self.e)(self.raises_exception)(RuntimeError)
+ except RuntimeError:
+ return
+ else:
+ raise AssertionError("should have caught RuntimeError")
+
+ def test_catch_no_raise(self):
+ try:
+ f = raises(self.e)(self.does_not_raise_exception)()
+ except AssertionError:
+ return
+ else:
+ raise AssertionError("should have raised an AssertionError")
+
+if __name__ == '__main__':
+ unittest.main()