diff options
author | Julian Taylor <jtaylor.debian@googlemail.com> | 2013-10-25 18:56:59 +0200 |
---|---|---|
committer | Julian Taylor <jtaylor.debian@googlemail.com> | 2013-10-25 19:01:29 +0200 |
commit | e549e69020e3dcec08185695db6f7001a62dc934 (patch) | |
tree | cb7dfef52169207eb8b08c7e1b053cd133f870a0 | |
parent | a3e8c12ed88c6db2aa89cfbb7a69fc863e8a40dc (diff) | |
download | numpy-e549e69020e3dcec08185695db6f7001a62dc934.tar.gz |
BUG: fix broken UPLO of eigh in python3
UPLO was cast to bytes and compared to a string which is always false in
python3.
closes gh-3977
-rw-r--r-- | numpy/linalg/linalg.py | 7 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 23 |
2 files changed, 26 insertions, 4 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index aa3bdea34..9b91aa7d3 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -914,7 +914,7 @@ def eigvalsh(a, UPLO='L'): A complex- or real-valued matrix whose eigenvalues are to be computed. UPLO : {'L', 'U'}, optional - Same as `lower`, wth 'L' for lower and 'U' for upper triangular. + Same as `lower`, with 'L' for lower and 'U' for upper triangular. Deprecated. Returns @@ -950,10 +950,11 @@ def eigvalsh(a, UPLO='L'): array([ 0.17157288+0.j, 5.82842712+0.j]) """ + UPLO = asbytes(UPLO) extobj = get_linalg_error_extobj( _raise_linalgerror_eigenvalues_nonconvergence) - if UPLO == 'L': + if UPLO == _L: gufunc = _umath_linalg.eigvalsh_lo else: gufunc = _umath_linalg.eigvalsh_up @@ -1203,7 +1204,7 @@ def eigh(a, UPLO='L'): extobj = get_linalg_error_extobj( _raise_linalgerror_eigenvalues_nonconvergence) - if 'L' == UPLO: + if _L == UPLO: gufunc = _umath_linalg.eigh_lo else: gufunc = _umath_linalg.eigh_up diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index cc1404bf1..c35e9f4f2 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -728,7 +728,7 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase): assert_allclose(dot_generalized(a, evc2), np.asarray(ev2)[...,None,:] * np.asarray(evc2), - rtol=get_rtol(ev.dtype)) + rtol=get_rtol(ev.dtype), err_msg=repr(a)) def test_types(self): def check(dtype): @@ -739,6 +739,27 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_half_filled(self): + expect = np.array([-0.33333333, -0.33333333, -0.33333333, 0.99999999]) + K = np.array([[ 0. , 0. , 0. , 0. ], + [-0.33333333, 0. , 0. , 0. ], + [ 0.33333333, -0.33333333, 0. , 0. ], + [ 0.33333333, -0.33333333, 0.33333333, 0. ]]) + Kr = np.rot90(K, k=2) + w, V = np.linalg.eigh(K) + assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + w, V = np.linalg.eigh(UPLO='L', a=K) + assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + w, V = np.linalg.eigh(Kr, 'U') + assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + + w = np.linalg.eigvalsh(K) + assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + w = np.linalg.eigvalsh(UPLO='L', a=K) + assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + w = np.linalg.eigvalsh(Kr, 'U') + assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype)) + class _TestNorm(object): |