diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2013-10-27 18:23:02 -0700 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2013-10-27 18:23:02 -0700 |
commit | 85b83e6938fa6f5176eaab8e8fd1652b27d53aa0 (patch) | |
tree | b18759a0f483dbe5fd473826d3957094d6712cfe | |
parent | 47b5af987bf31553329334fa08898dac67dbf1ac (diff) | |
parent | 3fc490759fc87dd4702b5a6174638e9fd70019dd (diff) | |
download | numpy-85b83e6938fa6f5176eaab8e8fd1652b27d53aa0.tar.gz |
Merge pull request #3982 from charris/refactor-eigh-eigvalsh
MAINT: Refactor eigh and eigvalsh and associated tests.
-rw-r--r-- | numpy/linalg/linalg.py | 12 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 91 |
2 files changed, 56 insertions, 47 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index c5621eace..1b82c7cc0 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -950,13 +950,13 @@ def eigvalsh(a, UPLO='L'): array([ 0.17157288+0.j, 5.82842712+0.j]) """ - UPLO = asbytes(UPLO.upper()) - if UPLO not in (b'L', b'U'): + UPLO = UPLO.upper() + if UPLO not in ('L', 'U'): raise ValueError("UPLO argument must be 'L' or 'U'") extobj = get_linalg_error_extobj( _raise_linalgerror_eigenvalues_nonconvergence) - if UPLO == _L: + if UPLO == 'L': gufunc = _umath_linalg.eigvalsh_lo else: gufunc = _umath_linalg.eigvalsh_up @@ -1197,8 +1197,8 @@ def eigh(a, UPLO='L'): [ 0.00000000+0.38268343j, 0.00000000-0.92387953j]]) """ - UPLO = asbytes(UPLO.upper()) - if UPLO not in (b'L', b'U'): + UPLO = UPLO.upper() + if UPLO not in ('L', 'U'): raise ValueError("UPLO argument must be 'L' or 'U'") a, wrap = _makearray(a) @@ -1208,7 +1208,7 @@ def eigh(a, UPLO='L'): extobj = get_linalg_error_extobj( _raise_linalgerror_eigenvalues_nonconvergence) - if _L == UPLO: + if UPLO == 'L': gufunc = _umath_linalg.eigh_lo else: gufunc = _umath_linalg.eigh_up diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 803b4c88f..dd4cbcc4f 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -50,7 +50,7 @@ def get_complex_dtype(dtype): def get_rtol(dtype): # Choose a safe rtol - if dtype in (np.single, csingle): + if dtype in (single, csingle): return 1e-5 else: return 1e-11 @@ -707,6 +707,34 @@ class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_invalid(self): + x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32) + assert_raises(ValueError, np.linalg.eigvalsh, x, UPLO="lrong") + assert_raises(ValueError, np.linalg.eigvalsh, x, "lower") + assert_raises(ValueError, np.linalg.eigvalsh, x, "upper") + + def test_UPLO(self): + Klo = np.array([[0, 0],[1, 0]], dtype=np.double) + Kup = np.array([[0, 1],[0, 0]], dtype=np.double) + tgt = np.array([-1, 1], dtype=np.double) + rtol = get_rtol(np.double) + + # Check default is 'L' + w = np.linalg.eigvalsh(Klo) + assert_allclose(np.sort(w), tgt, rtol=rtol) + # Check 'L' + w = np.linalg.eigvalsh(Klo, UPLO='L') + assert_allclose(np.sort(w), tgt, rtol=rtol) + # Check 'l' + w = np.linalg.eigvalsh(Klo, UPLO='l') + assert_allclose(np.sort(w), tgt, rtol=rtol) + # Check 'U' + w = np.linalg.eigvalsh(Kup, UPLO='U') + assert_allclose(np.sort(w), tgt, rtol=rtol) + # Check 'u' + w = np.linalg.eigvalsh(Kup, UPLO='u') + assert_allclose(np.sort(w), tgt, rtol=rtol) + class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase): def do(self, a, b): @@ -744,47 +772,28 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase): assert_raises(ValueError, np.linalg.eigh, x, UPLO="lrong") assert_raises(ValueError, np.linalg.eigh, x, "lower") assert_raises(ValueError, np.linalg.eigh, x, "upper") - assert_raises(ValueError, np.linalg.eigvalsh, x, UPLO="lrong") - assert_raises(ValueError, np.linalg.eigvalsh, x, "lower") - assert_raises(ValueError, np.linalg.eigvalsh, x, "upper") - def test_half_filled(self): - expect = np.array([-0.33333333, -0.33333333, -0.33333333, 0.99999999]) - K = np.array([[ 0. , 0. , 0. , 0. ], - [-0.33333333, 0. , 0. , 0. ], - [ 0.33333333, -0.33333333, 0. , 0. ], - [ 0.33333333, -0.33333333, 0.33333333, 0. ]]) - Kr = np.rot90(K, k=2) - - w, V = np.linalg.eigh(K) - assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) - - w, V = np.linalg.eigh(UPLO='L', a=K) - assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) - w, V = np.linalg.eigh(K, 'l') - w2, V2 = np.linalg.eigh(K, 'L') - assert_allclose(w, w2, rtol=get_rtol(K.dtype)) - assert_allclose(V, V2, rtol=get_rtol(K.dtype)) - - w, V = np.linalg.eigh(Kr, 'U') - assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) - w, V = np.linalg.eigh(Kr, 'u') - w2, V2 = np.linalg.eigh(Kr, 'u') - assert_allclose(w, w2, rtol=get_rtol(K.dtype)) - assert_allclose(V, V2, rtol=get_rtol(K.dtype)) - - w = np.linalg.eigvalsh(K) - assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) - - w = np.linalg.eigvalsh(UPLO='L', a=K) - assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) - assert_allclose(np.linalg.eigvalsh(K, 'L'), - np.linalg.eigvalsh(K, 'l'), rtol=get_rtol(K.dtype)) - - w = np.linalg.eigvalsh(Kr, 'U') - assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) - assert_allclose(np.linalg.eigvalsh(Kr, 'U'), - np.linalg.eigvalsh(Kr, 'u'), rtol=get_rtol(K.dtype)) + def test_UPLO(self): + Klo = np.array([[0, 0],[1, 0]], dtype=np.double) + Kup = np.array([[0, 1],[0, 0]], dtype=np.double) + tgt = np.array([-1, 1], dtype=np.double) + rtol = get_rtol(np.double) + + # Check default is 'L' + w, v = np.linalg.eigh(Klo) + assert_allclose(np.sort(w), tgt, rtol=rtol) + # Check 'L' + w, v = np.linalg.eigh(Klo, UPLO='L') + assert_allclose(np.sort(w), tgt, rtol=rtol) + # Check 'l' + w, v = np.linalg.eigh(Klo, UPLO='l') + assert_allclose(np.sort(w), tgt, rtol=rtol) + # Check 'U' + w, v = np.linalg.eigh(Kup, UPLO='U') + assert_allclose(np.sort(w), tgt, rtol=rtol) + # Check 'u' + w, v = np.linalg.eigh(Kup, UPLO='u') + assert_allclose(np.sort(w), tgt, rtol=rtol) class _TestNorm(object): |