diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-08-29 02:47:39 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-08-29 02:47:39 +0000 |
commit | 762dacd686faa83c7de12c115e0a5874eef0481e (patch) | |
tree | ba39d11ca3cdf5f1c62e5f6fcc61f89c85af4ee5 /numpy | |
parent | 8d80b96d2aa1f498b80fc19b296ff881185cd363 (diff) | |
download | numpy-762dacd686faa83c7de12c115e0a5874eef0481e.tar.gz |
Add float, int, etc. to numpy name-space. Flesh out tensordot. Fix-up getcharbuf to allow all 8-bit types to be returned as a charbuf.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/__init__.py | 4 | ||||
-rw-r--r-- | numpy/core/numeric.py | 57 | ||||
-rw-r--r-- | numpy/core/src/arrayobject.c | 7 |
3 files changed, 52 insertions, 16 deletions
diff --git a/numpy/__init__.py b/numpy/__init__.py index e864fb9b2..4e8d3fcd0 100644 --- a/numpy/__init__.py +++ b/numpy/__init__.py @@ -36,6 +36,10 @@ else: from core import * import lib from lib import * + # Make these accessible from numpy name-space + # but not imported in from numpy import * + from __builtin__ import bool, int, long, float, complex, \ + object, unicode, str import linalg import fft import random diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 34c8b17d0..cd23c7cc0 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -190,7 +190,7 @@ def argwhere(a): is a sequence of indices into a. This sequence must be converted to a tuple in order to be used to index into a. """ - return transpose(a.nonzero()) + return asarray(a.nonzero()).T def flatnonzero(a): """Return indicies that are not-zero in flattened version of a @@ -252,6 +252,7 @@ except ImportError: def restoredot(): pass + def tensordot(a, b, axes=(-1,0)): """tensordot returns the product for any (ndim >= 1) arrays. @@ -259,37 +260,67 @@ def tensordot(a, b, axes=(-1,0)): the axes to be summed over are given by the axes argument. the first element of the sequence determines the axis or axes - in arr1 to sum over and the second element in axes argument sequence + in arr1 to sum over, and the second element in axes argument sequence + determines the axis or axes in arr2 to sum over. + + When there is more than one axis to sum over, the corresponding + arguments to axes should be sequences of the same length with the first + axis to sum over given first in both sequences, the second axis second, + and so forth. """ axes_a, axes_b = axes try: na = len(axes_a) + axes_a = list(axes_a) except TypeError: axes_a = [axes_a] na = 1 try: nb = len(axes_b) + axes_b = list(axes_b) except TypeError: axes_b = [axes_b] nb = 1 a, b = asarray(a), asarray(b) as = a.shape + nda = len(a.shape) bs = b.shape + ndb = len(b.shape) equal = 1 if (na != nb): equal = 0 - for k in xrange(na): - if as[axes_a[k]] != bs[axes_b[k]]: - equal = 0 - break - + else: + for k in xrange(na): + if as[axes_a[k]] != bs[axes_b[k]]: + equal = 0 + break + if axes_a[k] < 0: + axes_a[k] += nda + if axes_b[k] < 0: + axes_b[k] += ndb if not equal: - raise ValueError, "shape-mismatch for sum" - - olda = [k for k in aa if k not in axes_a] - oldb = [k for k in bs if k not in axes_b] - - at = a.reshape(nd1, nd2) + raise ValueError, "shape-mismatch for sum" + + # Move the axes to sum over to the end of "a" + # and to the front of "b" + notin = [k for k in range(nda) if k not in axes_a] + newaxes_a = notin + axes_a + N2 = 1 + for axis in axes_a: + N2 *= as[axis] + newshape_a = (-1, N2) + olda = [as[axis] for axis in notin] + + notin = [k for k in range(ndb) if k not in axes_b] + newaxes_b = axes_b + notin + N2 = 1 + for axis in axes_b: + N2 *= bs[axis] + newshape_b = (N2, -1) + oldb = [bs[axis] for axis in notin] + + at = a.transpose(newaxes_a).reshape(newshape_a) + bt = b.transpose(newaxes_b).reshape(newshape_b) res = dot(at, bt) return res.reshape(olda + oldb) diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index abc01614a..d66e01647 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -3038,12 +3038,13 @@ static Py_ssize_t array_getcharbuf(PyArrayObject *self, Py_ssize_t segment, constchar **ptrptr) { if (self->descr->type_num == PyArray_STRING || \ - self->descr->type_num == PyArray_UNICODE) + self->descr->type_num == PyArray_UNICODE || \ + self->descr->elsize == 1) return array_getreadbuf(self, segment, (void **) ptrptr); else { PyErr_SetString(PyExc_TypeError, - "non-character array cannot be interpreted "\ - "as character buffer"); + "non-character (or 8-bit) array cannot be "\ + "interpreted as character buffer"); return -1; } } |