summaryrefslogtreecommitdiff
path: root/numpy/random/tests/test_generator_mt19937.py
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2019-11-03 14:39:19 -0500
committerGitHub <noreply@github.com>2019-11-03 14:39:19 -0500
commit04229deb0b49861f68c016bab35c791128b7d2d6 (patch)
treef746eeaa31cb757de6ee7cb5253c6ae4c0b4218c /numpy/random/tests/test_generator_mt19937.py
parent2be03c8d25b14b654064e953feac7d210e6bd44d (diff)
parent87840dc2f09ebeed12c3b9fef68b94dc04f4d16f (diff)
downloadnumpy-04229deb0b49861f68c016bab35c791128b7d2d6.tar.gz
Merge pull request #14197 from zoj613/multivariate_normal_speedups
ENH: Multivariate normal speedups
Diffstat (limited to 'numpy/random/tests/test_generator_mt19937.py')
-rw-r--r--numpy/random/tests/test_generator_mt19937.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index d4502d276..d85de6b6d 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -3,6 +3,8 @@ import sys
import pytest
import numpy as np
+from numpy.dual import cholesky, eigh, svd
+from numpy.linalg import LinAlgError
from numpy.testing import (
assert_, assert_raises, assert_equal, assert_allclose,
assert_warns, assert_no_warnings, assert_array_equal,
@@ -1196,12 +1198,13 @@ class TestRandomDist(object):
[5, 5, 3, 1, 2, 4]]])
assert_array_equal(actual, desired)
- def test_multivariate_normal(self):
+ @pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"])
+ def test_multivariate_normal(self, method):
random = Generator(MT19937(self.seed))
mean = (.123456789, 10)
cov = [[1, 0], [0, 1]]
size = (3, 2)
- actual = random.multivariate_normal(mean, cov, size)
+ actual = random.multivariate_normal(mean, cov, size, method=method)
desired = np.array([[[-1.747478062846581, 11.25613495182354 ],
[-0.9967333370066214, 10.342002097029821 ]],
[[ 0.7850019631242964, 11.181113712443013 ],
@@ -1212,15 +1215,24 @@ class TestRandomDist(object):
assert_array_almost_equal(actual, desired, decimal=15)
# Check for default size, was raising deprecation warning
- actual = random.multivariate_normal(mean, cov)
+ actual = random.multivariate_normal(mean, cov, method=method)
desired = np.array([0.233278563284287, 9.424140804347195])
assert_array_almost_equal(actual, desired, decimal=15)
+ # Check that non symmetric covariance input raises exception when
+ # check_valid='raises' if using default svd method.
+ mean = [0, 0]
+ cov = [[1, 2], [1, 2]]
+ assert_raises(ValueError, random.multivariate_normal, mean, cov,
+ check_valid='raise')
# Check that non positive-semidefinite covariance warns with
# RuntimeWarning
- mean = [0, 0]
cov = [[1, 2], [2, 1]]
assert_warns(RuntimeWarning, random.multivariate_normal, mean, cov)
+ assert_warns(RuntimeWarning, random.multivariate_normal, mean, cov,
+ method='eigh')
+ assert_raises(LinAlgError, random.multivariate_normal, mean, cov,
+ method='cholesky')
# and that it doesn't warn with RuntimeWarning check_valid='ignore'
assert_no_warnings(random.multivariate_normal, mean, cov,
@@ -1229,10 +1241,12 @@ class TestRandomDist(object):
# and that it raises with RuntimeWarning check_valid='raises'
assert_raises(ValueError, random.multivariate_normal, mean, cov,
check_valid='raise')
+ assert_raises(ValueError, random.multivariate_normal, mean, cov,
+ check_valid='raise', method='eigh')
cov = np.array([[1, 0.1], [0.1, 1]], dtype=np.float32)
with suppress_warnings() as sup:
- random.multivariate_normal(mean, cov)
+ random.multivariate_normal(mean, cov, method=method)
w = sup.record(RuntimeWarning)
assert len(w) == 0