diff options
author | RedRuM <44142765+zoj613@users.noreply.github.com> | 2019-10-16 22:13:59 +0200 |
---|---|---|
committer | RedRuM <44142765+zoj613@users.noreply.github.com> | 2019-10-16 22:13:59 +0200 |
commit | 9d9a89445bdf0edaaba4a4a3260493bc3f043a15 (patch) | |
tree | f6c98921e0ae5fcc282bbf5c1aef50b0d6471472 /numpy/random/tests | |
parent | 418a85eb738c76109779d1a70ba7da6c3a3f1603 (diff) | |
download | numpy-9d9a89445bdf0edaaba4a4a3260493bc3f043a15.tar.gz |
TST: Parametrize tests using different methods
Diffstat (limited to 'numpy/random/tests')
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 37 |
1 files changed, 11 insertions, 26 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 70c8851fc..602904107 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -979,19 +979,13 @@ class TestRandomDist(object): [5, 5, 3, 1, 2, 4]]]) assert_array_equal(actual, desired) - def test_multivariate_normal(self): + @pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"]) + def test_multivariate_normal(self, method): random = Generator(MT19937(self.seed)) mean = (.123456789, 10) cov = [[1, 0], [0, 1]] size = (3, 2) - actual_svd = random.multivariate_normal(mean, cov, size) - # the seed needs to be reset for each method - random = Generator(MT19937(self.seed)) - actual_eigh = random.multivariate_normal(mean, cov, size, - method='eigh') - random = Generator(MT19937(self.seed)) - actual_chol = random.multivariate_normal(mean, cov, size, - method='cholesky') + actual = random.multivariate_normal(mean, cov, size, method=method) desired = np.array([[[-1.747478062846581, 11.25613495182354 ], [-0.9967333370066214, 10.342002097029821 ]], [[ 0.7850019631242964, 11.181113712443013 ], @@ -999,23 +993,16 @@ class TestRandomDist(object): [[ 0.7130260107430003, 9.551628690083056 ], [ 0.7127098726541128, 11.991709234143173 ]]]) - assert_array_almost_equal(actual_svd, desired, decimal=15) - assert_array_almost_equal(actual_eigh, desired, decimal=15) - assert_array_almost_equal(actual_chol, desired, decimal=15) + assert_array_almost_equal(actual, desired, decimal=15) # Check for default size, was raising deprecation warning - random = Generator(MT19937(self.seed)) - actual_svd = random.multivariate_normal(mean, cov) - random = Generator(MT19937(self.seed)) - actual_eigh = random.multivariate_normal(mean, cov, method='eigh') - random = Generator(MT19937(self.seed)) - actual_chol = random.multivariate_normal(mean, cov, method='cholesky') + #random = Generator(MT19937(self.seed)) + actual = random.multivariate_normal(mean, cov, method=method) # the factor matrix is the same for all methods - random = Generator(MT19937(self.seed)) - desired = np.array([-1.747478062846581, 11.25613495182354]) - assert_array_almost_equal(actual_svd, desired, decimal=15) - assert_array_almost_equal(actual_eigh, desired, decimal=15) - assert_array_almost_equal(actual_chol, desired, decimal=15) + #random = Generator(MT19937(self.seed)) + #desired = np.array([-1.747478062846581, 11.25613495182354]) + desired = np.array([0.233278563284287, 9.424140804347195]) + assert_array_almost_equal(actual, desired, decimal=15) # Check that non symmetric covariance input raises exception when # check_valid='raises' if using default svd method. mean = [0, 0] @@ -1044,9 +1031,7 @@ class TestRandomDist(object): cov = np.array([[1, 0.1], [0.1, 1]], dtype=np.float32) with suppress_warnings() as sup: - random.multivariate_normal(mean, cov) - random.multivariate_normal(mean, cov, method='eigh') - random.multivariate_normal(mean, cov, method='cholesky') + random.multivariate_normal(mean, cov, method=method) w = sup.record(RuntimeWarning) assert len(w) == 0 |