diff options
author | Matti Picus <matti.picus@gmail.com> | 2019-11-03 14:39:19 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-11-03 14:39:19 -0500 |
commit | 04229deb0b49861f68c016bab35c791128b7d2d6 (patch) | |
tree | f746eeaa31cb757de6ee7cb5253c6ae4c0b4218c /numpy | |
parent | 2be03c8d25b14b654064e953feac7d210e6bd44d (diff) | |
parent | 87840dc2f09ebeed12c3b9fef68b94dc04f4d16f (diff) | |
download | numpy-04229deb0b49861f68c016bab35c791128b7d2d6.tar.gz |
Merge pull request #14197 from zoj613/multivariate_normal_speedups
ENH: Multivariate normal speedups
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/random/_generator.pyx | 59 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 24 |
2 files changed, 69 insertions, 14 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index b842d6a32..4b012af2f 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -3456,7 +3456,7 @@ cdef class Generator: # Multivariate distributions: def multivariate_normal(self, mean, cov, size=None, check_valid='warn', - tol=1e-8): + tol=1e-8, *, method='svd'): """ multivariate_normal(mean, cov, size=None, check_valid='warn', tol=1e-8) @@ -3486,6 +3486,15 @@ cdef class Generator: tol : float, optional Tolerance when checking the singular values in covariance matrix. cov is cast to double before the check. + method : { 'svd', 'eigh', 'cholesky'}, optional + The cov input is used to compute a factor matrix A such that + ``A @ A.T = cov``. This argument is used to select the method + used to compute the factor matrix A. The default method 'svd' is + the slowest, while 'cholesky' is the fastest but less robust than + the slowest method. The method `eigh` uses eigen decomposition to + compute A and is faster than svd but slower than cholesky. + + .. versionadded:: 1.18.0 Returns ------- @@ -3546,10 +3555,16 @@ cdef class Generator: -------- >>> mean = (1, 2) >>> cov = [[1, 0], [0, 1]] - >>> x = np.random.default_rng().multivariate_normal(mean, cov, (3, 3)) + >>> rng = np.random.default_rng() + >>> x = rng.multivariate_normal(mean, cov, (3, 3)) >>> x.shape (3, 3, 2) + We can use a different method other than the default to factorize cov: + >>> y = rng.multivariate_normal(mean, cov, (3, 3), method='cholesky') + >>> y.shape + (3, 3, 2) + The following is probably true, given that 0.6 is roughly twice the standard deviation: @@ -3557,7 +3572,9 @@ cdef class Generator: [True, True] # random """ - from numpy.dual import svd + if method not in {'eigh', 'svd', 'cholesky'}: + raise ValueError( + "method must be one of {'eigh', 'svd', 'cholesky'}") # Check preconditions on arguments mean = np.array(mean) @@ -3600,13 +3617,27 @@ cdef class Generator: # GH10839, ensure double to make tol meaningful cov = cov.astype(np.double) - (u, s, v) = svd(cov) + if method == 'svd': + from numpy.dual import svd + (u, s, vh) = svd(cov) + elif method == 'eigh': + from numpy.dual import eigh + # could call linalg.svd(hermitian=True), but that calculates a vh we don't need + (s, u) = eigh(cov) + else: + from numpy.dual import cholesky + l = cholesky(cov) - if check_valid != 'ignore': + # make sure check_valid is ignored whe method == 'cholesky' + # since the decomposition will have failed if cov is not valid. + if check_valid != 'ignore' and method != 'cholesky': if check_valid != 'warn' and check_valid != 'raise': - raise ValueError("check_valid must equal 'warn', 'raise', or 'ignore'") - - psd = np.allclose(np.dot(v.T * s, v), cov, rtol=tol, atol=tol) + raise ValueError( + "check_valid must equal 'warn', 'raise', or 'ignore'") + if method == 'svd': + psd = np.allclose(np.dot(vh.T * s, vh), cov, rtol=tol, atol=tol) + else: + psd = not np.any(s < -tol) if not psd: if check_valid == 'warn': warnings.warn("covariance is not positive-semidefinite.", @@ -3614,7 +3645,17 @@ cdef class Generator: else: raise ValueError("covariance is not positive-semidefinite.") - x = np.dot(x, np.sqrt(s)[:, None] * v) + if method == 'cholesky': + _factor = l + elif method == 'eigh': + # if check_valid == 'ignore' we need to ensure that np.sqrt does not + # return a NaN if s is a very small negative number that is + # approximately zero or when the covariance is not positive-semidefinite + _factor = u * np.sqrt(abs(s)) + else: + _factor = np.sqrt(s)[:, None] * vh + + x = np.dot(x, _factor) x += mean x.shape = tuple(final_shape) return x diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index d4502d276..d85de6b6d 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -3,6 +3,8 @@ import sys import pytest import numpy as np +from numpy.dual import cholesky, eigh, svd +from numpy.linalg import LinAlgError from numpy.testing import ( assert_, assert_raises, assert_equal, assert_allclose, assert_warns, assert_no_warnings, assert_array_equal, @@ -1196,12 +1198,13 @@ class TestRandomDist(object): [5, 5, 3, 1, 2, 4]]]) assert_array_equal(actual, desired) - def test_multivariate_normal(self): + @pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"]) + def test_multivariate_normal(self, method): random = Generator(MT19937(self.seed)) mean = (.123456789, 10) cov = [[1, 0], [0, 1]] size = (3, 2) - actual = random.multivariate_normal(mean, cov, size) + actual = random.multivariate_normal(mean, cov, size, method=method) desired = np.array([[[-1.747478062846581, 11.25613495182354 ], [-0.9967333370066214, 10.342002097029821 ]], [[ 0.7850019631242964, 11.181113712443013 ], @@ -1212,15 +1215,24 @@ class TestRandomDist(object): assert_array_almost_equal(actual, desired, decimal=15) # Check for default size, was raising deprecation warning - actual = random.multivariate_normal(mean, cov) + actual = random.multivariate_normal(mean, cov, method=method) desired = np.array([0.233278563284287, 9.424140804347195]) assert_array_almost_equal(actual, desired, decimal=15) + # Check that non symmetric covariance input raises exception when + # check_valid='raises' if using default svd method. + mean = [0, 0] + cov = [[1, 2], [1, 2]] + assert_raises(ValueError, random.multivariate_normal, mean, cov, + check_valid='raise') # Check that non positive-semidefinite covariance warns with # RuntimeWarning - mean = [0, 0] cov = [[1, 2], [2, 1]] assert_warns(RuntimeWarning, random.multivariate_normal, mean, cov) + assert_warns(RuntimeWarning, random.multivariate_normal, mean, cov, + method='eigh') + assert_raises(LinAlgError, random.multivariate_normal, mean, cov, + method='cholesky') # and that it doesn't warn with RuntimeWarning check_valid='ignore' assert_no_warnings(random.multivariate_normal, mean, cov, @@ -1229,10 +1241,12 @@ class TestRandomDist(object): # and that it raises with RuntimeWarning check_valid='raises' assert_raises(ValueError, random.multivariate_normal, mean, cov, check_valid='raise') + assert_raises(ValueError, random.multivariate_normal, mean, cov, + check_valid='raise', method='eigh') cov = np.array([[1, 0.1], [0.1, 1]], dtype=np.float32) with suppress_warnings() as sup: - random.multivariate_normal(mean, cov) + random.multivariate_normal(mean, cov, method=method) w = sup.record(RuntimeWarning) assert len(w) == 0 |