summaryrefslogtreecommitdiff
path: root/numpy/random/tests
diff options
context:
space:
mode:
authorRedRuM <44142765+zoj613@users.noreply.github.com>2019-08-17 23:45:26 +0200
committerRedRuM <44142765+zoj613@users.noreply.github.com>2019-08-17 23:45:26 +0200
commitad56f16191adb1b45576d6c3214c19dffdd0eb34 (patch)
tree1ed053a5f23d23768b84ee5953162fe18ff6b8b4 /numpy/random/tests
parente6f9a354ff3e2354d343493f6463328895f702c2 (diff)
downloadnumpy-ad56f16191adb1b45576d6c3214c19dffdd0eb34.tar.gz
ENH: re-write the implementation with new keywords method and use_factor
Diffstat (limited to 'numpy/random/tests')
-rw-r--r--numpy/random/tests/test_generator_mt19937.py87
1 files changed, 81 insertions, 6 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index a962fe84e..2677b04fa 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -3,6 +3,8 @@ import sys
import pytest
import numpy as np
+from numpy.dual import cholesky, eigh, svd
+from numpy.linalg import LinAlgError
from numpy.testing import (
assert_, assert_raises, assert_equal,
assert_warns, assert_no_warnings, assert_array_equal,
@@ -982,7 +984,18 @@ class TestRandomDist(object):
mean = (.123456789, 10)
cov = [[1, 0], [0, 1]]
size = (3, 2)
- actual = random.multivariate_normal(mean, cov, size)
+ actual_svd = random.multivariate_normal(mean, cov, size)
+ # the seed needs to be reset for each method
+ random = Generator(MT19937(self.seed))
+ actual_eigh = random.multivariate_normal(mean, cov, size,
+ method='eigh')
+ random = Generator(MT19937(self.seed))
+ actual_chol = random.multivariate_normal(mean, cov, size,
+ method='cholesky')
+ # the factor matrix is the same for all methods
+ random = Generator(MT19937(self.seed))
+ actual_factor = random.multivariate_normal(mean, cov, size,
+ use_factor=True)
desired = np.array([[[-1.747478062846581, 11.25613495182354 ],
[-0.9967333370066214, 10.342002097029821 ]],
[[ 0.7850019631242964, 11.181113712443013 ],
@@ -990,18 +1003,76 @@ class TestRandomDist(object):
[[ 0.7130260107430003, 9.551628690083056 ],
[ 0.7127098726541128, 11.991709234143173 ]]])
- assert_array_almost_equal(actual, desired, decimal=15)
+ assert_array_almost_equal(actual_svd, desired, decimal=15)
+ assert_array_almost_equal(actual_eigh, desired, decimal=15)
+ assert_array_almost_equal(actual_chol, desired, decimal=15)
+ assert_array_almost_equal(actual_factor, desired, decimal=15)
# Check for default size, was raising deprecation warning
- actual = random.multivariate_normal(mean, cov)
- desired = np.array([0.233278563284287, 9.424140804347195])
- assert_array_almost_equal(actual, desired, decimal=15)
+ random = Generator(MT19937(self.seed))
+ actual_svd = random.multivariate_normal(mean, cov)
+ random = Generator(MT19937(self.seed))
+ actual_eigh = random.multivariate_normal(mean, cov, method='eigh')
+ random = Generator(MT19937(self.seed))
+ actual_chol = random.multivariate_normal(mean, cov, method='cholesky')
+ # the factor matrix is the same for all methods
+ random = Generator(MT19937(self.seed))
+ actual_factor = random.multivariate_normal(mean, cov, use_factor=True)
+ desired = np.array([-1.747478062846581, 11.25613495182354])
+ assert_array_almost_equal(actual_svd, desired, decimal=15)
+ assert_array_almost_equal(actual_eigh, desired, decimal=15)
+ assert_array_almost_equal(actual_chol, desired, decimal=15)
+ assert_array_almost_equal(actual_factor, desired, decimal=15)
+
+ # test if raises ValueError when non-square factor is used
+ non_square_factor = [[1, 0], [0, 1], [1, 0]]
+ assert_raises(ValueError, random.multivariate_normal, mean,
+ non_square_factor, use_factor=True)
+
+ # Check that non symmetric covariance input raises exception when
+ # check_valid='raises' if using default svd method.
+ mean = [0, 0]
+ cov = [[1, 2], [1, 2]]
+ assert_raises(ValueError, random.multivariate_normal, mean, cov,
+ check_valid='raise')
+
+ # Check if each method used with use_factor=False returns the same
+ # output when used with use_factor=True
+ cov = [[2, 1], [1, 2]]
+ random = Generator(MT19937(self.seed))
+ desired_svd = random.multivariate_normal(mean, cov)
+ u, s, vh = svd(cov)
+ svd_factor = np.sqrt(s)[:, None] * vh
+ random = Generator(MT19937(self.seed))
+ actual_svd = random.multivariate_normal(mean, svd_factor.T,
+ use_factor=True)
+ random = Generator(MT19937(self.seed))
+ desired_eigh = random.multivariate_normal(mean, cov, method='eigh')
+ s, u = eigh(cov)
+ eigh_factor = u * np.sqrt(s)
+ random = Generator(MT19937(self.seed))
+ actual_eigh = random.multivariate_normal(mean, eigh_factor,
+ method='eigh',
+ use_factor=True)
+ random = Generator(MT19937(self.seed))
+ desired_chol = random.multivariate_normal(mean, cov, method='cholesky')
+ chol_factor = cholesky(cov)
+ random = Generator(MT19937(self.seed))
+ actual_chol = random.multivariate_normal(mean, chol_factor,
+ method='cholesky',
+ use_factor=True)
+ assert_array_almost_equal(actual_svd, desired_svd, decimal=15)
+ assert_array_almost_equal(actual_eigh, desired_eigh, decimal=15)
+ assert_array_almost_equal(actual_chol, desired_chol, decimal=15)
# Check that non positive-semidefinite covariance warns with
# RuntimeWarning
- mean = [0, 0]
cov = [[1, 2], [2, 1]]
assert_warns(RuntimeWarning, random.multivariate_normal, mean, cov)
+ assert_warns(RuntimeWarning, random.multivariate_normal, mean, cov,
+ method='eigh')
+ assert_raises(LinAlgError, random.multivariate_normal, mean, cov,
+ method='cholesky')
# and that it doesn't warn with RuntimeWarning check_valid='ignore'
assert_no_warnings(random.multivariate_normal, mean, cov,
@@ -1010,10 +1081,14 @@ class TestRandomDist(object):
# and that it raises with RuntimeWarning check_valid='raises'
assert_raises(ValueError, random.multivariate_normal, mean, cov,
check_valid='raise')
+ assert_raises(ValueError, random.multivariate_normal, mean, cov,
+ check_valid='raise', method='eigh')
cov = np.array([[1, 0.1], [0.1, 1]], dtype=np.float32)
with suppress_warnings() as sup:
random.multivariate_normal(mean, cov)
+ random.multivariate_normal(mean, cov, method='eigh')
+ random.multivariate_normal(mean, cov, method='cholesky')
w = sup.record(RuntimeWarning)
assert len(w) == 0