diff options
author | jaimefrio <jaime.frio@gmail.com> | 2014-02-20 22:15:47 -0800 |
---|---|---|
committer | jaimefrio <jaime.frio@gmail.com> | 2014-02-20 22:17:43 -0800 |
commit | b9454f50f23516234c325490913224c3a69fb122 (patch) | |
tree | 78e27bdca6aaca1258f760af5db4ccfb5e4f8b41 /numpy/core/numeric.py | |
parent | 34d7bee3d472e5b97c1ea7c8d7e8d9949a9ebc16 (diff) | |
download | numpy-b9454f50f23516234c325490913224c3a69fb122.tar.gz |
BUG: Fix broadcasting in np.cross (solves #2624)
Changed np.cross to move the axis it's operating on to the end, not the
beginning of the shape tuple, and used rollaxis, not swapaxes, so that
normal broadcasting rules apply.
There were no tests in place, so I added some validation of results and
broadcasted shapes, including the one reported in #2624
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 52 |
1 files changed, 28 insertions, 24 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 3b8e52e71..ee14f2af6 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1490,39 +1490,43 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): """ if axis is not None: - axisa, axisb, axisc=(axis,)*3 - a = asarray(a).swapaxes(axisa, 0) - b = asarray(b).swapaxes(axisb, 0) - msg = "incompatible dimensions for cross product\n"\ - "(dimension must be 2 or 3)" - if (a.shape[0] not in [2, 3]) or (b.shape[0] not in [2, 3]): + axisa, axisb, axisc = (axis,) * 3 + a = asarray(a) + b = asarray(b) + # Move working axis to the end of the shape + a = rollaxis(a, axisa, a.ndim) + b = rollaxis(b, axisb, b.ndim) + msg = ("incompatible dimensions for cross product\n" + "(dimension must be 2 or 3)") + if a.shape[-1] not in [2, 3] or b.shape[-1] not in [2, 3]: raise ValueError(msg) - if a.shape[0] == 2: - if (b.shape[0] == 2): - cp = a[0]*b[1] - a[1]*b[0] + if a.shape[-1] == 2: + if b.shape[-1] == 2: + cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] if cp.ndim == 0: return cp else: - return cp.swapaxes(0, axisc) + # This works because we are moving the last axis + return rollaxis(cp, -1, axisc) else: - x = a[1]*b[2] - y = -a[0]*b[2] - z = a[0]*b[1] - a[1]*b[0] - elif a.shape[0] == 3: - if (b.shape[0] == 3): - x = a[1]*b[2] - a[2]*b[1] - y = a[2]*b[0] - a[0]*b[2] - z = a[0]*b[1] - a[1]*b[0] + x = a[..., 1]*b[..., 2] + y = -a[..., 0]*b[..., 2] + z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + elif a.shape[-1] == 3: + if b.shape[-1] == 3: + x = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1] + y = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2] + z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] else: - x = -a[2]*b[1] - y = a[2]*b[0] - z = a[0]*b[1] - a[1]*b[0] - cp = array([x, y, z]) + x = -a[..., 2]*b[..., 1] + y = a[..., 2]*b[..., 0] + z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + cp = concatenate((x[..., None], y[..., None], z[..., None]), axis=-1) if cp.ndim == 1: return cp else: - return cp.swapaxes(0, axisc) - + # This works because we are moving the last axis + return rollaxis(cp, -1, axisc) #Use numarray's printing function from .arrayprint import array2string, get_printoptions, set_printoptions |