summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulian Taylor <juliantaylor108@gmail.com>2014-06-08 14:56:58 +0200
committerJulian Taylor <juliantaylor108@gmail.com>2014-06-08 14:56:58 +0200
commit49b06284149a57922ffe36cf441ad19df922fe22 (patch)
treecfd6d0ac0326429a8b1484fa664e183cbac97741
parent3087de303e70d405e5a201d72e62fd747c96e4a6 (diff)
parent22df0769eeb180326a657d850faa98e27b70eea5 (diff)
downloadnumpy-49b06284149a57922ffe36cf441ad19df922fe22.tar.gz
Merge pull request #4786 from juliantaylor/cross-style
MAINT: improve readablility of cross and improve test coverage
-rw-r--r--numpy/core/numeric.py75
-rw-r--r--numpy/core/tests/test_numeric.py29
2 files changed, 75 insertions, 29 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 38c28da6d..a85e8514c 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1522,7 +1522,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
b = rollaxis(b, axisb, b.ndim)
msg = ("incompatible dimensions for cross product\n"
"(dimension must be 2 or 3)")
- if a.shape[-1] not in [2, 3] or b.shape[-1] not in [2, 3]:
+ if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
raise ValueError(msg)
# Create the output array
@@ -1532,45 +1532,62 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
dtype = promote_types(a.dtype, b.dtype)
cp = empty(shape, dtype)
+ # create local aliases for readability
+ a0 = a[..., 0]
+ a1 = a[..., 1]
+ if a.shape[-1] == 3:
+ a2 = a[..., 2]
+ b0 = b[..., 0]
+ b1 = b[..., 1]
+ if b.shape[-1] == 3:
+ b2 = b[..., 2]
+ if cp.ndim != 0 and cp.shape[-1] == 3:
+ cp0 = cp[..., 0]
+ cp1 = cp[..., 1]
+ cp2 = cp[..., 2]
+
if a.shape[-1] == 2:
if b.shape[-1] == 2:
- # cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
- multiply(a[..., 0], b[..., 1], out=cp)
- cp -= a[..., 1]*b[..., 0]
+ # a0 * b1 - a1 * b0
+ multiply(a0, b1, out=cp)
+ cp -= a1 * b0
if cp.ndim == 0:
return cp
else:
# This works because we are moving the last axis
return rollaxis(cp, -1, axisc)
else:
- # cp[..., 0] = a[..., 1]*b[..., 2]
- multiply(a[..., 1], b[..., 2], out=cp[..., 0])
- # cp[..., 1] = -a[..., 0]*b[..., 2]
- multiply(a[..., 0], b[..., 2], out=cp[..., 1])
- cp[..., 1] *= - 1
- # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
- multiply(a[..., 0], b[..., 1], out=cp[..., 2])
- cp[..., 2] -= a[..., 1]*b[..., 0]
+ # cp0 = a1 * b2 - 0 (a2 = 0)
+ # cp1 = 0 - a0 * b2 (a2 = 0)
+ # cp2 = a0 * b1 - a1 * b0
+ multiply(a1, b2, out=cp0)
+ multiply(a0, b2, out=cp1)
+ negative(cp1, out=cp1)
+ multiply(a0, b1, out=cp2)
+ cp2 -= a1 * b0
elif a.shape[-1] == 3:
if b.shape[-1] == 3:
- # cp[..., 0] = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1]
- multiply(a[..., 1], b[..., 2], out=cp[..., 0])
- cp[..., 0] -= a[..., 2]*b[..., 1]
- # cp[..., 1] = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2]
- multiply(a[..., 2], b[..., 0], out=cp[..., 1])
- cp[..., 1] -= a[..., 0]*b[..., 2]
- # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
- multiply(a[..., 0], b[..., 1], out=cp[..., 2])
- cp[..., 2] -= a[..., 1]*b[..., 0]
+ # cp0 = a1 * b2 - a2 * b1
+ # cp1 = a2 * b0 - a0 * b2
+ # cp2 = a0 * b1 - a1 * b0
+ multiply(a1, b2, out=cp0)
+ tmp = array(a2 * b1)
+ cp0 -= tmp
+ multiply(a2, b0, out=cp1)
+ multiply(a0, b2, out=tmp)
+ cp1 -= tmp
+ multiply(a0, b1, out=cp2)
+ multiply(a1, b0, out=tmp)
+ cp2 -= tmp
else:
- # cp[..., 0] = -a[..., 2]*b[..., 1]
- multiply(a[..., 2], b[..., 1], out=cp[..., 0])
- cp[..., 0] *= - 1
- # cp[..., 1] = a[..., 2]*b[..., 0]
- multiply(a[..., 2], b[..., 0], out=cp[..., 1])
- # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
- multiply(a[..., 0], b[..., 1], out=cp[..., 2])
- cp[..., 2] -= a[..., 1]*b[..., 0]
+ # cp0 = 0 - a2 * b1 (b2 = 0)
+ # cp1 = a2 * b0 - 0 (b2 = 0)
+ # cp2 = a0 * b1 - a1 * b0
+ multiply(a2, b1, out=cp0)
+ negative(cp0, out=cp0)
+ multiply(a2, b0, out=cp1)
+ multiply(a0, b1, out=cp2)
+ cp2 -= a1 * b0
if cp.ndim == 1:
return cp
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index d08fd983e..40bbe5aec 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2023,6 +2023,35 @@ class TestCross(TestCase):
def test_broadcasting(self):
# Ticket #2624 (Trac #2032)
+ u = np.tile([1, 2], (11, 1))
+ v = np.tile([3, 4], (11, 1))
+ z = -2
+ assert_equal(np.cross(u, v), z)
+ assert_equal(np.cross(v, u), -z)
+ assert_equal(np.cross(u, u), 0)
+
+ u = np.tile([1, 2], (11, 1)).T
+ v = np.tile([3, 4, 5], (11, 1))
+ z = np.tile([10, -5, -2], (11, 1))
+ assert_equal(np.cross(u, v, axisa=0), z)
+ assert_equal(np.cross(v, u.T), -z)
+ assert_equal(np.cross(v, v), 0)
+
+ u = np.tile([1, 2, 3], (11, 1)).T
+ v = np.tile([3, 4], (11, 1)).T
+ z = np.tile([-12, 9, -2], (11, 1))
+ assert_equal(np.cross(u, v, axisa=0, axisb=0), z)
+ assert_equal(np.cross(v.T, u.T), -z)
+ assert_equal(np.cross(u.T, u.T), 0)
+
+ u = np.tile([1, 2, 3], (5, 1))
+ v = np.tile([4, 5, 6], (5, 1)).T
+ z = np.tile([-3, 6, -3], (5, 1))
+ assert_equal(np.cross(u, v, axisb=0), z)
+ assert_equal(np.cross(v.T, u), -z)
+ assert_equal(np.cross(u, u), 0)
+
+ def test_broadcasting_shapes(self):
u = np.ones((2, 1, 3))
v = np.ones((5, 3))
assert_equal(np.cross(u, v).shape, (2, 5, 3))