diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-10-13 18:18:18 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-10-13 18:18:18 +0000 |
commit | 4f83a34fe0961bf00e27313949676c4101c44b73 (patch) | |
tree | f1b40f9170ab14cfcd13c9b01c26550806fd2a34 /numpy/linalg | |
parent | 45f9558ec55e04a214b1e259fb0a73ed08c1bd9f (diff) | |
download | numpy-4f83a34fe0961bf00e27313949676c4101c44b73.tar.gz |
Fix-up tensor solve and tensor inv and rename to match tensordot.
Diffstat (limited to 'numpy/linalg')
-rw-r--r-- | numpy/linalg/linalg.py | 59 |
1 files changed, 34 insertions, 25 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 2d72c7a8b..cd70e9776 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -6,7 +6,7 @@ only accesses the following LAPACK functions: dgesv, zgesv, dgeev, zgeev, dgesdd, zgesdd, dgelsd, zgelsd, dsyevd, zheevd, dgetrf, dpotrf. """ -__all__ = ['solve', 'solvetensor', 'invtensor', +__all__ = ['solve', 'tensorsolve', 'tensorinv', 'inv', 'cholesky', 'eigvals', 'eigvalsh', 'pinv', @@ -122,10 +122,11 @@ def _assertSquareness(*arrays): # Linear equations -def solvetensor(a, b, axes=None): +def tensorsolve(a, b, axes=None): """Solves the tensor equation a x = b for x - where it is assumed that all the indices of x are summed over in the product. + where it is assumed that all the indices of x are summed over in + the product. a can be N-dimensional. x will have the dimensions of A subtracted from the dimensions of b. @@ -181,40 +182,48 @@ def solve(a, b): return b.transpose().astype(result_t) -def invtensor(a, ind=2): - """Find the inverse tensor. +def tensorinv(a, ind=2): + """Find the 'inverse' of a N-d array - ind > 0 ==> first (ind) indices of a are summed over - ind < 0 ==> last (-ind) indices of a are summed over + ind must be a positive integer specifying + how many indices at the front of the array are involved + in the inverse sum. + + the result is ainv with shape a.shape[ind:] + a.shape[:ind] - if ind is a list, then it specifies the summed over axes + tensordot(ainv, a, ind) is an identity operator - When the inv tensor and the tensor are summed over the - indicated axes a separable identity tensor remains. + and so is - The inverse has the summed over axes at the end. - """ + tensordot(a, ainv, a.shape-newind) + + Example: + + a = rand(4,6,8,3) + ainv = tensorinv(a) + # ainv.shape is (8,3,4,6) + # suppose b has shape (4,6) + tensordot(ainv, b) # produces same (8,3)-shaped output as + tensorsolve(a, b) + + a = rand(24,8,3) + ainv = tensorinv(a,1) + # ainv.shape is (8,3,24) + # suppose b has shape (24,) + tensordot(ainv, b, 1) # produces the same (8,3)-shaped output as + tensorsolve(a, b) + """ a = asarray(a) oldshape = a.shape prod = 1 - if iterable(ind): - invshape = range(a.ndim) - for axis in ind: - invshape.remove(axis) - invshape.insert(a.ndim,axis) - prod *= oldshape[axis] - elif ind > 0: + if ind > 0: invshape = oldshape[ind:] + oldshape[:ind] - for k in oldshape[:ind]: - prod *= k - elif ind < 0: - invshape = oldshape[:-ind] + oldshape[-ind:] - for k in oldshape[-ind:]: + for k in oldshape[ind:]: prod *= k else: raise ValueError, "Invalid ind argument." - a = a.reshape(-1,prod) + a = a.reshape(prod,-1) ia = inv(a) return ia.reshape(*invshape) |