summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Kirkham <kirkhamj@janelia.hhmi.org>2016-01-09 14:40:54 -0500
committerJohn Kirkham <kirkhamj@janelia.hhmi.org>2016-01-09 15:24:09 -0500
commitaaa16ed8ecae49c41507a99c8a6c196d13327bb0 (patch)
tree1c6e414a61cc0458ad245f7d3ba131b25a2be284
parent50e4e3a2c81e6cb624c36e32b2526cec85d37efb (diff)
downloadnumpy-aaa16ed8ecae49c41507a99c8a6c196d13327bb0.tar.gz
TST: Add an `inner` test with two 3D tensors.
-rw-r--r--numpy/core/tests/test_multiarray.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 7bb267cfb..04e09d37f 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4881,6 +4881,33 @@ class TestInner(TestCase):
desired = np.array(10, dtype=dt).item()
assert_equal(np.inner(b, a), desired)
+ def test_3d_tensor(self):
+ for dt in np.typecodes['AllInteger'] + np.typecodes['AllFloat'] + '?':
+ a = np.arange(24).reshape(2,3,4).astype(dt)
+ b = np.arange(24, 48).reshape(2,3,4).astype(dt)
+ desired = np.array(
+ [[[[ 158, 182, 206],
+ [ 230, 254, 278]],
+
+ [[ 566, 654, 742],
+ [ 830, 918, 1006]],
+
+ [[ 974, 1126, 1278],
+ [1430, 1582, 1734]]],
+
+ [[[1382, 1598, 1814],
+ [2030, 2246, 2462]],
+
+ [[1790, 2070, 2350],
+ [2630, 2910, 3190]],
+
+ [[2198, 2542, 2886],
+ [3230, 3574, 3918]]]],
+ dtype=dt
+ )
+ assert_equal(np.inner(a, b), desired)
+ assert_equal(np.inner(b, a).transpose(2,3,0,1), desired)
+
class TestSummarization(TestCase):
def test_1d(self):