diff options
author | RedRuM <44142765+zoj613@users.noreply.github.com> | 2019-08-04 15:32:25 +0200 |
---|---|---|
committer | RedRuM <44142765+zoj613@users.noreply.github.com> | 2019-08-04 15:32:25 +0200 |
commit | 124b81ffdfda98eecb5cc27660bd86893a87ef9a (patch) | |
tree | 8d7d0e03554aec3790672e0dd7f706f85c75e8bc | |
parent | 71c8a1030d5a32342edc2e2311cb71dc38a7374e (diff) | |
download | numpy-124b81ffdfda98eecb5cc27660bd86893a87ef9a.tar.gz |
ENH: Add new test cases corresponding to the enhancements in random.multivariate_normal
-rw-r--r-- | numpy/random/mtrand.pyx | 48 |
1 files changed, 41 insertions, 7 deletions
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index 46b6b3388..d6a24918a 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -3609,7 +3609,7 @@ cdef class RandomState: # Multivariate distributions: def multivariate_normal(self, mean, cov, size=None, check_valid='warn', - tol=1e-8): + mode="default", factor=None, tol=1e-8): """ multivariate_normal(mean, cov, size=None, check_valid='warn', tol=1e-8) @@ -3710,7 +3710,16 @@ cdef class RandomState: [True, True] # random """ - from numpy.dual import svd + if mode not in {'default', 'legacy', 'fast'}: + raise ValueError( + "mode must to be one of {'default', 'legacy', 'fast'}") + + if mode == 'legacy': + from numpy.dual import svd + elif mode == 'fast': + from numpy.dual import cholesky + else: + from numpy.dual import eigh # Check preconditions on arguments mean = np.array(mean) @@ -3728,6 +3737,10 @@ cdef class RandomState: raise ValueError("cov must be 2 dimensional and square") if mean.shape[0] != cov.shape[0]: raise ValueError("mean and cov must have same length") + # raise exception if cov is not symmetric + symmetric = np.allclose(cov, cov.T, rtol=tol, atol=tol) + if not symmetric: + raise ValueError("covariance is not symmetric.") # Compute shape of output and create a matrix of independent # standard normally distributed random numbers. The matrix has rows @@ -3737,6 +3750,16 @@ cdef class RandomState: final_shape.append(mean.shape[0]) x = self.standard_normal(final_shape).reshape(-1, mean.shape[0]) + if factor is not None: + factor = np.array(factor) + if np.allclose(np.dot(factor, factor.T), cov, rtol=tol, atol=tol): + x = mean + np.dot(x, factor) + x.shape = tuple(final_shape) + return x + else: + raise ValueError( + "Input factor cannot recover the covariance matrix.") + # Transform matrix of standard normals into matrix where each row # contains multivariate normals with the desired covariance. # Compute A such that dot(transpose(A),A) == cov. @@ -3753,14 +3776,21 @@ cdef class RandomState: # GH10839, ensure double to make tol meaningful cov = cov.astype(np.double) - (u, s, v) = svd(cov) + if mode == 'legacy': + (u, s, v) = svd(cov) + elif mode == 'fast': + l = cholesky(cov) + else: + s, v = eigh(cov) - if check_valid != 'ignore': + if check_valid != 'ignore' and mode != 'fast': if check_valid != 'warn' and check_valid != 'raise': raise ValueError( "check_valid must equal 'warn', 'raise', or 'ignore'") - - psd = np.allclose(np.dot(v.T * s, v), cov, rtol=tol, atol=tol) + if mode == 'legacy': + psd = np.allclose(np.dot(v.T * s, v), cov, rtol=tol, atol=tol) + else: + psd = np.all(s >= 0) if not psd: if check_valid == 'warn': warnings.warn("covariance is not positive-semidefinite.", @@ -3769,7 +3799,11 @@ cdef class RandomState: raise ValueError( "covariance is not positive-semidefinite.") - x = np.dot(x, np.sqrt(s)[:, None] * v) + if mode == 'fast': + x = np.dot(x, l) + else: + s = np.abs(s) # to combat negative eigenvalues + x = np.dot(x, np.sqrt(s)[:, None] * v) x += mean x.shape = tuple(final_shape) return x |