summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-08-28 19:46:08 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-08-28 19:46:08 +0000
commit0f23250d0f7f84f1c69ccd85d13338714343d4c1 (patch)
tree48d8356353a0738a31a8a73fb55241c03bfd8471 /numpy/core/numeric.py
parent59286e9133cfc1d71b92313e0b0443e31f03e8ce (diff)
downloadnumpy-0f23250d0f7f84f1c69ccd85d13338714343d4c1.tar.gz
Merge changes mistakenly added to 1.0b4 tag to the main trunk
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py43
1 files changed, 42 insertions, 1 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index b77fed3db..b89fc0586 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -7,7 +7,7 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'ufunc',
'asarray', 'asanyarray', 'ascontiguousarray', 'asfortranarray',
'isfortran', 'empty_like', 'zeros_like',
'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot',
- 'alterdot', 'restoredot', 'cross',
+ 'alterdot', 'restoredot', 'cross', 'tensordot',
'array2string', 'get_printoptions', 'set_printoptions',
'array_repr', 'array_str', 'set_string_function',
'little_endian', 'require',
@@ -252,6 +252,47 @@ except ImportError:
def restoredot():
pass
+def tensordot(a, b, axes=(-1,0))
+ """tensordot returns the product for any (ndim >= 1) arrays.
+
+ r_{xxx, yyy} = \sum_k a_{xxx,k} b_{k,yyy} where
+
+ the axes to be summed over are given by the axes argument.
+ the first element of the sequence determines the axis or axes
+ in arr1 to sum over and the second element in axes argument sequence
+ """
+ axes_a, axes_b = axes
+ try:
+ na = len(axes_a)
+ except TypeError:
+ axes_a = [axes_a]
+ na = 1
+ try:
+ nb = len(axes_b)
+ except TypeError:
+ axes_b = [axes_b]
+ nb = 1
+
+ a, b = asarray(a), asarray(b)
+ as = a.shape
+ bs = b.shape
+ equal = 1
+ if (na != nb): equal = 0
+ for k in xrange(na):
+ if as[axes_a[k]] != bs[axes_b[k]]:
+ equal = 0
+ break
+
+ if not equal:
+ raise ValueError, "shape-mismatch for sum"
+
+ olda = [ for k in aa if k not in axes_a]
+ oldb = [k for k in bs if k not in axes_b]
+
+ at = a.reshape(nd1, nd2)
+ res = dot(at, bt)
+ return res.reshape(olda + oldb)
+
def _move_axis_to_0(a, axis):
if axis == 0: