summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorcookedm <cookedm@localhost>2006-07-04 20:43:11 +0000
committercookedm <cookedm@localhost>2006-07-04 20:43:11 +0000
commitd4840f8838361de985aa8858aaf99bbe85a43621 (patch)
tree8473e23db0754c933f248261b6ba6db39593a1de
parenta8672c2c5e1fe4ae0c51805aa1eed5d736a5eedf (diff)
downloadnumpy-d4840f8838361de985aa8858aaf99bbe85a43621.tar.gz
Convert linalg to use dtypes instead of typecodes
-rw-r--r--numpy/linalg/linalg.py79
1 files changed, 43 insertions, 36 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 7a3821675..bcbcd3f8d 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -19,30 +19,37 @@ from numpy.core import *
from numpy.lib import *
import lapack_lite
+fortran_int = int32
+
# Error object
class LinAlgError(Exception):
pass
-# Helper routines
-_array_kind = {'i':0, 'l': 0, 'q': 0, 'f': 0, 'd': 0, 'F': 1, 'D': 1}
-_array_precision = {'i': 1, 'l': 1, 'q': 1, 'f': 0, 'd': 1, 'F': 0, 'D': 1}
-_array_type = [['f', 'd'], ['F', 'D']]
-
def _makearray(a):
new = asarray(a)
wrap = getattr(a, "__array_wrap__", new.__array_wrap__)
return new, wrap
def _commonType(*arrays):
- kind = 0
-# precision = 0
-# force higher precision in lite version
- precision = 1
+ # in lite version, use higher precision (always double or cdouble)
+ maxtype = (0, double)
for a in arrays:
- t = a.dtype.char
- kind = max(kind, _array_kind[t])
- precision = max(precision, _array_precision[t])
- return _array_type[kind][precision]
+ if issubclass(a.dtype.type, inexact):
+ if a.dtype.type in (single, double):
+ t = (0, double)
+ elif a.dtype.type in (csingle, cdouble):
+ t = (1, cdouble)
+ else:
+ # unsupported inexact scalar
+ raise TypeError("array type %s is unsupported in linalg" %
+ (a.dtype.name,))
+ else:
+ t = (0, double)
+ maxtype = max(maxtype, t)
+ return maxtype[1]
+
+def _realType(t):
+ return {double : double, cdouble : double}
def _castCopyAndTranspose(type, *arrays):
if len(arrays) == 1:
@@ -93,12 +100,12 @@ def solve(a, b):
raise LinAlgError, 'Incompatible dimensions'
t =_commonType(a, b)
# lapack_routine = _findLapackRoutine('gesv', t)
- if _array_kind[t] == 1: # Complex routines take different arguments
+ if issubclass(t, complexfloating):
lapack_routine = lapack_lite.zgesv
else:
lapack_routine = lapack_lite.dgesv
a, b = _fastCopyAndTranspose(t, a, b)
- pivots = zeros(n_eq, 'i')
+ pivots = zeros(n_eq, fortran_int)
results = lapack_routine(n_eq, n_rhs, a, n_eq, pivots, b, n_eq, 0)
if results['info'] > 0:
raise LinAlgError, 'Singular matrix'
@@ -123,7 +130,7 @@ def cholesky(a):
a = _castCopyAndTranspose(t, a)
m = a.shape[0]
n = a.shape[1]
- if _array_kind[t] == 1:
+ if issubclass(t, complexfloating):
lapack_routine = lapack_lite.zpotrf
else:
lapack_routine = lapack_lite.dpotrf
@@ -138,11 +145,11 @@ def eigvals(a):
_assertRank2(a)
_assertSquareness(a)
t =_commonType(a)
- real_t = _array_type[0][_array_precision[t]]
+ real_t = _realType(t)
a = _fastCopyAndTranspose(t, a)
n = a.shape[0]
dummy = zeros((1,), t)
- if _array_kind[t] == 1: # Complex routines take different arguments
+ if issubclass(t, complexfloating):
lapack_routine = lapack_lite.zgeev
w = zeros((n,), t)
rwork = zeros((n,),real_t)
@@ -178,13 +185,13 @@ def eigvals(a):
def eigvalsh(a, UPLO='L'):
_assertRank2(a)
_assertSquareness(a)
- t =_commonType(a)
- real_t = _array_type[0][_array_precision[t]]
+ t = _commonType(a)
+ real_t = _realType(t)
a = _castCopyAndTranspose(t, a)
n = a.shape[0]
liwork = 5*n+3
- iwork = zeros((liwork,),'i')
- if _array_kind[t] == 1: # Complex routines take different arguments
+ iwork = zeros((liwork,), fortran_int)
+ if issubclass(t, complexfloating):
lapack_routine = lapack_lite.zheevd
w = zeros((n,), real_t)
lwork = 1
@@ -288,13 +295,13 @@ def eigh(a, UPLO='L'):
a, wrap = _makearray(a)
_assertRank2(a)
_assertSquareness(a)
- t =_commonType(a)
- real_t = _array_type[0][_array_precision[t]]
+ t = _commonType(a)
+ real_t = _realType(t)
a = _castCopyAndTranspose(t, a)
n = a.shape[0]
liwork = 5*n+3
- iwork = zeros((liwork,),'i')
- if _array_kind[t] == 1: # Complex routines take different arguments
+ iwork = zeros((liwork,), fortran_int)
+ if issubclass(t, complexfloating):
lapack_routine = lapack_lite.zheevd
w = zeros((n,), real_t)
lwork = 1
@@ -343,8 +350,8 @@ def svd(a, full_matrices=1, compute_uv=1):
a, wrap = _makearray(a)
_assertRank2(a)
m, n = a.shape
- t =_commonType(a)
- real_t = _array_type[0][_array_precision[t]]
+ t = _commonType(a)
+ real_t = _realType(t)
a = _fastCopyAndTranspose(t, a)
s = zeros((min(n,m),), real_t)
if compute_uv:
@@ -365,8 +372,8 @@ def svd(a, full_matrices=1, compute_uv=1):
u = empty((1,1),t)
vt = empty((1,1),t)
- iwork = zeros((8*min(m,n),), 'i')
- if _array_kind[t] == 1: # Complex routines take different arguments
+ iwork = zeros((8*min(m,n),), fortran_int)
+ if issubclass(t, complexfloating):
lapack_routine = lapack_lite.zgesdd
rwork = zeros((5*min(m,n)*min(m,n) + 5*min(m,n),), real_t)
lwork = 1
@@ -438,11 +445,11 @@ def det(a):
t =_commonType(a)
a = _fastCopyAndTranspose(t, a)
n = a.shape[0]
- if _array_kind[t] == 1:
+ if issubclass(t, complexfloating):
lapack_routine = lapack_lite.zgetrf
else:
lapack_routine = lapack_lite.dgetrf
- pivots = zeros((n,), 'i')
+ pivots = zeros((n,), fortran_int)
results = lapack_routine(n, n, a, n, pivots, 0)
sign = add.reduce(not_equal(pivots, arange(1, n+1))) % 2
return (1.-2.*sign)*multiply.reduce(diagonal(a),axis=-1)
@@ -475,15 +482,15 @@ Singular values less than s[0]*rcond are treated as zero.
ldb = max(n,m)
if m != b.shape[0]:
raise LinAlgError, 'Incompatible dimensions'
- t =_commonType(a, b)
- real_t = _array_type[0][_array_precision[t]]
+ t = _commonType(a, b)
+ real_t = _realType(t)
bstar = zeros((ldb,n_rhs),t)
bstar[:b.shape[0],:n_rhs] = b.copy()
a,bstar = _castCopyAndTranspose(t, a, bstar)
s = zeros((min(m,n),),real_t)
nlvl = max( 0, int( math.log( float(min( m,n ))/2. ) ) + 1 )
- iwork = zeros((3*min(m,n)*nlvl+11*min(m,n),), 'i')
- if _array_kind[t] == 1: # Complex routines take different arguments
+ iwork = zeros((3*min(m,n)*nlvl+11*min(m,n),), fortran_int)
+ if issubclass(t, complexfloating):
lapack_routine = lapack_lite.zgelsd
lwork = 1
rwork = zeros((lwork,), real_t)