From b9454f50f23516234c325490913224c3a69fb122 Mon Sep 17 00:00:00 2001 From: jaimefrio Date: Thu, 20 Feb 2014 22:15:47 -0800 Subject: BUG: Fix broadcasting in np.cross (solves #2624) Changed np.cross to move the axis it's operating on to the end, not the beginning of the shape tuple, and used rollaxis, not swapaxes, so that normal broadcasting rules apply. There were no tests in place, so I added some validation of results and broadcasted shapes, including the one reported in #2624 --- numpy/core/numeric.py | 52 +++++++++++++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 24 deletions(-) (limited to 'numpy/core/numeric.py') diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 3b8e52e71..ee14f2af6 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1490,39 +1490,43 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): """ if axis is not None: - axisa, axisb, axisc=(axis,)*3 - a = asarray(a).swapaxes(axisa, 0) - b = asarray(b).swapaxes(axisb, 0) - msg = "incompatible dimensions for cross product\n"\ - "(dimension must be 2 or 3)" - if (a.shape[0] not in [2, 3]) or (b.shape[0] not in [2, 3]): + axisa, axisb, axisc = (axis,) * 3 + a = asarray(a) + b = asarray(b) + # Move working axis to the end of the shape + a = rollaxis(a, axisa, a.ndim) + b = rollaxis(b, axisb, b.ndim) + msg = ("incompatible dimensions for cross product\n" + "(dimension must be 2 or 3)") + if a.shape[-1] not in [2, 3] or b.shape[-1] not in [2, 3]: raise ValueError(msg) - if a.shape[0] == 2: - if (b.shape[0] == 2): - cp = a[0]*b[1] - a[1]*b[0] + if a.shape[-1] == 2: + if b.shape[-1] == 2: + cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] if cp.ndim == 0: return cp else: - return cp.swapaxes(0, axisc) + # This works because we are moving the last axis + return rollaxis(cp, -1, axisc) else: - x = a[1]*b[2] - y = -a[0]*b[2] - z = a[0]*b[1] - a[1]*b[0] - elif a.shape[0] == 3: - if (b.shape[0] == 3): - x = a[1]*b[2] - a[2]*b[1] - y = a[2]*b[0] - a[0]*b[2] - z = a[0]*b[1] - a[1]*b[0] + x = a[..., 1]*b[..., 2] + y = -a[..., 0]*b[..., 2] + z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + elif a.shape[-1] == 3: + if b.shape[-1] == 3: + x = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1] + y = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2] + z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] else: - x = -a[2]*b[1] - y = a[2]*b[0] - z = a[0]*b[1] - a[1]*b[0] - cp = array([x, y, z]) + x = -a[..., 2]*b[..., 1] + y = a[..., 2]*b[..., 0] + z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + cp = concatenate((x[..., None], y[..., None], z[..., None]), axis=-1) if cp.ndim == 1: return cp else: - return cp.swapaxes(0, axisc) - + # This works because we are moving the last axis + return rollaxis(cp, -1, axisc) #Use numarray's printing function from .arrayprint import array2string, get_printoptions, set_printoptions -- cgit v1.2.1 From e31da35c5f23968e2cbfae06642ce32dc54e7627 Mon Sep 17 00:00:00 2001 From: jaimefrio Date: Fri, 21 Feb 2014 10:58:06 -0800 Subject: TST: Added out of bounds axis checking to the tests. ENH: Minimize temporary intermediate arrays using a preallocated output array and in-place operations. DOC: Added 'Supports full broadcasting' note to the docstring. --- numpy/core/numeric.py | 53 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 11 deletions(-) (limited to 'numpy/core/numeric.py') diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index ee14f2af6..e371797b6 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1429,6 +1429,11 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): outer : Outer product. ix_ : Construct index arrays. + Notes + ----- + .. versionadded:: 1.9.0 + Supports full broadcasting of the inputs. + Examples -------- Vector cross-product. @@ -1500,28 +1505,54 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): "(dimension must be 2 or 3)") if a.shape[-1] not in [2, 3] or b.shape[-1] not in [2, 3]: raise ValueError(msg) + + # Create the output array + shape = broadcast(a[..., 0], b[..., 0]).shape + if a.shape[-1] == 3 or b.shape[-1] == 3: + shape += (3,) + dtype = promote_types(a.dtype, b.dtype) + cp = empty(shape, dtype) + if a.shape[-1] == 2: if b.shape[-1] == 2: - cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + # cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + multiply(a[..., 0], b[..., 1], out=cp) + cp -= a[..., 1]*b[..., 0] if cp.ndim == 0: return cp else: # This works because we are moving the last axis return rollaxis(cp, -1, axisc) else: - x = a[..., 1]*b[..., 2] - y = -a[..., 0]*b[..., 2] - z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + # cp[..., 0] = a[..., 1]*b[..., 2] + multiply(a[..., 1], b[..., 2], out=cp[..., 0]) + # cp[..., 1] = -a[..., 0]*b[..., 2] + multiply(a[..., 0], b[..., 2], out=cp[..., 1]) + cp[..., 1] *= - 1 + # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + multiply(a[..., 0], b[..., 1], out=cp[..., 2]) + cp[..., 2] -= a[..., 1]*b[..., 0] elif a.shape[-1] == 3: if b.shape[-1] == 3: - x = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1] - y = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2] - z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + # cp[..., 0] = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1] + multiply(a[..., 1], b[..., 2], out=cp[..., 0]) + cp[..., 0] -= a[..., 2]*b[..., 1] + # cp[..., 1] = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2] + multiply(a[..., 2], b[..., 0], out=cp[..., 1]) + cp[..., 1] -= a[..., 0]*b[..., 2] + # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + multiply(a[..., 0], b[..., 1], out=cp[..., 2]) + cp[..., 2] -= a[..., 1]*b[..., 0] else: - x = -a[..., 2]*b[..., 1] - y = a[..., 2]*b[..., 0] - z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] - cp = concatenate((x[..., None], y[..., None], z[..., None]), axis=-1) + # cp[..., 0] = -a[..., 2]*b[..., 1] + multiply(a[..., 2], b[..., 1], out=cp[..., 0]) + cp[..., 0] *= - 1 + # cp[..., 1] = a[..., 2]*b[..., 0] + multiply(a[..., 2], b[..., 0], out=cp[..., 1]) + # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] + multiply(a[..., 0], b[..., 1], out=cp[..., 2]) + cp[..., 2] -= a[..., 1]*b[..., 0] + if cp.ndim == 1: return cp else: -- cgit v1.2.1