summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2013-10-27 18:23:02 -0700
committerCharles Harris <charlesr.harris@gmail.com>2013-10-27 18:23:02 -0700
commit85b83e6938fa6f5176eaab8e8fd1652b27d53aa0 (patch)
treeb18759a0f483dbe5fd473826d3957094d6712cfe
parent47b5af987bf31553329334fa08898dac67dbf1ac (diff)
parent3fc490759fc87dd4702b5a6174638e9fd70019dd (diff)
downloadnumpy-85b83e6938fa6f5176eaab8e8fd1652b27d53aa0.tar.gz
Merge pull request #3982 from charris/refactor-eigh-eigvalsh
MAINT: Refactor eigh and eigvalsh and associated tests.
-rw-r--r--numpy/linalg/linalg.py12
-rw-r--r--numpy/linalg/tests/test_linalg.py91
2 files changed, 56 insertions, 47 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index c5621eace..1b82c7cc0 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -950,13 +950,13 @@ def eigvalsh(a, UPLO='L'):
array([ 0.17157288+0.j, 5.82842712+0.j])
"""
- UPLO = asbytes(UPLO.upper())
- if UPLO not in (b'L', b'U'):
+ UPLO = UPLO.upper()
+ if UPLO not in ('L', 'U'):
raise ValueError("UPLO argument must be 'L' or 'U'")
extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
- if UPLO == _L:
+ if UPLO == 'L':
gufunc = _umath_linalg.eigvalsh_lo
else:
gufunc = _umath_linalg.eigvalsh_up
@@ -1197,8 +1197,8 @@ def eigh(a, UPLO='L'):
[ 0.00000000+0.38268343j, 0.00000000-0.92387953j]])
"""
- UPLO = asbytes(UPLO.upper())
- if UPLO not in (b'L', b'U'):
+ UPLO = UPLO.upper()
+ if UPLO not in ('L', 'U'):
raise ValueError("UPLO argument must be 'L' or 'U'")
a, wrap = _makearray(a)
@@ -1208,7 +1208,7 @@ def eigh(a, UPLO='L'):
extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
- if _L == UPLO:
+ if UPLO == 'L':
gufunc = _umath_linalg.eigh_lo
else:
gufunc = _umath_linalg.eigh_up
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 803b4c88f..dd4cbcc4f 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -50,7 +50,7 @@ def get_complex_dtype(dtype):
def get_rtol(dtype):
# Choose a safe rtol
- if dtype in (np.single, csingle):
+ if dtype in (single, csingle):
return 1e-5
else:
return 1e-11
@@ -707,6 +707,34 @@ class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_invalid(self):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32)
+ assert_raises(ValueError, np.linalg.eigvalsh, x, UPLO="lrong")
+ assert_raises(ValueError, np.linalg.eigvalsh, x, "lower")
+ assert_raises(ValueError, np.linalg.eigvalsh, x, "upper")
+
+ def test_UPLO(self):
+ Klo = np.array([[0, 0],[1, 0]], dtype=np.double)
+ Kup = np.array([[0, 1],[0, 0]], dtype=np.double)
+ tgt = np.array([-1, 1], dtype=np.double)
+ rtol = get_rtol(np.double)
+
+ # Check default is 'L'
+ w = np.linalg.eigvalsh(Klo)
+ assert_allclose(np.sort(w), tgt, rtol=rtol)
+ # Check 'L'
+ w = np.linalg.eigvalsh(Klo, UPLO='L')
+ assert_allclose(np.sort(w), tgt, rtol=rtol)
+ # Check 'l'
+ w = np.linalg.eigvalsh(Klo, UPLO='l')
+ assert_allclose(np.sort(w), tgt, rtol=rtol)
+ # Check 'U'
+ w = np.linalg.eigvalsh(Kup, UPLO='U')
+ assert_allclose(np.sort(w), tgt, rtol=rtol)
+ # Check 'u'
+ w = np.linalg.eigvalsh(Kup, UPLO='u')
+ assert_allclose(np.sort(w), tgt, rtol=rtol)
+
class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):
def do(self, a, b):
@@ -744,47 +772,28 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):
assert_raises(ValueError, np.linalg.eigh, x, UPLO="lrong")
assert_raises(ValueError, np.linalg.eigh, x, "lower")
assert_raises(ValueError, np.linalg.eigh, x, "upper")
- assert_raises(ValueError, np.linalg.eigvalsh, x, UPLO="lrong")
- assert_raises(ValueError, np.linalg.eigvalsh, x, "lower")
- assert_raises(ValueError, np.linalg.eigvalsh, x, "upper")
- def test_half_filled(self):
- expect = np.array([-0.33333333, -0.33333333, -0.33333333, 0.99999999])
- K = np.array([[ 0. , 0. , 0. , 0. ],
- [-0.33333333, 0. , 0. , 0. ],
- [ 0.33333333, -0.33333333, 0. , 0. ],
- [ 0.33333333, -0.33333333, 0.33333333, 0. ]])
- Kr = np.rot90(K, k=2)
-
- w, V = np.linalg.eigh(K)
- assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
-
- w, V = np.linalg.eigh(UPLO='L', a=K)
- assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
- w, V = np.linalg.eigh(K, 'l')
- w2, V2 = np.linalg.eigh(K, 'L')
- assert_allclose(w, w2, rtol=get_rtol(K.dtype))
- assert_allclose(V, V2, rtol=get_rtol(K.dtype))
-
- w, V = np.linalg.eigh(Kr, 'U')
- assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
- w, V = np.linalg.eigh(Kr, 'u')
- w2, V2 = np.linalg.eigh(Kr, 'u')
- assert_allclose(w, w2, rtol=get_rtol(K.dtype))
- assert_allclose(V, V2, rtol=get_rtol(K.dtype))
-
- w = np.linalg.eigvalsh(K)
- assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
-
- w = np.linalg.eigvalsh(UPLO='L', a=K)
- assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
- assert_allclose(np.linalg.eigvalsh(K, 'L'),
- np.linalg.eigvalsh(K, 'l'), rtol=get_rtol(K.dtype))
-
- w = np.linalg.eigvalsh(Kr, 'U')
- assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
- assert_allclose(np.linalg.eigvalsh(Kr, 'U'),
- np.linalg.eigvalsh(Kr, 'u'), rtol=get_rtol(K.dtype))
+ def test_UPLO(self):
+ Klo = np.array([[0, 0],[1, 0]], dtype=np.double)
+ Kup = np.array([[0, 1],[0, 0]], dtype=np.double)
+ tgt = np.array([-1, 1], dtype=np.double)
+ rtol = get_rtol(np.double)
+
+ # Check default is 'L'
+ w, v = np.linalg.eigh(Klo)
+ assert_allclose(np.sort(w), tgt, rtol=rtol)
+ # Check 'L'
+ w, v = np.linalg.eigh(Klo, UPLO='L')
+ assert_allclose(np.sort(w), tgt, rtol=rtol)
+ # Check 'l'
+ w, v = np.linalg.eigh(Klo, UPLO='l')
+ assert_allclose(np.sort(w), tgt, rtol=rtol)
+ # Check 'U'
+ w, v = np.linalg.eigh(Kup, UPLO='U')
+ assert_allclose(np.sort(w), tgt, rtol=rtol)
+ # Check 'u'
+ w, v = np.linalg.eigh(Kup, UPLO='u')
+ assert_allclose(np.sort(w), tgt, rtol=rtol)
class _TestNorm(object):