diff options
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 52 |
1 files changed, 28 insertions, 24 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 3b8e52e71..ee14f2af6 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1490,39 +1490,43 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): """ if axis is not None: - axisa, axisb, axisc=(axis,)*3 - a = asarray(a).swapaxes(axisa, 0) - b = asarray(b).swapaxes(axisb, 0) - msg = "incompatible dimensions for cross product\n"\ - "(dimension must be 2 or 3)" - if (a.shape[0] not in [2, 3]) or (b.shape[0] not in [2, 3]): + axisa, axisb, axisc = (axis,) * 3 + a = asarray(a) + b = asarray(b) + # Move working axis to the end of the shape + a = rollaxis(a, axisa, a.ndim) + b = rollaxis(b, axisb, b.ndim) + msg = ("incompatible dimensions for cross product\n" + "(dimension must be 2 or 3)") + if a.shape[-1] not in [2, 3] or b.shape[-1] not in [2, 3]: raise ValueError(msg) - if a.shape[0] == 2: - if (b.shape[0] == 2): - cp = a[0]*b[1] - a[1]*b[0] + if a.shape[-1] == 2: + if b.shape[-1] == 2: + cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] if cp.ndim == 0: return cp else: - return cp.swapaxes(0, axisc) + # This works because we are moving the last axis + return rollaxis(cp, -1, axisc) else: - x = a[1]*b[2] - y = -a[0]*b[2] - z = a[0]*b[1] - a[1]*b[0] - elif a.shape[0] == 3: - if (b.shape[0] == 3): - x = a[1]*b[2] - a[2]*b[1] - y = a[2]*b[0] - a[0]*b[2] - z = a[0]*b[1] - a[1]*b[0] + x = a[..., 1]*b[..., 2] + y = -a[..., 0]*b[..., 2] + z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + elif a.shape[-1] == 3: + if b.shape[-1] == 3: + x = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1] + y = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2] + z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] else: - x = -a[2]*b[1] - y = a[2]*b[0] - z = a[0]*b[1] - a[1]*b[0] - cp = array([x, y, z]) + x = -a[..., 2]*b[..., 1] + y = a[..., 2]*b[..., 0] + z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + cp = concatenate((x[..., None], y[..., None], z[..., None]), axis=-1) if cp.ndim == 1: return cp else: - return cp.swapaxes(0, axisc) - + # This works because we are moving the last axis + return rollaxis(cp, -1, axisc) #Use numarray's printing function from .arrayprint import array2string, get_printoptions, set_printoptions |