diff options
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 79 |
1 files changed, 43 insertions, 36 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 7a3821675..bcbcd3f8d 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -19,30 +19,37 @@ from numpy.core import * from numpy.lib import * import lapack_lite +fortran_int = int32 + # Error object class LinAlgError(Exception): pass -# Helper routines -_array_kind = {'i':0, 'l': 0, 'q': 0, 'f': 0, 'd': 0, 'F': 1, 'D': 1} -_array_precision = {'i': 1, 'l': 1, 'q': 1, 'f': 0, 'd': 1, 'F': 0, 'D': 1} -_array_type = [['f', 'd'], ['F', 'D']] - def _makearray(a): new = asarray(a) wrap = getattr(a, "__array_wrap__", new.__array_wrap__) return new, wrap def _commonType(*arrays): - kind = 0 -# precision = 0 -# force higher precision in lite version - precision = 1 + # in lite version, use higher precision (always double or cdouble) + maxtype = (0, double) for a in arrays: - t = a.dtype.char - kind = max(kind, _array_kind[t]) - precision = max(precision, _array_precision[t]) - return _array_type[kind][precision] + if issubclass(a.dtype.type, inexact): + if a.dtype.type in (single, double): + t = (0, double) + elif a.dtype.type in (csingle, cdouble): + t = (1, cdouble) + else: + # unsupported inexact scalar + raise TypeError("array type %s is unsupported in linalg" % + (a.dtype.name,)) + else: + t = (0, double) + maxtype = max(maxtype, t) + return maxtype[1] + +def _realType(t): + return {double : double, cdouble : double} def _castCopyAndTranspose(type, *arrays): if len(arrays) == 1: @@ -93,12 +100,12 @@ def solve(a, b): raise LinAlgError, 'Incompatible dimensions' t =_commonType(a, b) # lapack_routine = _findLapackRoutine('gesv', t) - if _array_kind[t] == 1: # Complex routines take different arguments + if issubclass(t, complexfloating): lapack_routine = lapack_lite.zgesv else: lapack_routine = lapack_lite.dgesv a, b = _fastCopyAndTranspose(t, a, b) - pivots = zeros(n_eq, 'i') + pivots = zeros(n_eq, fortran_int) results = lapack_routine(n_eq, n_rhs, a, n_eq, pivots, b, n_eq, 0) if results['info'] > 0: raise LinAlgError, 'Singular matrix' @@ -123,7 +130,7 @@ def cholesky(a): a = _castCopyAndTranspose(t, a) m = a.shape[0] n = a.shape[1] - if _array_kind[t] == 1: + if issubclass(t, complexfloating): lapack_routine = lapack_lite.zpotrf else: lapack_routine = lapack_lite.dpotrf @@ -138,11 +145,11 @@ def eigvals(a): _assertRank2(a) _assertSquareness(a) t =_commonType(a) - real_t = _array_type[0][_array_precision[t]] + real_t = _realType(t) a = _fastCopyAndTranspose(t, a) n = a.shape[0] dummy = zeros((1,), t) - if _array_kind[t] == 1: # Complex routines take different arguments + if issubclass(t, complexfloating): lapack_routine = lapack_lite.zgeev w = zeros((n,), t) rwork = zeros((n,),real_t) @@ -178,13 +185,13 @@ def eigvals(a): def eigvalsh(a, UPLO='L'): _assertRank2(a) _assertSquareness(a) - t =_commonType(a) - real_t = _array_type[0][_array_precision[t]] + t = _commonType(a) + real_t = _realType(t) a = _castCopyAndTranspose(t, a) n = a.shape[0] liwork = 5*n+3 - iwork = zeros((liwork,),'i') - if _array_kind[t] == 1: # Complex routines take different arguments + iwork = zeros((liwork,), fortran_int) + if issubclass(t, complexfloating): lapack_routine = lapack_lite.zheevd w = zeros((n,), real_t) lwork = 1 @@ -288,13 +295,13 @@ def eigh(a, UPLO='L'): a, wrap = _makearray(a) _assertRank2(a) _assertSquareness(a) - t =_commonType(a) - real_t = _array_type[0][_array_precision[t]] + t = _commonType(a) + real_t = _realType(t) a = _castCopyAndTranspose(t, a) n = a.shape[0] liwork = 5*n+3 - iwork = zeros((liwork,),'i') - if _array_kind[t] == 1: # Complex routines take different arguments + iwork = zeros((liwork,), fortran_int) + if issubclass(t, complexfloating): lapack_routine = lapack_lite.zheevd w = zeros((n,), real_t) lwork = 1 @@ -343,8 +350,8 @@ def svd(a, full_matrices=1, compute_uv=1): a, wrap = _makearray(a) _assertRank2(a) m, n = a.shape - t =_commonType(a) - real_t = _array_type[0][_array_precision[t]] + t = _commonType(a) + real_t = _realType(t) a = _fastCopyAndTranspose(t, a) s = zeros((min(n,m),), real_t) if compute_uv: @@ -365,8 +372,8 @@ def svd(a, full_matrices=1, compute_uv=1): u = empty((1,1),t) vt = empty((1,1),t) - iwork = zeros((8*min(m,n),), 'i') - if _array_kind[t] == 1: # Complex routines take different arguments + iwork = zeros((8*min(m,n),), fortran_int) + if issubclass(t, complexfloating): lapack_routine = lapack_lite.zgesdd rwork = zeros((5*min(m,n)*min(m,n) + 5*min(m,n),), real_t) lwork = 1 @@ -438,11 +445,11 @@ def det(a): t =_commonType(a) a = _fastCopyAndTranspose(t, a) n = a.shape[0] - if _array_kind[t] == 1: + if issubclass(t, complexfloating): lapack_routine = lapack_lite.zgetrf else: lapack_routine = lapack_lite.dgetrf - pivots = zeros((n,), 'i') + pivots = zeros((n,), fortran_int) results = lapack_routine(n, n, a, n, pivots, 0) sign = add.reduce(not_equal(pivots, arange(1, n+1))) % 2 return (1.-2.*sign)*multiply.reduce(diagonal(a),axis=-1) @@ -475,15 +482,15 @@ Singular values less than s[0]*rcond are treated as zero. ldb = max(n,m) if m != b.shape[0]: raise LinAlgError, 'Incompatible dimensions' - t =_commonType(a, b) - real_t = _array_type[0][_array_precision[t]] + t = _commonType(a, b) + real_t = _realType(t) bstar = zeros((ldb,n_rhs),t) bstar[:b.shape[0],:n_rhs] = b.copy() a,bstar = _castCopyAndTranspose(t, a, bstar) s = zeros((min(m,n),),real_t) nlvl = max( 0, int( math.log( float(min( m,n ))/2. ) ) + 1 ) - iwork = zeros((3*min(m,n)*nlvl+11*min(m,n),), 'i') - if _array_kind[t] == 1: # Complex routines take different arguments + iwork = zeros((3*min(m,n)*nlvl+11*min(m,n),), fortran_int) + if issubclass(t, complexfloating): lapack_routine = lapack_lite.zgelsd lwork = 1 rwork = zeros((lwork,), real_t) |