summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRedRuM <44142765+zoj613@users.noreply.github.com>2019-08-04 15:32:25 +0200
committerRedRuM <44142765+zoj613@users.noreply.github.com>2019-08-04 15:32:25 +0200
commit124b81ffdfda98eecb5cc27660bd86893a87ef9a (patch)
tree8d7d0e03554aec3790672e0dd7f706f85c75e8bc
parent71c8a1030d5a32342edc2e2311cb71dc38a7374e (diff)
downloadnumpy-124b81ffdfda98eecb5cc27660bd86893a87ef9a.tar.gz
ENH: Add new test cases corresponding to the enhancements in random.multivariate_normal
-rw-r--r--numpy/random/mtrand.pyx48
1 files changed, 41 insertions, 7 deletions
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index 46b6b3388..d6a24918a 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -3609,7 +3609,7 @@ cdef class RandomState:
# Multivariate distributions:
def multivariate_normal(self, mean, cov, size=None, check_valid='warn',
- tol=1e-8):
+ mode="default", factor=None, tol=1e-8):
"""
multivariate_normal(mean, cov, size=None, check_valid='warn', tol=1e-8)
@@ -3710,7 +3710,16 @@ cdef class RandomState:
[True, True] # random
"""
- from numpy.dual import svd
+ if mode not in {'default', 'legacy', 'fast'}:
+ raise ValueError(
+ "mode must to be one of {'default', 'legacy', 'fast'}")
+
+ if mode == 'legacy':
+ from numpy.dual import svd
+ elif mode == 'fast':
+ from numpy.dual import cholesky
+ else:
+ from numpy.dual import eigh
# Check preconditions on arguments
mean = np.array(mean)
@@ -3728,6 +3737,10 @@ cdef class RandomState:
raise ValueError("cov must be 2 dimensional and square")
if mean.shape[0] != cov.shape[0]:
raise ValueError("mean and cov must have same length")
+ # raise exception if cov is not symmetric
+ symmetric = np.allclose(cov, cov.T, rtol=tol, atol=tol)
+ if not symmetric:
+ raise ValueError("covariance is not symmetric.")
# Compute shape of output and create a matrix of independent
# standard normally distributed random numbers. The matrix has rows
@@ -3737,6 +3750,16 @@ cdef class RandomState:
final_shape.append(mean.shape[0])
x = self.standard_normal(final_shape).reshape(-1, mean.shape[0])
+ if factor is not None:
+ factor = np.array(factor)
+ if np.allclose(np.dot(factor, factor.T), cov, rtol=tol, atol=tol):
+ x = mean + np.dot(x, factor)
+ x.shape = tuple(final_shape)
+ return x
+ else:
+ raise ValueError(
+ "Input factor cannot recover the covariance matrix.")
+
# Transform matrix of standard normals into matrix where each row
# contains multivariate normals with the desired covariance.
# Compute A such that dot(transpose(A),A) == cov.
@@ -3753,14 +3776,21 @@ cdef class RandomState:
# GH10839, ensure double to make tol meaningful
cov = cov.astype(np.double)
- (u, s, v) = svd(cov)
+ if mode == 'legacy':
+ (u, s, v) = svd(cov)
+ elif mode == 'fast':
+ l = cholesky(cov)
+ else:
+ s, v = eigh(cov)
- if check_valid != 'ignore':
+ if check_valid != 'ignore' and mode != 'fast':
if check_valid != 'warn' and check_valid != 'raise':
raise ValueError(
"check_valid must equal 'warn', 'raise', or 'ignore'")
-
- psd = np.allclose(np.dot(v.T * s, v), cov, rtol=tol, atol=tol)
+ if mode == 'legacy':
+ psd = np.allclose(np.dot(v.T * s, v), cov, rtol=tol, atol=tol)
+ else:
+ psd = np.all(s >= 0)
if not psd:
if check_valid == 'warn':
warnings.warn("covariance is not positive-semidefinite.",
@@ -3769,7 +3799,11 @@ cdef class RandomState:
raise ValueError(
"covariance is not positive-semidefinite.")
- x = np.dot(x, np.sqrt(s)[:, None] * v)
+ if mode == 'fast':
+ x = np.dot(x, l)
+ else:
+ s = np.abs(s) # to combat negative eigenvalues
+ x = np.dot(x, np.sqrt(s)[:, None] * v)
x += mean
x.shape = tuple(final_shape)
return x