summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorJulian Taylor <juliantaylor108@gmail.com>2014-03-01 15:40:19 +0100
committerJulian Taylor <juliantaylor108@gmail.com>2014-03-01 15:40:19 +0100
commit4d2aceca643c3f8e8b3edbf9a2172ea4b94a3118 (patch)
tree98b9cea4421b527b0327584652e3746a19917155 /numpy/core
parent94d35b500fc1e9daacc56e5dfc2d0cda590ea76f (diff)
parent2d152aff963cc325a32c65e4dd1b0e2e87fba4b2 (diff)
downloadnumpy-4d2aceca643c3f8e8b3edbf9a2172ea4b94a3118.tar.gz
Merge pull request #4338 from jaimefrio/cross-broadcast
BUG: Fix broadcasting in np.cross (solves #2624)
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/numeric.py83
-rw-r--r--numpy/core/tests/test_numeric.py44
2 files changed, 103 insertions, 24 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 5f4553284..713f6f8f0 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1429,6 +1429,11 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
outer : Outer product.
ix_ : Construct index arrays.
+ Notes
+ -----
+ .. versionadded:: 1.9.0
+ Supports full broadcasting of the inputs.
+
Examples
--------
Vector cross-product.
@@ -1490,39 +1495,69 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
"""
if axis is not None:
- axisa, axisb, axisc=(axis,)*3
- a = asarray(a).swapaxes(axisa, 0)
- b = asarray(b).swapaxes(axisb, 0)
- msg = "incompatible dimensions for cross product\n"\
- "(dimension must be 2 or 3)"
- if (a.shape[0] not in [2, 3]) or (b.shape[0] not in [2, 3]):
+ axisa, axisb, axisc = (axis,) * 3
+ a = asarray(a)
+ b = asarray(b)
+ # Move working axis to the end of the shape
+ a = rollaxis(a, axisa, a.ndim)
+ b = rollaxis(b, axisb, b.ndim)
+ msg = ("incompatible dimensions for cross product\n"
+ "(dimension must be 2 or 3)")
+ if a.shape[-1] not in [2, 3] or b.shape[-1] not in [2, 3]:
raise ValueError(msg)
- if a.shape[0] == 2:
- if (b.shape[0] == 2):
- cp = a[0]*b[1] - a[1]*b[0]
+
+ # Create the output array
+ shape = broadcast(a[..., 0], b[..., 0]).shape
+ if a.shape[-1] == 3 or b.shape[-1] == 3:
+ shape += (3,)
+ dtype = promote_types(a.dtype, b.dtype)
+ cp = empty(shape, dtype)
+
+ if a.shape[-1] == 2:
+ if b.shape[-1] == 2:
+ # cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ multiply(a[..., 0], b[..., 1], out=cp)
+ cp -= a[..., 1]*b[..., 0]
if cp.ndim == 0:
return cp
else:
- return cp.swapaxes(0, axisc)
+ # This works because we are moving the last axis
+ return rollaxis(cp, -1, axisc)
else:
- x = a[1]*b[2]
- y = -a[0]*b[2]
- z = a[0]*b[1] - a[1]*b[0]
- elif a.shape[0] == 3:
- if (b.shape[0] == 3):
- x = a[1]*b[2] - a[2]*b[1]
- y = a[2]*b[0] - a[0]*b[2]
- z = a[0]*b[1] - a[1]*b[0]
+ # cp[..., 0] = a[..., 1]*b[..., 2]
+ multiply(a[..., 1], b[..., 2], out=cp[..., 0])
+ # cp[..., 1] = -a[..., 0]*b[..., 2]
+ multiply(a[..., 0], b[..., 2], out=cp[..., 1])
+ cp[..., 1] *= - 1
+ # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ multiply(a[..., 0], b[..., 1], out=cp[..., 2])
+ cp[..., 2] -= a[..., 1]*b[..., 0]
+ elif a.shape[-1] == 3:
+ if b.shape[-1] == 3:
+ # cp[..., 0] = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1]
+ multiply(a[..., 1], b[..., 2], out=cp[..., 0])
+ cp[..., 0] -= a[..., 2]*b[..., 1]
+ # cp[..., 1] = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2]
+ multiply(a[..., 2], b[..., 0], out=cp[..., 1])
+ cp[..., 1] -= a[..., 0]*b[..., 2]
+ # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ multiply(a[..., 0], b[..., 1], out=cp[..., 2])
+ cp[..., 2] -= a[..., 1]*b[..., 0]
else:
- x = -a[2]*b[1]
- y = a[2]*b[0]
- z = a[0]*b[1] - a[1]*b[0]
- cp = array([x, y, z])
+ # cp[..., 0] = -a[..., 2]*b[..., 1]
+ multiply(a[..., 2], b[..., 1], out=cp[..., 0])
+ cp[..., 0] *= - 1
+ # cp[..., 1] = a[..., 2]*b[..., 0]
+ multiply(a[..., 2], b[..., 0], out=cp[..., 1])
+ # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ multiply(a[..., 0], b[..., 1], out=cp[..., 2])
+ cp[..., 2] -= a[..., 1]*b[..., 0]
+
if cp.ndim == 1:
return cp
else:
- return cp.swapaxes(0, axisc)
-
+ # This works because we are moving the last axis
+ return rollaxis(cp, -1, axisc)
#Use numarray's printing function
from .arrayprint import array2string, get_printoptions, set_printoptions
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 2a698d1c2..faa128e03 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1958,6 +1958,50 @@ class TestRoll(TestCase):
x = np.array([])
assert_equal(np.roll(x, 1), np.array([]))
+class TestCross(TestCase):
+ def test_2x2(self):
+ u = [1, 2]
+ v = [3, 4]
+ z = -2
+ cp = np.cross(u, v)
+ assert_equal(cp, z)
+ cp = np.cross(v, u)
+ assert_equal(cp, -z)
+
+ def test_2x3(self):
+ u = [1, 2]
+ v = [3, 4, 5]
+ z = np.array([10, -5, -2])
+ cp = np.cross(u, v)
+ assert_equal(cp, z)
+ cp = np.cross(v, u)
+ assert_equal(cp, -z)
+
+ def test_3x3(self):
+ u = [1, 2, 3]
+ v = [4, 5, 6]
+ z = np.array([-3, 6, -3])
+ cp = cross(u, v)
+ assert_equal(cp, z)
+ cp = np.cross(v, u)
+ assert_equal(cp, -z)
+
+ def test_broadcasting(self):
+ # Ticket #2624 (Trac #2032)
+ u = np.ones((2, 1, 3))
+ v = np.ones((5, 3))
+ assert_equal(np.cross(u, v).shape, (2, 5, 3))
+ u = np.ones((10, 3, 5))
+ v = np.ones((2, 5))
+ assert_equal(np.cross(u, v, axisa=1, axisb=0).shape, (10, 5, 3))
+ assert_raises(ValueError, np.cross, u, v, axisa=1, axisb=2)
+ assert_raises(ValueError, np.cross, u, v, axisa=3, axisb=0)
+ u = np.ones((10, 3, 5, 7))
+ v = np.ones((5, 7, 2))
+ assert_equal(np.cross(u, v, axisa=1, axisc=2).shape, (10, 5, 3, 7))
+ assert_raises(ValueError, np.cross, u, v, axisa=-5, axisb=2)
+ assert_raises(ValueError, np.cross, u, v, axisa=1, axisb=-4)
+
if __name__ == "__main__":
run_module_suite()