summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorjaimefrio <jaime.frio@gmail.com>2014-02-20 22:15:47 -0800
committerjaimefrio <jaime.frio@gmail.com>2014-02-20 22:17:43 -0800
commitb9454f50f23516234c325490913224c3a69fb122 (patch)
tree78e27bdca6aaca1258f760af5db4ccfb5e4f8b41 /numpy/core/numeric.py
parent34d7bee3d472e5b97c1ea7c8d7e8d9949a9ebc16 (diff)
downloadnumpy-b9454f50f23516234c325490913224c3a69fb122.tar.gz
BUG: Fix broadcasting in np.cross (solves #2624)
Changed np.cross to move the axis it's operating on to the end, not the beginning of the shape tuple, and used rollaxis, not swapaxes, so that normal broadcasting rules apply. There were no tests in place, so I added some validation of results and broadcasted shapes, including the one reported in #2624
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py52
1 files changed, 28 insertions, 24 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 3b8e52e71..ee14f2af6 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1490,39 +1490,43 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
"""
if axis is not None:
- axisa, axisb, axisc=(axis,)*3
- a = asarray(a).swapaxes(axisa, 0)
- b = asarray(b).swapaxes(axisb, 0)
- msg = "incompatible dimensions for cross product\n"\
- "(dimension must be 2 or 3)"
- if (a.shape[0] not in [2, 3]) or (b.shape[0] not in [2, 3]):
+ axisa, axisb, axisc = (axis,) * 3
+ a = asarray(a)
+ b = asarray(b)
+ # Move working axis to the end of the shape
+ a = rollaxis(a, axisa, a.ndim)
+ b = rollaxis(b, axisb, b.ndim)
+ msg = ("incompatible dimensions for cross product\n"
+ "(dimension must be 2 or 3)")
+ if a.shape[-1] not in [2, 3] or b.shape[-1] not in [2, 3]:
raise ValueError(msg)
- if a.shape[0] == 2:
- if (b.shape[0] == 2):
- cp = a[0]*b[1] - a[1]*b[0]
+ if a.shape[-1] == 2:
+ if b.shape[-1] == 2:
+ cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
if cp.ndim == 0:
return cp
else:
- return cp.swapaxes(0, axisc)
+ # This works because we are moving the last axis
+ return rollaxis(cp, -1, axisc)
else:
- x = a[1]*b[2]
- y = -a[0]*b[2]
- z = a[0]*b[1] - a[1]*b[0]
- elif a.shape[0] == 3:
- if (b.shape[0] == 3):
- x = a[1]*b[2] - a[2]*b[1]
- y = a[2]*b[0] - a[0]*b[2]
- z = a[0]*b[1] - a[1]*b[0]
+ x = a[..., 1]*b[..., 2]
+ y = -a[..., 0]*b[..., 2]
+ z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ elif a.shape[-1] == 3:
+ if b.shape[-1] == 3:
+ x = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1]
+ y = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2]
+ z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
else:
- x = -a[2]*b[1]
- y = a[2]*b[0]
- z = a[0]*b[1] - a[1]*b[0]
- cp = array([x, y, z])
+ x = -a[..., 2]*b[..., 1]
+ y = a[..., 2]*b[..., 0]
+ z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ cp = concatenate((x[..., None], y[..., None], z[..., None]), axis=-1)
if cp.ndim == 1:
return cp
else:
- return cp.swapaxes(0, axisc)
-
+ # This works because we are moving the last axis
+ return rollaxis(cp, -1, axisc)
#Use numarray's printing function
from .arrayprint import array2string, get_printoptions, set_printoptions