summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2013-10-25 19:36:33 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2013-10-25 20:07:13 +0200
commitf8e07275f05e95a4d0af098b06d37925602f7861 (patch)
tree69facb9e6b295e5c6b6ba74074ad711d162cc879
parente549e69020e3dcec08185695db6f7001a62dc934 (diff)
downloadnumpy-f8e07275f05e95a4d0af098b06d37925602f7861.tar.gz
BUG: reject invalid UPLO with ValueError in eigh/eigvalsh
to prevent unintentional use of wrong function. Restores 1.7 behavior.
-rw-r--r--numpy/linalg/linalg.py8
-rw-r--r--numpy/linalg/tests/test_linalg.py32
2 files changed, 35 insertions, 5 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 9b91aa7d3..c5621eace 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -950,7 +950,9 @@ def eigvalsh(a, UPLO='L'):
array([ 0.17157288+0.j, 5.82842712+0.j])
"""
- UPLO = asbytes(UPLO)
+ UPLO = asbytes(UPLO.upper())
+ if UPLO not in (b'L', b'U'):
+ raise ValueError("UPLO argument must be 'L' or 'U'")
extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
@@ -1195,7 +1197,9 @@ def eigh(a, UPLO='L'):
[ 0.00000000+0.38268343j, 0.00000000-0.92387953j]])
"""
- UPLO = asbytes(UPLO)
+ UPLO = asbytes(UPLO.upper())
+ if UPLO not in (b'L', b'U'):
+ raise ValueError("UPLO argument must be 'L' or 'U'")
a, wrap = _makearray(a)
_assertRankAtLeast2(a)
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index c35e9f4f2..803b4c88f 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -739,26 +739,52 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_invalid(self):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32)
+ assert_raises(ValueError, np.linalg.eigh, x, UPLO="lrong")
+ assert_raises(ValueError, np.linalg.eigh, x, "lower")
+ assert_raises(ValueError, np.linalg.eigh, x, "upper")
+ assert_raises(ValueError, np.linalg.eigvalsh, x, UPLO="lrong")
+ assert_raises(ValueError, np.linalg.eigvalsh, x, "lower")
+ assert_raises(ValueError, np.linalg.eigvalsh, x, "upper")
+
def test_half_filled(self):
expect = np.array([-0.33333333, -0.33333333, -0.33333333, 0.99999999])
K = np.array([[ 0. , 0. , 0. , 0. ],
- [-0.33333333, 0. , 0. , 0. ],
- [ 0.33333333, -0.33333333, 0. , 0. ],
- [ 0.33333333, -0.33333333, 0.33333333, 0. ]])
+ [-0.33333333, 0. , 0. , 0. ],
+ [ 0.33333333, -0.33333333, 0. , 0. ],
+ [ 0.33333333, -0.33333333, 0.33333333, 0. ]])
Kr = np.rot90(K, k=2)
+
w, V = np.linalg.eigh(K)
assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+
w, V = np.linalg.eigh(UPLO='L', a=K)
assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+ w, V = np.linalg.eigh(K, 'l')
+ w2, V2 = np.linalg.eigh(K, 'L')
+ assert_allclose(w, w2, rtol=get_rtol(K.dtype))
+ assert_allclose(V, V2, rtol=get_rtol(K.dtype))
+
w, V = np.linalg.eigh(Kr, 'U')
assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+ w, V = np.linalg.eigh(Kr, 'u')
+ w2, V2 = np.linalg.eigh(Kr, 'u')
+ assert_allclose(w, w2, rtol=get_rtol(K.dtype))
+ assert_allclose(V, V2, rtol=get_rtol(K.dtype))
w = np.linalg.eigvalsh(K)
assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+
w = np.linalg.eigvalsh(UPLO='L', a=K)
assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+ assert_allclose(np.linalg.eigvalsh(K, 'L'),
+ np.linalg.eigvalsh(K, 'l'), rtol=get_rtol(K.dtype))
+
w = np.linalg.eigvalsh(Kr, 'U')
assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+ assert_allclose(np.linalg.eigvalsh(Kr, 'U'),
+ np.linalg.eigvalsh(Kr, 'u'), rtol=get_rtol(K.dtype))
class _TestNorm(object):