diff options
author | RedRuM <44142765+zoj613@users.noreply.github.com> | 2019-08-17 23:45:26 +0200 |
---|---|---|
committer | RedRuM <44142765+zoj613@users.noreply.github.com> | 2019-08-17 23:45:26 +0200 |
commit | ad56f16191adb1b45576d6c3214c19dffdd0eb34 (patch) | |
tree | 1ed053a5f23d23768b84ee5953162fe18ff6b8b4 /numpy/random/tests | |
parent | e6f9a354ff3e2354d343493f6463328895f702c2 (diff) | |
download | numpy-ad56f16191adb1b45576d6c3214c19dffdd0eb34.tar.gz |
ENH: re-write the implementation with new keywords method and use_factor
Diffstat (limited to 'numpy/random/tests')
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 87 |
1 files changed, 81 insertions, 6 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index a962fe84e..2677b04fa 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -3,6 +3,8 @@ import sys import pytest import numpy as np +from numpy.dual import cholesky, eigh, svd +from numpy.linalg import LinAlgError from numpy.testing import ( assert_, assert_raises, assert_equal, assert_warns, assert_no_warnings, assert_array_equal, @@ -982,7 +984,18 @@ class TestRandomDist(object): mean = (.123456789, 10) cov = [[1, 0], [0, 1]] size = (3, 2) - actual = random.multivariate_normal(mean, cov, size) + actual_svd = random.multivariate_normal(mean, cov, size) + # the seed needs to be reset for each method + random = Generator(MT19937(self.seed)) + actual_eigh = random.multivariate_normal(mean, cov, size, + method='eigh') + random = Generator(MT19937(self.seed)) + actual_chol = random.multivariate_normal(mean, cov, size, + method='cholesky') + # the factor matrix is the same for all methods + random = Generator(MT19937(self.seed)) + actual_factor = random.multivariate_normal(mean, cov, size, + use_factor=True) desired = np.array([[[-1.747478062846581, 11.25613495182354 ], [-0.9967333370066214, 10.342002097029821 ]], [[ 0.7850019631242964, 11.181113712443013 ], @@ -990,18 +1003,76 @@ class TestRandomDist(object): [[ 0.7130260107430003, 9.551628690083056 ], [ 0.7127098726541128, 11.991709234143173 ]]]) - assert_array_almost_equal(actual, desired, decimal=15) + assert_array_almost_equal(actual_svd, desired, decimal=15) + assert_array_almost_equal(actual_eigh, desired, decimal=15) + assert_array_almost_equal(actual_chol, desired, decimal=15) + assert_array_almost_equal(actual_factor, desired, decimal=15) # Check for default size, was raising deprecation warning - actual = random.multivariate_normal(mean, cov) - desired = np.array([0.233278563284287, 9.424140804347195]) - assert_array_almost_equal(actual, desired, decimal=15) + random = Generator(MT19937(self.seed)) + actual_svd = random.multivariate_normal(mean, cov) + random = Generator(MT19937(self.seed)) + actual_eigh = random.multivariate_normal(mean, cov, method='eigh') + random = Generator(MT19937(self.seed)) + actual_chol = random.multivariate_normal(mean, cov, method='cholesky') + # the factor matrix is the same for all methods + random = Generator(MT19937(self.seed)) + actual_factor = random.multivariate_normal(mean, cov, use_factor=True) + desired = np.array([-1.747478062846581, 11.25613495182354]) + assert_array_almost_equal(actual_svd, desired, decimal=15) + assert_array_almost_equal(actual_eigh, desired, decimal=15) + assert_array_almost_equal(actual_chol, desired, decimal=15) + assert_array_almost_equal(actual_factor, desired, decimal=15) + + # test if raises ValueError when non-square factor is used + non_square_factor = [[1, 0], [0, 1], [1, 0]] + assert_raises(ValueError, random.multivariate_normal, mean, + non_square_factor, use_factor=True) + + # Check that non symmetric covariance input raises exception when + # check_valid='raises' if using default svd method. + mean = [0, 0] + cov = [[1, 2], [1, 2]] + assert_raises(ValueError, random.multivariate_normal, mean, cov, + check_valid='raise') + + # Check if each method used with use_factor=False returns the same + # output when used with use_factor=True + cov = [[2, 1], [1, 2]] + random = Generator(MT19937(self.seed)) + desired_svd = random.multivariate_normal(mean, cov) + u, s, vh = svd(cov) + svd_factor = np.sqrt(s)[:, None] * vh + random = Generator(MT19937(self.seed)) + actual_svd = random.multivariate_normal(mean, svd_factor.T, + use_factor=True) + random = Generator(MT19937(self.seed)) + desired_eigh = random.multivariate_normal(mean, cov, method='eigh') + s, u = eigh(cov) + eigh_factor = u * np.sqrt(s) + random = Generator(MT19937(self.seed)) + actual_eigh = random.multivariate_normal(mean, eigh_factor, + method='eigh', + use_factor=True) + random = Generator(MT19937(self.seed)) + desired_chol = random.multivariate_normal(mean, cov, method='cholesky') + chol_factor = cholesky(cov) + random = Generator(MT19937(self.seed)) + actual_chol = random.multivariate_normal(mean, chol_factor, + method='cholesky', + use_factor=True) + assert_array_almost_equal(actual_svd, desired_svd, decimal=15) + assert_array_almost_equal(actual_eigh, desired_eigh, decimal=15) + assert_array_almost_equal(actual_chol, desired_chol, decimal=15) # Check that non positive-semidefinite covariance warns with # RuntimeWarning - mean = [0, 0] cov = [[1, 2], [2, 1]] assert_warns(RuntimeWarning, random.multivariate_normal, mean, cov) + assert_warns(RuntimeWarning, random.multivariate_normal, mean, cov, + method='eigh') + assert_raises(LinAlgError, random.multivariate_normal, mean, cov, + method='cholesky') # and that it doesn't warn with RuntimeWarning check_valid='ignore' assert_no_warnings(random.multivariate_normal, mean, cov, @@ -1010,10 +1081,14 @@ class TestRandomDist(object): # and that it raises with RuntimeWarning check_valid='raises' assert_raises(ValueError, random.multivariate_normal, mean, cov, check_valid='raise') + assert_raises(ValueError, random.multivariate_normal, mean, cov, + check_valid='raise', method='eigh') cov = np.array([[1, 0.1], [0.1, 1]], dtype=np.float32) with suppress_warnings() as sup: random.multivariate_normal(mean, cov) + random.multivariate_normal(mean, cov, method='eigh') + random.multivariate_normal(mean, cov, method='cholesky') w = sup.record(RuntimeWarning) assert len(w) == 0 |