diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-08-29 02:47:39 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-08-29 02:47:39 +0000 |
commit | 762dacd686faa83c7de12c115e0a5874eef0481e (patch) | |
tree | ba39d11ca3cdf5f1c62e5f6fcc61f89c85af4ee5 /numpy/core/numeric.py | |
parent | 8d80b96d2aa1f498b80fc19b296ff881185cd363 (diff) | |
download | numpy-762dacd686faa83c7de12c115e0a5874eef0481e.tar.gz |
Add float, int, etc. to numpy name-space. Flesh out tensordot. Fix-up getcharbuf to allow all 8-bit types to be returned as a charbuf.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 57 |
1 files changed, 44 insertions, 13 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 34c8b17d0..cd23c7cc0 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -190,7 +190,7 @@ def argwhere(a): is a sequence of indices into a. This sequence must be converted to a tuple in order to be used to index into a. """ - return transpose(a.nonzero()) + return asarray(a.nonzero()).T def flatnonzero(a): """Return indicies that are not-zero in flattened version of a @@ -252,6 +252,7 @@ except ImportError: def restoredot(): pass + def tensordot(a, b, axes=(-1,0)): """tensordot returns the product for any (ndim >= 1) arrays. @@ -259,37 +260,67 @@ def tensordot(a, b, axes=(-1,0)): the axes to be summed over are given by the axes argument. the first element of the sequence determines the axis or axes - in arr1 to sum over and the second element in axes argument sequence + in arr1 to sum over, and the second element in axes argument sequence + determines the axis or axes in arr2 to sum over. + + When there is more than one axis to sum over, the corresponding + arguments to axes should be sequences of the same length with the first + axis to sum over given first in both sequences, the second axis second, + and so forth. """ axes_a, axes_b = axes try: na = len(axes_a) + axes_a = list(axes_a) except TypeError: axes_a = [axes_a] na = 1 try: nb = len(axes_b) + axes_b = list(axes_b) except TypeError: axes_b = [axes_b] nb = 1 a, b = asarray(a), asarray(b) as = a.shape + nda = len(a.shape) bs = b.shape + ndb = len(b.shape) equal = 1 if (na != nb): equal = 0 - for k in xrange(na): - if as[axes_a[k]] != bs[axes_b[k]]: - equal = 0 - break - + else: + for k in xrange(na): + if as[axes_a[k]] != bs[axes_b[k]]: + equal = 0 + break + if axes_a[k] < 0: + axes_a[k] += nda + if axes_b[k] < 0: + axes_b[k] += ndb if not equal: - raise ValueError, "shape-mismatch for sum" - - olda = [k for k in aa if k not in axes_a] - oldb = [k for k in bs if k not in axes_b] - - at = a.reshape(nd1, nd2) + raise ValueError, "shape-mismatch for sum" + + # Move the axes to sum over to the end of "a" + # and to the front of "b" + notin = [k for k in range(nda) if k not in axes_a] + newaxes_a = notin + axes_a + N2 = 1 + for axis in axes_a: + N2 *= as[axis] + newshape_a = (-1, N2) + olda = [as[axis] for axis in notin] + + notin = [k for k in range(ndb) if k not in axes_b] + newaxes_b = axes_b + notin + N2 = 1 + for axis in axes_b: + N2 *= bs[axis] + newshape_b = (N2, -1) + oldb = [bs[axis] for axis in notin] + + at = a.transpose(newaxes_a).reshape(newshape_a) + bt = b.transpose(newaxes_b).reshape(newshape_b) res = dot(at, bt) return res.reshape(olda + oldb) |