summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2019-11-03 14:39:19 -0500
committerGitHub <noreply@github.com>2019-11-03 14:39:19 -0500
commit04229deb0b49861f68c016bab35c791128b7d2d6 (patch)
treef746eeaa31cb757de6ee7cb5253c6ae4c0b4218c /numpy
parent2be03c8d25b14b654064e953feac7d210e6bd44d (diff)
parent87840dc2f09ebeed12c3b9fef68b94dc04f4d16f (diff)
downloadnumpy-04229deb0b49861f68c016bab35c791128b7d2d6.tar.gz
Merge pull request #14197 from zoj613/multivariate_normal_speedups
ENH: Multivariate normal speedups
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/_generator.pyx59
-rw-r--r--numpy/random/tests/test_generator_mt19937.py24
2 files changed, 69 insertions, 14 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index b842d6a32..4b012af2f 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -3456,7 +3456,7 @@ cdef class Generator:
# Multivariate distributions:
def multivariate_normal(self, mean, cov, size=None, check_valid='warn',
- tol=1e-8):
+ tol=1e-8, *, method='svd'):
"""
multivariate_normal(mean, cov, size=None, check_valid='warn', tol=1e-8)
@@ -3486,6 +3486,15 @@ cdef class Generator:
tol : float, optional
Tolerance when checking the singular values in covariance matrix.
cov is cast to double before the check.
+ method : { 'svd', 'eigh', 'cholesky'}, optional
+ The cov input is used to compute a factor matrix A such that
+ ``A @ A.T = cov``. This argument is used to select the method
+ used to compute the factor matrix A. The default method 'svd' is
+ the slowest, while 'cholesky' is the fastest but less robust than
+ the slowest method. The method `eigh` uses eigen decomposition to
+ compute A and is faster than svd but slower than cholesky.
+
+ .. versionadded:: 1.18.0
Returns
-------
@@ -3546,10 +3555,16 @@ cdef class Generator:
--------
>>> mean = (1, 2)
>>> cov = [[1, 0], [0, 1]]
- >>> x = np.random.default_rng().multivariate_normal(mean, cov, (3, 3))
+ >>> rng = np.random.default_rng()
+ >>> x = rng.multivariate_normal(mean, cov, (3, 3))
>>> x.shape
(3, 3, 2)
+ We can use a different method other than the default to factorize cov:
+ >>> y = rng.multivariate_normal(mean, cov, (3, 3), method='cholesky')
+ >>> y.shape
+ (3, 3, 2)
+
The following is probably true, given that 0.6 is roughly twice the
standard deviation:
@@ -3557,7 +3572,9 @@ cdef class Generator:
[True, True] # random
"""
- from numpy.dual import svd
+ if method not in {'eigh', 'svd', 'cholesky'}:
+ raise ValueError(
+ "method must be one of {'eigh', 'svd', 'cholesky'}")
# Check preconditions on arguments
mean = np.array(mean)
@@ -3600,13 +3617,27 @@ cdef class Generator:
# GH10839, ensure double to make tol meaningful
cov = cov.astype(np.double)
- (u, s, v) = svd(cov)
+ if method == 'svd':
+ from numpy.dual import svd
+ (u, s, vh) = svd(cov)
+ elif method == 'eigh':
+ from numpy.dual import eigh
+ # could call linalg.svd(hermitian=True), but that calculates a vh we don't need
+ (s, u) = eigh(cov)
+ else:
+ from numpy.dual import cholesky
+ l = cholesky(cov)
- if check_valid != 'ignore':
+ # make sure check_valid is ignored whe method == 'cholesky'
+ # since the decomposition will have failed if cov is not valid.
+ if check_valid != 'ignore' and method != 'cholesky':
if check_valid != 'warn' and check_valid != 'raise':
- raise ValueError("check_valid must equal 'warn', 'raise', or 'ignore'")
-
- psd = np.allclose(np.dot(v.T * s, v), cov, rtol=tol, atol=tol)
+ raise ValueError(
+ "check_valid must equal 'warn', 'raise', or 'ignore'")
+ if method == 'svd':
+ psd = np.allclose(np.dot(vh.T * s, vh), cov, rtol=tol, atol=tol)
+ else:
+ psd = not np.any(s < -tol)
if not psd:
if check_valid == 'warn':
warnings.warn("covariance is not positive-semidefinite.",
@@ -3614,7 +3645,17 @@ cdef class Generator:
else:
raise ValueError("covariance is not positive-semidefinite.")
- x = np.dot(x, np.sqrt(s)[:, None] * v)
+ if method == 'cholesky':
+ _factor = l
+ elif method == 'eigh':
+ # if check_valid == 'ignore' we need to ensure that np.sqrt does not
+ # return a NaN if s is a very small negative number that is
+ # approximately zero or when the covariance is not positive-semidefinite
+ _factor = u * np.sqrt(abs(s))
+ else:
+ _factor = np.sqrt(s)[:, None] * vh
+
+ x = np.dot(x, _factor)
x += mean
x.shape = tuple(final_shape)
return x
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index d4502d276..d85de6b6d 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -3,6 +3,8 @@ import sys
import pytest
import numpy as np
+from numpy.dual import cholesky, eigh, svd
+from numpy.linalg import LinAlgError
from numpy.testing import (
assert_, assert_raises, assert_equal, assert_allclose,
assert_warns, assert_no_warnings, assert_array_equal,
@@ -1196,12 +1198,13 @@ class TestRandomDist(object):
[5, 5, 3, 1, 2, 4]]])
assert_array_equal(actual, desired)
- def test_multivariate_normal(self):
+ @pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"])
+ def test_multivariate_normal(self, method):
random = Generator(MT19937(self.seed))
mean = (.123456789, 10)
cov = [[1, 0], [0, 1]]
size = (3, 2)
- actual = random.multivariate_normal(mean, cov, size)
+ actual = random.multivariate_normal(mean, cov, size, method=method)
desired = np.array([[[-1.747478062846581, 11.25613495182354 ],
[-0.9967333370066214, 10.342002097029821 ]],
[[ 0.7850019631242964, 11.181113712443013 ],
@@ -1212,15 +1215,24 @@ class TestRandomDist(object):
assert_array_almost_equal(actual, desired, decimal=15)
# Check for default size, was raising deprecation warning
- actual = random.multivariate_normal(mean, cov)
+ actual = random.multivariate_normal(mean, cov, method=method)
desired = np.array([0.233278563284287, 9.424140804347195])
assert_array_almost_equal(actual, desired, decimal=15)
+ # Check that non symmetric covariance input raises exception when
+ # check_valid='raises' if using default svd method.
+ mean = [0, 0]
+ cov = [[1, 2], [1, 2]]
+ assert_raises(ValueError, random.multivariate_normal, mean, cov,
+ check_valid='raise')
# Check that non positive-semidefinite covariance warns with
# RuntimeWarning
- mean = [0, 0]
cov = [[1, 2], [2, 1]]
assert_warns(RuntimeWarning, random.multivariate_normal, mean, cov)
+ assert_warns(RuntimeWarning, random.multivariate_normal, mean, cov,
+ method='eigh')
+ assert_raises(LinAlgError, random.multivariate_normal, mean, cov,
+ method='cholesky')
# and that it doesn't warn with RuntimeWarning check_valid='ignore'
assert_no_warnings(random.multivariate_normal, mean, cov,
@@ -1229,10 +1241,12 @@ class TestRandomDist(object):
# and that it raises with RuntimeWarning check_valid='raises'
assert_raises(ValueError, random.multivariate_normal, mean, cov,
check_valid='raise')
+ assert_raises(ValueError, random.multivariate_normal, mean, cov,
+ check_valid='raise', method='eigh')
cov = np.array([[1, 0.1], [0.1, 1]], dtype=np.float32)
with suppress_warnings() as sup:
- random.multivariate_normal(mean, cov)
+ random.multivariate_normal(mean, cov, method=method)
w = sup.record(RuntimeWarning)
assert len(w) == 0