summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-08-29 02:47:39 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-08-29 02:47:39 +0000
commit762dacd686faa83c7de12c115e0a5874eef0481e (patch)
treeba39d11ca3cdf5f1c62e5f6fcc61f89c85af4ee5 /numpy/core/numeric.py
parent8d80b96d2aa1f498b80fc19b296ff881185cd363 (diff)
downloadnumpy-762dacd686faa83c7de12c115e0a5874eef0481e.tar.gz
Add float, int, etc. to numpy name-space. Flesh out tensordot. Fix-up getcharbuf to allow all 8-bit types to be returned as a charbuf.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py57
1 files changed, 44 insertions, 13 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 34c8b17d0..cd23c7cc0 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -190,7 +190,7 @@ def argwhere(a):
is a sequence of indices into a. This sequence must be
converted to a tuple in order to be used to index into a.
"""
- return transpose(a.nonzero())
+ return asarray(a.nonzero()).T
def flatnonzero(a):
"""Return indicies that are not-zero in flattened version of a
@@ -252,6 +252,7 @@ except ImportError:
def restoredot():
pass
+
def tensordot(a, b, axes=(-1,0)):
"""tensordot returns the product for any (ndim >= 1) arrays.
@@ -259,37 +260,67 @@ def tensordot(a, b, axes=(-1,0)):
the axes to be summed over are given by the axes argument.
the first element of the sequence determines the axis or axes
- in arr1 to sum over and the second element in axes argument sequence
+ in arr1 to sum over, and the second element in axes argument sequence
+ determines the axis or axes in arr2 to sum over.
+
+ When there is more than one axis to sum over, the corresponding
+ arguments to axes should be sequences of the same length with the first
+ axis to sum over given first in both sequences, the second axis second,
+ and so forth.
"""
axes_a, axes_b = axes
try:
na = len(axes_a)
+ axes_a = list(axes_a)
except TypeError:
axes_a = [axes_a]
na = 1
try:
nb = len(axes_b)
+ axes_b = list(axes_b)
except TypeError:
axes_b = [axes_b]
nb = 1
a, b = asarray(a), asarray(b)
as = a.shape
+ nda = len(a.shape)
bs = b.shape
+ ndb = len(b.shape)
equal = 1
if (na != nb): equal = 0
- for k in xrange(na):
- if as[axes_a[k]] != bs[axes_b[k]]:
- equal = 0
- break
-
+ else:
+ for k in xrange(na):
+ if as[axes_a[k]] != bs[axes_b[k]]:
+ equal = 0
+ break
+ if axes_a[k] < 0:
+ axes_a[k] += nda
+ if axes_b[k] < 0:
+ axes_b[k] += ndb
if not equal:
- raise ValueError, "shape-mismatch for sum"
-
- olda = [k for k in aa if k not in axes_a]
- oldb = [k for k in bs if k not in axes_b]
-
- at = a.reshape(nd1, nd2)
+ raise ValueError, "shape-mismatch for sum"
+
+ # Move the axes to sum over to the end of "a"
+ # and to the front of "b"
+ notin = [k for k in range(nda) if k not in axes_a]
+ newaxes_a = notin + axes_a
+ N2 = 1
+ for axis in axes_a:
+ N2 *= as[axis]
+ newshape_a = (-1, N2)
+ olda = [as[axis] for axis in notin]
+
+ notin = [k for k in range(ndb) if k not in axes_b]
+ newaxes_b = axes_b + notin
+ N2 = 1
+ for axis in axes_b:
+ N2 *= bs[axis]
+ newshape_b = (N2, -1)
+ oldb = [bs[axis] for axis in notin]
+
+ at = a.transpose(newaxes_a).reshape(newshape_a)
+ bt = b.transpose(newaxes_b).reshape(newshape_b)
res = dot(at, bt)
return res.reshape(olda + oldb)