diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-08-28 19:46:08 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-08-28 19:46:08 +0000 |
commit | 0f23250d0f7f84f1c69ccd85d13338714343d4c1 (patch) | |
tree | 48d8356353a0738a31a8a73fb55241c03bfd8471 /numpy/core/numeric.py | |
parent | 59286e9133cfc1d71b92313e0b0443e31f03e8ce (diff) | |
download | numpy-0f23250d0f7f84f1c69ccd85d13338714343d4c1.tar.gz |
Merge changes mistakenly added to 1.0b4 tag to the main trunk
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 43 |
1 files changed, 42 insertions, 1 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index b77fed3db..b89fc0586 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -7,7 +7,7 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'ufunc', 'asarray', 'asanyarray', 'ascontiguousarray', 'asfortranarray', 'isfortran', 'empty_like', 'zeros_like', 'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot', - 'alterdot', 'restoredot', 'cross', + 'alterdot', 'restoredot', 'cross', 'tensordot', 'array2string', 'get_printoptions', 'set_printoptions', 'array_repr', 'array_str', 'set_string_function', 'little_endian', 'require', @@ -252,6 +252,47 @@ except ImportError: def restoredot(): pass +def tensordot(a, b, axes=(-1,0)) + """tensordot returns the product for any (ndim >= 1) arrays. + + r_{xxx, yyy} = \sum_k a_{xxx,k} b_{k,yyy} where + + the axes to be summed over are given by the axes argument. + the first element of the sequence determines the axis or axes + in arr1 to sum over and the second element in axes argument sequence + """ + axes_a, axes_b = axes + try: + na = len(axes_a) + except TypeError: + axes_a = [axes_a] + na = 1 + try: + nb = len(axes_b) + except TypeError: + axes_b = [axes_b] + nb = 1 + + a, b = asarray(a), asarray(b) + as = a.shape + bs = b.shape + equal = 1 + if (na != nb): equal = 0 + for k in xrange(na): + if as[axes_a[k]] != bs[axes_b[k]]: + equal = 0 + break + + if not equal: + raise ValueError, "shape-mismatch for sum" + + olda = [ for k in aa if k not in axes_a] + oldb = [k for k in bs if k not in axes_b] + + at = a.reshape(nd1, nd2) + res = dot(at, bt) + return res.reshape(olda + oldb) + def _move_axis_to_0(a, axis): if axis == 0: |