summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py43
1 files changed, 42 insertions, 1 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index b77fed3db..b89fc0586 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -7,7 +7,7 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'ufunc',
'asarray', 'asanyarray', 'ascontiguousarray', 'asfortranarray',
'isfortran', 'empty_like', 'zeros_like',
'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot',
- 'alterdot', 'restoredot', 'cross',
+ 'alterdot', 'restoredot', 'cross', 'tensordot',
'array2string', 'get_printoptions', 'set_printoptions',
'array_repr', 'array_str', 'set_string_function',
'little_endian', 'require',
@@ -252,6 +252,47 @@ except ImportError:
def restoredot():
pass
+def tensordot(a, b, axes=(-1,0))
+ """tensordot returns the product for any (ndim >= 1) arrays.
+
+ r_{xxx, yyy} = \sum_k a_{xxx,k} b_{k,yyy} where
+
+ the axes to be summed over are given by the axes argument.
+ the first element of the sequence determines the axis or axes
+ in arr1 to sum over and the second element in axes argument sequence
+ """
+ axes_a, axes_b = axes
+ try:
+ na = len(axes_a)
+ except TypeError:
+ axes_a = [axes_a]
+ na = 1
+ try:
+ nb = len(axes_b)
+ except TypeError:
+ axes_b = [axes_b]
+ nb = 1
+
+ a, b = asarray(a), asarray(b)
+ as = a.shape
+ bs = b.shape
+ equal = 1
+ if (na != nb): equal = 0
+ for k in xrange(na):
+ if as[axes_a[k]] != bs[axes_b[k]]:
+ equal = 0
+ break
+
+ if not equal:
+ raise ValueError, "shape-mismatch for sum"
+
+ olda = [ for k in aa if k not in axes_a]
+ oldb = [k for k in bs if k not in axes_b]
+
+ at = a.reshape(nd1, nd2)
+ res = dot(at, bt)
+ return res.reshape(olda + oldb)
+
def _move_axis_to_0(a, axis):
if axis == 0: