summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-09-15 22:40:54 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-09-15 22:40:54 +0000
commit972ae975594790108be7bd3661d0d0be007e048c (patch)
tree640649966869c164509c9db79c7ce97311b52ada /numpy/core/numeric.py
parentbef75b3f149af546a4aa592bee9a1acecc267898 (diff)
downloadnumpy-972ae975594790108be7bd3661d0d0be007e048c.tar.gz
Add rollaxis command and fix cross function
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py38
1 files changed, 30 insertions, 8 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index f7733a92a..d9959a167 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -7,7 +7,7 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'ufunc',
'asarray', 'asanyarray', 'ascontiguousarray', 'asfortranarray',
'isfortran', 'empty_like', 'zeros_like',
'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot',
- 'alterdot', 'restoredot', 'cross', 'tensordot',
+ 'alterdot', 'restoredot', 'rollaxis', 'cross', 'tensordot',
'array2string', 'get_printoptions', 'set_printoptions',
'array_repr', 'array_str', 'set_string_function',
'little_endian', 'require',
@@ -322,16 +322,38 @@ def tensordot(a, b, axes=(-1,0)):
res = dot(at, bt)
return res.reshape(olda + oldb)
+def rollaxis(a, axis, start=0):
+ """Return transposed array so that axis is rolled before start.
-def _move_axis_to_0(a, axis):
- if axis == 0:
- return a
+ if a.shape is (3,4,5,6)
+ rollaxis(a, 3, 1).shape is (3,6,4,5)
+ rollaxis(a, 2, 0).shape is (5,3,4,6)
+ rollaxis(a, 1, 3).shape is (3,5,4,6)
+ rollaxis(a, 1, 4).shape is (3,5,6,4)
+ """
n = a.ndim
if axis < 0:
axis += n
- axes = range(1, axis+1) + [0,] + range(axis+1, n)
+ if start < 0:
+ start += n
+ msg = 'rollaxis: %s (%d) must be >=0 and < %d'
+ if not (0 <= axis < n):
+ raise ValueError, msg % ('axis', axis, n)
+ if not (0 <= start < n+1):
+ raise ValueError, msg % ('start', start, n+1)
+ if (axis < start): # it's been removed
+ start -= 1
+ if axis==start:
+ return a
+ axes = range(0,n)
+ axes.remove(axis)
+ axes.insert(start, axis)
return a.transpose(axes)
+# fix hack in scipy which imports this function
+def _move_axis_to_0(a, axis):
+ return rollaxis(a, axis, 0)
+
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
"""Return the cross product of two (arrays of) vectors.
@@ -342,8 +364,8 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
"""
if axis is not None:
axisa,axisb,axisc=(axis,)*3
- a = _move_axis_to_0(asarray(a), axisa)
- b = _move_axis_to_0(asarray(b), axisb)
+ a = asarray(a).swapaxes(axisa, 0)
+ b = asarray(b).swapaxes(axisb, 0)
msg = "incompatible dimensions for cross product\n"\
"(dimension must be 2 or 3)"
if (a.shape[0] not in [2,3]) or (b.shape[0] not in [2,3]):
@@ -354,7 +376,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
if cp.ndim == 0:
return cp
else:
- return cp.swapaxes(0,axisc)
+ return cp.swapaxes(0, axisc)
else:
x = a[1]*b[2]
y = -a[0]*b[2]