diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-09-15 22:40:54 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-09-15 22:40:54 +0000 |
commit | 972ae975594790108be7bd3661d0d0be007e048c (patch) | |
tree | 640649966869c164509c9db79c7ce97311b52ada /numpy/core/numeric.py | |
parent | bef75b3f149af546a4aa592bee9a1acecc267898 (diff) | |
download | numpy-972ae975594790108be7bd3661d0d0be007e048c.tar.gz |
Add rollaxis command and fix cross function
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index f7733a92a..d9959a167 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -7,7 +7,7 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'ufunc', 'asarray', 'asanyarray', 'ascontiguousarray', 'asfortranarray', 'isfortran', 'empty_like', 'zeros_like', 'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot', - 'alterdot', 'restoredot', 'cross', 'tensordot', + 'alterdot', 'restoredot', 'rollaxis', 'cross', 'tensordot', 'array2string', 'get_printoptions', 'set_printoptions', 'array_repr', 'array_str', 'set_string_function', 'little_endian', 'require', @@ -322,16 +322,38 @@ def tensordot(a, b, axes=(-1,0)): res = dot(at, bt) return res.reshape(olda + oldb) +def rollaxis(a, axis, start=0): + """Return transposed array so that axis is rolled before start. -def _move_axis_to_0(a, axis): - if axis == 0: - return a + if a.shape is (3,4,5,6) + rollaxis(a, 3, 1).shape is (3,6,4,5) + rollaxis(a, 2, 0).shape is (5,3,4,6) + rollaxis(a, 1, 3).shape is (3,5,4,6) + rollaxis(a, 1, 4).shape is (3,5,6,4) + """ n = a.ndim if axis < 0: axis += n - axes = range(1, axis+1) + [0,] + range(axis+1, n) + if start < 0: + start += n + msg = 'rollaxis: %s (%d) must be >=0 and < %d' + if not (0 <= axis < n): + raise ValueError, msg % ('axis', axis, n) + if not (0 <= start < n+1): + raise ValueError, msg % ('start', start, n+1) + if (axis < start): # it's been removed + start -= 1 + if axis==start: + return a + axes = range(0,n) + axes.remove(axis) + axes.insert(start, axis) return a.transpose(axes) +# fix hack in scipy which imports this function +def _move_axis_to_0(a, axis): + return rollaxis(a, axis, 0) + def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): """Return the cross product of two (arrays of) vectors. @@ -342,8 +364,8 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): """ if axis is not None: axisa,axisb,axisc=(axis,)*3 - a = _move_axis_to_0(asarray(a), axisa) - b = _move_axis_to_0(asarray(b), axisb) + a = asarray(a).swapaxes(axisa, 0) + b = asarray(b).swapaxes(axisb, 0) msg = "incompatible dimensions for cross product\n"\ "(dimension must be 2 or 3)" if (a.shape[0] not in [2,3]) or (b.shape[0] not in [2,3]): @@ -354,7 +376,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): if cp.ndim == 0: return cp else: - return cp.swapaxes(0,axisc) + return cp.swapaxes(0, axisc) else: x = a[1]*b[2] y = -a[0]*b[2] |