summaryrefslogtreecommitdiff
path: root/numpy/random/tests
diff options
context:
space:
mode:
authorRedRuM <44142765+zoj613@users.noreply.github.com>2019-10-16 22:13:59 +0200
committerRedRuM <44142765+zoj613@users.noreply.github.com>2019-10-16 22:13:59 +0200
commit9d9a89445bdf0edaaba4a4a3260493bc3f043a15 (patch)
treef6c98921e0ae5fcc282bbf5c1aef50b0d6471472 /numpy/random/tests
parent418a85eb738c76109779d1a70ba7da6c3a3f1603 (diff)
downloadnumpy-9d9a89445bdf0edaaba4a4a3260493bc3f043a15.tar.gz
TST: Parametrize tests using different methods
Diffstat (limited to 'numpy/random/tests')
-rw-r--r--numpy/random/tests/test_generator_mt19937.py37
1 files changed, 11 insertions, 26 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 70c8851fc..602904107 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -979,19 +979,13 @@ class TestRandomDist(object):
[5, 5, 3, 1, 2, 4]]])
assert_array_equal(actual, desired)
- def test_multivariate_normal(self):
+ @pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"])
+ def test_multivariate_normal(self, method):
random = Generator(MT19937(self.seed))
mean = (.123456789, 10)
cov = [[1, 0], [0, 1]]
size = (3, 2)
- actual_svd = random.multivariate_normal(mean, cov, size)
- # the seed needs to be reset for each method
- random = Generator(MT19937(self.seed))
- actual_eigh = random.multivariate_normal(mean, cov, size,
- method='eigh')
- random = Generator(MT19937(self.seed))
- actual_chol = random.multivariate_normal(mean, cov, size,
- method='cholesky')
+ actual = random.multivariate_normal(mean, cov, size, method=method)
desired = np.array([[[-1.747478062846581, 11.25613495182354 ],
[-0.9967333370066214, 10.342002097029821 ]],
[[ 0.7850019631242964, 11.181113712443013 ],
@@ -999,23 +993,16 @@ class TestRandomDist(object):
[[ 0.7130260107430003, 9.551628690083056 ],
[ 0.7127098726541128, 11.991709234143173 ]]])
- assert_array_almost_equal(actual_svd, desired, decimal=15)
- assert_array_almost_equal(actual_eigh, desired, decimal=15)
- assert_array_almost_equal(actual_chol, desired, decimal=15)
+ assert_array_almost_equal(actual, desired, decimal=15)
# Check for default size, was raising deprecation warning
- random = Generator(MT19937(self.seed))
- actual_svd = random.multivariate_normal(mean, cov)
- random = Generator(MT19937(self.seed))
- actual_eigh = random.multivariate_normal(mean, cov, method='eigh')
- random = Generator(MT19937(self.seed))
- actual_chol = random.multivariate_normal(mean, cov, method='cholesky')
+ #random = Generator(MT19937(self.seed))
+ actual = random.multivariate_normal(mean, cov, method=method)
# the factor matrix is the same for all methods
- random = Generator(MT19937(self.seed))
- desired = np.array([-1.747478062846581, 11.25613495182354])
- assert_array_almost_equal(actual_svd, desired, decimal=15)
- assert_array_almost_equal(actual_eigh, desired, decimal=15)
- assert_array_almost_equal(actual_chol, desired, decimal=15)
+ #random = Generator(MT19937(self.seed))
+ #desired = np.array([-1.747478062846581, 11.25613495182354])
+ desired = np.array([0.233278563284287, 9.424140804347195])
+ assert_array_almost_equal(actual, desired, decimal=15)
# Check that non symmetric covariance input raises exception when
# check_valid='raises' if using default svd method.
mean = [0, 0]
@@ -1044,9 +1031,7 @@ class TestRandomDist(object):
cov = np.array([[1, 0.1], [0.1, 1]], dtype=np.float32)
with suppress_warnings() as sup:
- random.multivariate_normal(mean, cov)
- random.multivariate_normal(mean, cov, method='eigh')
- random.multivariate_normal(mean, cov, method='cholesky')
+ random.multivariate_normal(mean, cov, method=method)
w = sup.record(RuntimeWarning)
assert len(w) == 0