diff options
author | Julian Taylor <juliantaylor108@gmail.com> | 2014-03-01 15:40:19 +0100 |
---|---|---|
committer | Julian Taylor <juliantaylor108@gmail.com> | 2014-03-01 15:40:19 +0100 |
commit | 4d2aceca643c3f8e8b3edbf9a2172ea4b94a3118 (patch) | |
tree | 98b9cea4421b527b0327584652e3746a19917155 /numpy/core/numeric.py | |
parent | 94d35b500fc1e9daacc56e5dfc2d0cda590ea76f (diff) | |
parent | 2d152aff963cc325a32c65e4dd1b0e2e87fba4b2 (diff) | |
download | numpy-4d2aceca643c3f8e8b3edbf9a2172ea4b94a3118.tar.gz |
Merge pull request #4338 from jaimefrio/cross-broadcast
BUG: Fix broadcasting in np.cross (solves #2624)
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 83 |
1 files changed, 59 insertions, 24 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 5f4553284..713f6f8f0 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1429,6 +1429,11 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): outer : Outer product. ix_ : Construct index arrays. + Notes + ----- + .. versionadded:: 1.9.0 + Supports full broadcasting of the inputs. + Examples -------- Vector cross-product. @@ -1490,39 +1495,69 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): """ if axis is not None: - axisa, axisb, axisc=(axis,)*3 - a = asarray(a).swapaxes(axisa, 0) - b = asarray(b).swapaxes(axisb, 0) - msg = "incompatible dimensions for cross product\n"\ - "(dimension must be 2 or 3)" - if (a.shape[0] not in [2, 3]) or (b.shape[0] not in [2, 3]): + axisa, axisb, axisc = (axis,) * 3 + a = asarray(a) + b = asarray(b) + # Move working axis to the end of the shape + a = rollaxis(a, axisa, a.ndim) + b = rollaxis(b, axisb, b.ndim) + msg = ("incompatible dimensions for cross product\n" + "(dimension must be 2 or 3)") + if a.shape[-1] not in [2, 3] or b.shape[-1] not in [2, 3]: raise ValueError(msg) - if a.shape[0] == 2: - if (b.shape[0] == 2): - cp = a[0]*b[1] - a[1]*b[0] + + # Create the output array + shape = broadcast(a[..., 0], b[..., 0]).shape + if a.shape[-1] == 3 or b.shape[-1] == 3: + shape += (3,) + dtype = promote_types(a.dtype, b.dtype) + cp = empty(shape, dtype) + + if a.shape[-1] == 2: + if b.shape[-1] == 2: + # cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + multiply(a[..., 0], b[..., 1], out=cp) + cp -= a[..., 1]*b[..., 0] if cp.ndim == 0: return cp else: - return cp.swapaxes(0, axisc) + # This works because we are moving the last axis + return rollaxis(cp, -1, axisc) else: - x = a[1]*b[2] - y = -a[0]*b[2] - z = a[0]*b[1] - a[1]*b[0] - elif a.shape[0] == 3: - if (b.shape[0] == 3): - x = a[1]*b[2] - a[2]*b[1] - y = a[2]*b[0] - a[0]*b[2] - z = a[0]*b[1] - a[1]*b[0] + # cp[..., 0] = a[..., 1]*b[..., 2] + multiply(a[..., 1], b[..., 2], out=cp[..., 0]) + # cp[..., 1] = -a[..., 0]*b[..., 2] + multiply(a[..., 0], b[..., 2], out=cp[..., 1]) + cp[..., 1] *= - 1 + # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + multiply(a[..., 0], b[..., 1], out=cp[..., 2]) + cp[..., 2] -= a[..., 1]*b[..., 0] + elif a.shape[-1] == 3: + if b.shape[-1] == 3: + # cp[..., 0] = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1] + multiply(a[..., 1], b[..., 2], out=cp[..., 0]) + cp[..., 0] -= a[..., 2]*b[..., 1] + # cp[..., 1] = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2] + multiply(a[..., 2], b[..., 0], out=cp[..., 1]) + cp[..., 1] -= a[..., 0]*b[..., 2] + # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + multiply(a[..., 0], b[..., 1], out=cp[..., 2]) + cp[..., 2] -= a[..., 1]*b[..., 0] else: - x = -a[2]*b[1] - y = a[2]*b[0] - z = a[0]*b[1] - a[1]*b[0] - cp = array([x, y, z]) + # cp[..., 0] = -a[..., 2]*b[..., 1] + multiply(a[..., 2], b[..., 1], out=cp[..., 0]) + cp[..., 0] *= - 1 + # cp[..., 1] = a[..., 2]*b[..., 0] + multiply(a[..., 2], b[..., 0], out=cp[..., 1]) + # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + multiply(a[..., 0], b[..., 1], out=cp[..., 2]) + cp[..., 2] -= a[..., 1]*b[..., 0] + if cp.ndim == 1: return cp else: - return cp.swapaxes(0, axisc) - + # This works because we are moving the last axis + return rollaxis(cp, -1, axisc) #Use numarray's printing function from .arrayprint import array2string, get_printoptions, set_printoptions |