summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-02-23 07:06:19 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-02-23 07:06:19 +0000
commit97342eff4ae87caeb83f4e419270b0da88a161e1 (patch)
treebc02dede08a8eed95b329d6721511fc13d2023b8 /numpy/linalg/linalg.py
parent5c566d6098fe1afd13290951fe60888317deb8b1 (diff)
downloadnumpy-97342eff4ae87caeb83f4e419270b0da88a161e1.tar.gz
Make matrices survive through more functions.
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 6c440d25a..cce432b6a 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -1,7 +1,7 @@
-"""Lite version of numpy.linalg.
+"""Lite version of scipy.linalg.
"""
-# This module is a lite version of LinAlg.py module which contains
+# This module is a lite version of the linalg.py module in SciPy which contains
# high-level Python interface to the LAPACK library. The lite version
# only accesses the following LAPACK functions: dgesv, zgesv, dgeev,
# zgeev, dgesdd, zgesdd, dgelsd, zgelsd, dsyevd, zheevd, dgetrf, dpotrf.
@@ -28,6 +28,14 @@ _array_kind = {'i':0, 'l': 0, 'f': 0, 'd': 0, 'F': 1, 'D': 1}
_array_precision = {'i': 1, 'l': 1, 'f': 0, 'd': 1, 'F': 0, 'D': 1}
_array_type = [['f', 'd'], ['F', 'D']]
+def _makearray(a):
+ new = asarray(a)
+ if isinstance(a, ndarray):
+ wrap = a.__array_wrap__
+ else:
+ wrap = new.__array_wrap__
+ return new, wrap
+
def _commonType(*arrays):
kind = 0
# precision = 0
@@ -228,11 +236,10 @@ def eig(a):
v is a matrix of eigenvectors with vector v[:,i] corresponds to
eigenvalue u[i]. Satisfies the equation dot(a, v[:,i]) = u[i]*v[:,i]
"""
- a = asarray(a)
+ a, wrap = _makearray(a)
_assertRank2(a)
_assertSquareness(a)
a,t = _convertarray(a) # convert to float_ or complex_ type
- wrap = a.__array_wrap__
real_t = 'd'
n = a.shape[0]
dummy = zeros((1,), t)
@@ -282,12 +289,12 @@ eigenvalue u[i]. Satisfies the equation dot(a, v[:,i]) = u[i]*v[:,i]
def eigh(a, UPLO='L'):
+ a, wrap = _makearray(a)
_assertRank2(a)
_assertSquareness(a)
t =_commonType(a)
real_t = _array_type[0][_array_precision[t]]
a = _castCopyAndTranspose(t, a)
- wrap = a.__array_wrap__
n = a.shape[0]
liwork = 5*n+3
iwork = zeros((liwork,),'i')
@@ -321,12 +328,12 @@ def eigh(a, UPLO='L'):
# Singular value decomposition
def svd(a, full_matrices=1):
+ a, wrap = _makearray(a)
_assertRank2(a)
m, n = a.shape
t =_commonType(a)
real_t = _array_type[0][_array_precision[t]]
a = _fastCopyAndTranspose(t, a)
- wrap = a.__array_wrap__
if full_matrices:
nu = m
nvt = n
@@ -369,7 +376,7 @@ def svd(a, full_matrices=1):
# Generalized inverse
def generalized_inverse(a, rcond = 1.e-10):
- a = array(a, copy=0)
+ a, wrap = _makearray(a)
if a.dtype.char in typecodes['Complex']:
a = conjugate(a)
u, s, vt = svd(a, 0)
@@ -381,13 +388,13 @@ def generalized_inverse(a, rcond = 1.e-10):
s[i] = 1./s[i]
else:
s[i] = 0.;
- wrap = a.__array_wrap__
return wrap(dot(transpose(vt),
multiply(s[:, NewAxis],transpose(u))))
# Determinant
def determinant(a):
+ a = asarray(a)
_assertRank2(a)
_assertSquareness(a)
t =_commonType(a)
@@ -419,7 +426,7 @@ otherwise resids = sum((b-dot(A,x)**2).
Singular values less than s[0]*rcond are treated as zero.
"""
a = asarray(a)
- b = asarray(b)
+ b, wrap = _makearray(b)
one_eq = len(b.shape) == 1
if one_eq:
b = b[:, NewAxis]
@@ -477,7 +484,7 @@ Singular values less than s[0]*rcond are treated as zero.
x = transpose(bstar)[:n,:].copy()
if (results['rank']==n) and (m>n):
resids = sum((transpose(bstar)[n:,:])**2).copy()
- return x,resids,results['rank'],s[:min(n,m)].copy()
+ return wrap(x),resids,results['rank'],s[:min(n,m)].copy()
def singular_value_decomposition(A, full_matrices=0):
return svd(A, full_matrices)