diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2020-04-04 19:57:55 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-04 19:57:55 -0600 |
commit | 569dfcd79aeb2a584e008861e5e29b7b5115d94e (patch) | |
tree | 931c0e7de585caecf890a48c3c11809642ed7e93 /numpy/random/tests | |
parent | 063b3140ec3792f0420da9f386d8240693213ec9 (diff) | |
parent | 72457f01832d10c72a1839aafc178cf4f53449cb (diff) | |
download | numpy-569dfcd79aeb2a584e008861e5e29b7b5115d94e.tar.gz |
Merge pull request #15872 from Balandat/fix_eigh_mvn_sampling
BUG: Fix eigh and cholesky methods of numpy.random.multivariate_normal
Diffstat (limited to 'numpy/random/tests')
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 6f4407373..b10c1310e 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -1242,6 +1242,17 @@ class TestRandomDist: assert_raises(ValueError, random.multivariate_normal, mean, cov, check_valid='raise', method='eigh') + # check degenerate samples from singular covariance matrix + cov = [[1, 1], [1, 1]] + if method in ('svd', 'eigh'): + samples = random.multivariate_normal(mean, cov, size=(3, 2), + method=method) + assert_array_almost_equal(samples[..., 0], samples[..., 1], + decimal=6) + else: + assert_raises(LinAlgError, random.multivariate_normal, mean, cov, + method='cholesky') + cov = np.array([[1, 0.1], [0.1, 1]], dtype=np.float32) with suppress_warnings() as sup: random.multivariate_normal(mean, cov, method=method) @@ -1259,6 +1270,19 @@ class TestRandomDist: assert_raises(ValueError, random.multivariate_normal, mu, np.eye(3)) + @pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"]) + def test_multivariate_normal_basic_stats(self, method): + random = Generator(MT19937(self.seed)) + n_s = 1000 + mean = np.array([1, 2]) + cov = np.array([[2, 1], [1, 2]]) + s = random.multivariate_normal(mean, cov, size=(n_s,), method=method) + s_center = s - mean + cov_emp = (s_center.T @ s_center) / (n_s - 1) + # these are pretty loose and are only designed to detect major errors + assert np.all(np.abs(s_center.mean(-2)) < 0.1) + assert np.all(np.abs(cov_emp - cov) < 0.2) + def test_negative_binomial(self): random = Generator(MT19937(self.seed)) actual = random.negative_binomial(n=100, p=.12345, size=(3, 2)) |