diff options
-rw-r--r-- | numpy/linalg/linalg.py | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 6c440d25a..cce432b6a 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -1,7 +1,7 @@ -"""Lite version of numpy.linalg. +"""Lite version of scipy.linalg. """ -# This module is a lite version of LinAlg.py module which contains +# This module is a lite version of the linalg.py module in SciPy which contains # high-level Python interface to the LAPACK library. The lite version # only accesses the following LAPACK functions: dgesv, zgesv, dgeev, # zgeev, dgesdd, zgesdd, dgelsd, zgelsd, dsyevd, zheevd, dgetrf, dpotrf. @@ -28,6 +28,14 @@ _array_kind = {'i':0, 'l': 0, 'f': 0, 'd': 0, 'F': 1, 'D': 1} _array_precision = {'i': 1, 'l': 1, 'f': 0, 'd': 1, 'F': 0, 'D': 1} _array_type = [['f', 'd'], ['F', 'D']] +def _makearray(a): + new = asarray(a) + if isinstance(a, ndarray): + wrap = a.__array_wrap__ + else: + wrap = new.__array_wrap__ + return new, wrap + def _commonType(*arrays): kind = 0 # precision = 0 @@ -228,11 +236,10 @@ def eig(a): v is a matrix of eigenvectors with vector v[:,i] corresponds to eigenvalue u[i]. Satisfies the equation dot(a, v[:,i]) = u[i]*v[:,i] """ - a = asarray(a) + a, wrap = _makearray(a) _assertRank2(a) _assertSquareness(a) a,t = _convertarray(a) # convert to float_ or complex_ type - wrap = a.__array_wrap__ real_t = 'd' n = a.shape[0] dummy = zeros((1,), t) @@ -282,12 +289,12 @@ eigenvalue u[i]. Satisfies the equation dot(a, v[:,i]) = u[i]*v[:,i] def eigh(a, UPLO='L'): + a, wrap = _makearray(a) _assertRank2(a) _assertSquareness(a) t =_commonType(a) real_t = _array_type[0][_array_precision[t]] a = _castCopyAndTranspose(t, a) - wrap = a.__array_wrap__ n = a.shape[0] liwork = 5*n+3 iwork = zeros((liwork,),'i') @@ -321,12 +328,12 @@ def eigh(a, UPLO='L'): # Singular value decomposition def svd(a, full_matrices=1): + a, wrap = _makearray(a) _assertRank2(a) m, n = a.shape t =_commonType(a) real_t = _array_type[0][_array_precision[t]] a = _fastCopyAndTranspose(t, a) - wrap = a.__array_wrap__ if full_matrices: nu = m nvt = n @@ -369,7 +376,7 @@ def svd(a, full_matrices=1): # Generalized inverse def generalized_inverse(a, rcond = 1.e-10): - a = array(a, copy=0) + a, wrap = _makearray(a) if a.dtype.char in typecodes['Complex']: a = conjugate(a) u, s, vt = svd(a, 0) @@ -381,13 +388,13 @@ def generalized_inverse(a, rcond = 1.e-10): s[i] = 1./s[i] else: s[i] = 0.; - wrap = a.__array_wrap__ return wrap(dot(transpose(vt), multiply(s[:, NewAxis],transpose(u)))) # Determinant def determinant(a): + a = asarray(a) _assertRank2(a) _assertSquareness(a) t =_commonType(a) @@ -419,7 +426,7 @@ otherwise resids = sum((b-dot(A,x)**2). Singular values less than s[0]*rcond are treated as zero. """ a = asarray(a) - b = asarray(b) + b, wrap = _makearray(b) one_eq = len(b.shape) == 1 if one_eq: b = b[:, NewAxis] @@ -477,7 +484,7 @@ Singular values less than s[0]*rcond are treated as zero. x = transpose(bstar)[:n,:].copy() if (results['rank']==n) and (m>n): resids = sum((transpose(bstar)[n:,:])**2).copy() - return x,resids,results['rank'],s[:min(n,m)].copy() + return wrap(x),resids,results['rank'],s[:min(n,m)].copy() def singular_value_decomposition(A, full_matrices=0): return svd(A, full_matrices) |