import numpy as N from numpy.testing.utils import * import unittest class _GenericTest(object): def _test_equal(self, a, b): self._assert_func(a, b) def _test_not_equal(self, a, b): passed = False try: self._assert_func(a, b) passed = True except AssertionError: pass if passed: raise AssertionError("a and b are found equal but are not") def test_array_rank1_eq(self): """Test two equal array of rank 1 are found equal.""" a = N.array([1, 2]) b = N.array([1, 2]) self._test_equal(a, b) def test_array_rank1_noteq(self): """Test two different array of rank 1 are found not equal.""" a = N.array([1, 2]) b = N.array([2, 2]) self._test_not_equal(a, b) def test_array_rank2_eq(self): """Test two equal array of rank 2 are found equal.""" a = N.array([[1, 2], [3, 4]]) b = N.array([[1, 2], [3, 4]]) self._test_equal(a, b) def test_array_diffshape(self): """Test two arrays with different shapes are found not equal.""" a = N.array([1, 2]) b = N.array([[1, 2], [1, 2]]) self._test_not_equal(a, b) class TestEqual(_GenericTest, unittest.TestCase): def setUp(self): self._assert_func = assert_array_equal def test_generic_rank1(self): """Test rank 1 array for all dtypes.""" def foo(t): a = N.empty(2, t) a.fill(1) b = a.copy() c = a.copy() c.fill(0) self._test_equal(a, b) self._test_not_equal(c, b) # Test numeric types and object for t in '?bhilqpBHILQPfdgFDG': foo(t) # Test strings for t in ['S1', 'U1']: foo(t) def test_generic_rank3(self): """Test rank 3 array for all dtypes.""" def foo(t): a = N.empty((4, 2, 3), t) a.fill(1) b = a.copy() c = a.copy() c.fill(0) self._test_equal(a, b) self._test_not_equal(c, b) # Test numeric types and object for t in '?bhilqpBHILQPfdgFDG': foo(t) # Test strings for t in ['S1', 'U1']: foo(t) def test_nan_array(self): """Test arrays with nan values in them.""" a = N.array([1, 2, N.nan]) b = N.array([1, 2, N.nan]) self._test_equal(a, b) c = N.array([1, 2, 3]) self._test_not_equal(c, b) def test_string_arrays(self): """Test two arrays with different shapes are found not equal.""" a = N.array(['floupi', 'floupa']) b = N.array(['floupi', 'floupa']) self._test_equal(a, b) c = N.array(['floupipi', 'floupa']) self._test_not_equal(c, b) def test_recarrays(self): """Test record arrays.""" a = N.empty(2, [('floupi', N.float), ('floupa', N.float)]) a['floupi'] = [1, 2] a['floupa'] = [1, 2] b = a.copy() self._test_equal(a, b) c = N.empty(2, [('floupipi', N.float), ('floupa', N.float)]) c['floupipi'] = a['floupi'].copy() c['floupa'] = a['floupa'].copy() self._test_not_equal(c, b) class TestAlmostEqual(_GenericTest, unittest.TestCase): def setUp(self): self._assert_func = assert_array_almost_equal class TestRaises(unittest.TestCase): def setUp(self): class MyException(Exception): pass self.e = MyException def raises_exception(self, e): raise e def does_not_raise_exception(self): pass def test_correct_catch(self): f = raises(self.e)(self.raises_exception)(self.e) def test_wrong_exception(self): try: f = raises(self.e)(self.raises_exception)(RuntimeError) except RuntimeError: return else: raise AssertionError("should have caught RuntimeError") def test_catch_no_raise(self): try: f = raises(self.e)(self.does_not_raise_exception)() except AssertionError: return else: raise AssertionError("should have raised an AssertionError") if __name__ == '__main__': unittest.main()