diff options
author | Julian Taylor <jtaylor.debian@googlemail.com> | 2013-10-25 19:36:33 +0200 |
---|---|---|
committer | Julian Taylor <jtaylor.debian@googlemail.com> | 2013-10-25 20:07:13 +0200 |
commit | f8e07275f05e95a4d0af098b06d37925602f7861 (patch) | |
tree | 69facb9e6b295e5c6b6ba74074ad711d162cc879 | |
parent | e549e69020e3dcec08185695db6f7001a62dc934 (diff) | |
download | numpy-f8e07275f05e95a4d0af098b06d37925602f7861.tar.gz |
BUG: reject invalid UPLO with ValueError in eigh/eigvalsh
to prevent unintentional use of wrong function. Restores 1.7 behavior.
-rw-r--r-- | numpy/linalg/linalg.py | 8 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 32 |
2 files changed, 35 insertions, 5 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 9b91aa7d3..c5621eace 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -950,7 +950,9 @@ def eigvalsh(a, UPLO='L'): array([ 0.17157288+0.j, 5.82842712+0.j]) """ - UPLO = asbytes(UPLO) + UPLO = asbytes(UPLO.upper()) + if UPLO not in (b'L', b'U'): + raise ValueError("UPLO argument must be 'L' or 'U'") extobj = get_linalg_error_extobj( _raise_linalgerror_eigenvalues_nonconvergence) @@ -1195,7 +1197,9 @@ def eigh(a, UPLO='L'): [ 0.00000000+0.38268343j, 0.00000000-0.92387953j]]) """ - UPLO = asbytes(UPLO) + UPLO = asbytes(UPLO.upper()) + if UPLO not in (b'L', b'U'): + raise ValueError("UPLO argument must be 'L' or 'U'") a, wrap = _makearray(a) _assertRankAtLeast2(a) diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index c35e9f4f2..803b4c88f 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -739,26 +739,52 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_invalid(self): + x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32) + assert_raises(ValueError, np.linalg.eigh, x, UPLO="lrong") + assert_raises(ValueError, np.linalg.eigh, x, "lower") + assert_raises(ValueError, np.linalg.eigh, x, "upper") + assert_raises(ValueError, np.linalg.eigvalsh, x, UPLO="lrong") + assert_raises(ValueError, np.linalg.eigvalsh, x, "lower") + assert_raises(ValueError, np.linalg.eigvalsh, x, "upper") + def test_half_filled(self): expect = np.array([-0.33333333, -0.33333333, -0.33333333, 0.99999999]) K = np.array([[ 0. , 0. , 0. , 0. ], - [-0.33333333, 0. , 0. , 0. ], - [ 0.33333333, -0.33333333, 0. , 0. ], - [ 0.33333333, -0.33333333, 0.33333333, 0. ]]) + [-0.33333333, 0. , 0. , 0. ], + [ 0.33333333, -0.33333333, 0. , 0. ], + [ 0.33333333, -0.33333333, 0.33333333, 0. ]]) Kr = np.rot90(K, k=2) + w, V = np.linalg.eigh(K) assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + w, V = np.linalg.eigh(UPLO='L', a=K) assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + w, V = np.linalg.eigh(K, 'l') + w2, V2 = np.linalg.eigh(K, 'L') + assert_allclose(w, w2, rtol=get_rtol(K.dtype)) + assert_allclose(V, V2, rtol=get_rtol(K.dtype)) + w, V = np.linalg.eigh(Kr, 'U') assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + w, V = np.linalg.eigh(Kr, 'u') + w2, V2 = np.linalg.eigh(Kr, 'u') + assert_allclose(w, w2, rtol=get_rtol(K.dtype)) + assert_allclose(V, V2, rtol=get_rtol(K.dtype)) w = np.linalg.eigvalsh(K) assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + w = np.linalg.eigvalsh(UPLO='L', a=K) assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + assert_allclose(np.linalg.eigvalsh(K, 'L'), + np.linalg.eigvalsh(K, 'l'), rtol=get_rtol(K.dtype)) + w = np.linalg.eigvalsh(Kr, 'U') assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + assert_allclose(np.linalg.eigvalsh(Kr, 'U'), + np.linalg.eigvalsh(Kr, 'u'), rtol=get_rtol(K.dtype)) class _TestNorm(object): |