summaryrefslogtreecommitdiff
path: root/numpy/linalg
diff options
context:
space:
mode:
authorcookedm <cookedm@localhost>2006-07-14 09:42:54 +0000
committercookedm <cookedm@localhost>2006-07-14 09:42:54 +0000
commit46fa7a119d9995f5f348694ae0595f2abdf44762 (patch)
tree471aae8ca795ae1fe12257a5590fa1dd40025781 /numpy/linalg
parentacd97c630b02b7a48b715eaf67155784642d95d8 (diff)
downloadnumpy-46fa7a119d9995f5f348694ae0595f2abdf44762.tar.gz
numpy.linalg: fix bug where complex arrays weren't being returned.
Also improved test cases.
Diffstat (limited to 'numpy/linalg')
-rw-r--r--numpy/linalg/linalg.py7
-rw-r--r--numpy/linalg/tests/test_linalg.py22
2 files changed, 18 insertions, 11 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 1a0dd3c6e..142de2340 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -18,7 +18,8 @@ __all__ = ['solve',
from numpy.core import array, asarray, zeros, empty, transpose, \
intc, single, double, csingle, cdouble, inexact, complexfloating, \
newaxis, ravel, all, Inf, dot, add, multiply, identity, sqrt, \
- maximum, nonzero, diagonal, arange, fastCopyAndTranspose, sum
+ maximum, nonzero, diagonal, arange, fastCopyAndTranspose, sum, \
+ argsort
from numpy.lib import triu
from numpy.linalg import lapack_lite
@@ -75,7 +76,7 @@ def _commonType(*arrays):
result_type = _complex_types_map[result_type]
else:
t = double
- return t, rt
+ return t, result_type
def _castCopyAndTranspose(type, *arrays):
if len(arrays) == 1:
@@ -351,7 +352,7 @@ def eigh(a, UPLO='L'):
if results['info'] > 0:
raise LinAlgError, 'Eigenvalues did not converge'
at = a.transpose().astype(result_t)
- return w.astype(result_t), wrap(at)
+ return w.astype(_realType(result_t)), wrap(at)
# Singular value decomposition
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index cc50595e2..ca9557117 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -17,19 +17,25 @@ def assert_almost_equal(a, b, **kw):
old_assert_almost_equal(a, b, decimal=decimal, **kw)
class LinalgTestCase(NumpyTestCase):
- def _check(self, dtype):
- a = array([[1.,2.], [3.,4.]], dtype=dtype)
- b = array([2., 1.], dtype=dtype)
+ def check_single(self):
+ a = array([[1.,2.], [3.,4.]], dtype=single)
+ b = array([2., 1.], dtype=single)
self.do(a, b)
- def check_single(self):
- self._check(single)
def check_double(self):
- self._check(double)
+ a = array([[1.,2.], [3.,4.]], dtype=double)
+ b = array([2., 1.], dtype=double)
+ self.do(a, b)
+
def check_csingle(self):
- self._check(csingle)
+ a = array([[1.+2j,2+3j], [3+4j,4+5j]], dtype=csingle)
+ b = array([2.+1j, 1.+2j], dtype=csingle)
+ self.do(a, b)
+
def check_cdouble(self):
- self._check(cdouble)
+ a = array([[1.+2j,2+3j], [3+4j,4+5j]], dtype=cdouble)
+ b = array([2.+1j, 1.+2j], dtype=cdouble)
+ self.do(a, b)
class test_solve(LinalgTestCase):
def do(self, a, b):