From 465afd2bd933a69e0bca229284199425acfee1dd Mon Sep 17 00:00:00 2001 From: David Cournapeau Date: Mon, 2 Mar 2009 14:18:15 +0000 Subject: Abstract away dtype for norm test. --- numpy/linalg/tests/test_linalg.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) (limited to 'numpy') diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 2d6d29ffa..f6d21fd85 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -1,6 +1,7 @@ """ Test functions for linalg module """ +import numpy as np from numpy.testing import * from numpy import array, single, double, csingle, cdouble, dot, identity from numpy import multiply, atleast_2d, inf, asarray, matrix @@ -257,17 +258,19 @@ class TestEigh(HermitianTestCase, TestCase): evalues, evectors = linalg.eig(a) assert_almost_equal(ev, evalues) -class TestNorm(TestCase): +class _TestNorm(TestCase): + dt = None def test_empty(self): assert_equal(norm([]), 0.0) - assert_equal(norm(array([], dtype = double)), 0.0) - assert_equal(norm(atleast_2d(array([], dtype = double))), 0.0) + assert_equal(norm(array([], dtype=self.dt)), 0.0) + assert_equal(norm(atleast_2d(array([], dtype=self.dt))), 0.0) def test_vector(self): a = [1.0,2.0,3.0,4.0] b = [-1.0,-2.0,-3.0,-4.0] c = [-1.0, 2.0,-3.0, 4.0] - for v in (a,array(a),b,array(b),c,array(c)): + for v in (a,array(a, dtype=self.dt),b,array(b, dtype=self.dt),c,array(c, + dtype=self.dt)): assert_almost_equal(norm(v), 30**0.5) assert_almost_equal(norm(v,inf), 4.0) assert_almost_equal(norm(v,-inf), 1.0) @@ -283,19 +286,22 @@ class TestNorm(TestCase): self.assertRaises(ValueError, norm, array([1., 2., 3.]), 'fro') def test_matrix(self): - A = matrix([[1.,3.],[5.,7.]], dtype=single) - A = matrix([[1.,3.],[5.,7.]], dtype=single) + A = matrix([[1.,3.],[5.,7.]], dtype=self.dt) + A = matrix([[1.,3.],[5.,7.]], dtype=self.dt) assert_almost_equal(norm(A), 84**0.5) assert_almost_equal(norm(A,'fro'), 84**0.5) assert_almost_equal(norm(A,inf), 12.0) assert_almost_equal(norm(A,-inf), 4.0) assert_almost_equal(norm(A,1), 10.0) assert_almost_equal(norm(A,-1), 6.0) - assert_almost_equal(norm(A,2), 9.12310563) - assert_almost_equal(norm(A,-2), 0.87689437) + assert_almost_equal(norm(A,2), 9.1231056256176615) + assert_almost_equal(norm(A,-2), 0.87689437438234041) self.assertRaises(ValueError, norm, A, 'nofro') self.assertRaises(ValueError, norm, A, -3) +class TestNormDouble(_TestNorm): + dt = np.double + if __name__ == "__main__": run_module_suite() -- cgit v1.2.1