summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py59
1 files changed, 34 insertions, 25 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 2d72c7a8b..cd70e9776 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -6,7 +6,7 @@ only accesses the following LAPACK functions: dgesv, zgesv, dgeev,
zgeev, dgesdd, zgesdd, dgelsd, zgelsd, dsyevd, zheevd, dgetrf, dpotrf.
"""
-__all__ = ['solve', 'solvetensor', 'invtensor',
+__all__ = ['solve', 'tensorsolve', 'tensorinv',
'inv', 'cholesky',
'eigvals',
'eigvalsh', 'pinv',
@@ -122,10 +122,11 @@ def _assertSquareness(*arrays):
# Linear equations
-def solvetensor(a, b, axes=None):
+def tensorsolve(a, b, axes=None):
"""Solves the tensor equation a x = b for x
- where it is assumed that all the indices of x are summed over in the product.
+ where it is assumed that all the indices of x are summed over in
+ the product.
a can be N-dimensional. x will have the dimensions of A subtracted from
the dimensions of b.
@@ -181,40 +182,48 @@ def solve(a, b):
return b.transpose().astype(result_t)
-def invtensor(a, ind=2):
- """Find the inverse tensor.
+def tensorinv(a, ind=2):
+ """Find the 'inverse' of a N-d array
- ind > 0 ==> first (ind) indices of a are summed over
- ind < 0 ==> last (-ind) indices of a are summed over
+ ind must be a positive integer specifying
+ how many indices at the front of the array are involved
+ in the inverse sum.
+
+ the result is ainv with shape a.shape[ind:] + a.shape[:ind]
- if ind is a list, then it specifies the summed over axes
+ tensordot(ainv, a, ind) is an identity operator
- When the inv tensor and the tensor are summed over the
- indicated axes a separable identity tensor remains.
+ and so is
- The inverse has the summed over axes at the end.
- """
+ tensordot(a, ainv, a.shape-newind)
+
+ Example:
+
+ a = rand(4,6,8,3)
+ ainv = tensorinv(a)
+ # ainv.shape is (8,3,4,6)
+ # suppose b has shape (4,6)
+ tensordot(ainv, b) # produces same (8,3)-shaped output as
+ tensorsolve(a, b)
+
+ a = rand(24,8,3)
+ ainv = tensorinv(a,1)
+ # ainv.shape is (8,3,24)
+ # suppose b has shape (24,)
+ tensordot(ainv, b, 1) # produces the same (8,3)-shaped output as
+ tensorsolve(a, b)
+ """
a = asarray(a)
oldshape = a.shape
prod = 1
- if iterable(ind):
- invshape = range(a.ndim)
- for axis in ind:
- invshape.remove(axis)
- invshape.insert(a.ndim,axis)
- prod *= oldshape[axis]
- elif ind > 0:
+ if ind > 0:
invshape = oldshape[ind:] + oldshape[:ind]
- for k in oldshape[:ind]:
- prod *= k
- elif ind < 0:
- invshape = oldshape[:-ind] + oldshape[-ind:]
- for k in oldshape[-ind:]:
+ for k in oldshape[ind:]:
prod *= k
else:
raise ValueError, "Invalid ind argument."
- a = a.reshape(-1,prod)
+ a = a.reshape(prod,-1)
ia = inv(a)
return ia.reshape(*invshape)