summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-01-09 14:50:53 -0700
committerCharles Harris <charlesr.harris@gmail.com>2016-01-09 14:50:53 -0700
commite89f45756e36082e2ceac5c129cd83a7f76c3ee8 (patch)
tree1c6e414a61cc0458ad245f7d3ba131b25a2be284
parentdccda3fc9acd6c1f1d3084cffb4fe6f65c007275 (diff)
parentaaa16ed8ecae49c41507a99c8a6c196d13327bb0 (diff)
downloadnumpy-e89f45756e36082e2ceac5c129cd83a7f76c3ee8.tar.gz
Merge pull request #6986 from jakirkham/test_innerproduct
TST: `inner` with different dimensions
-rw-r--r--numpy/core/tests/test_multiarray.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index c66e49e5f..04e09d37f 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4825,6 +4825,22 @@ if sys.version_info[:2] >= (3, 5):
class TestInner(TestCase):
+ def test_inner_scalar_and_vector(self):
+ for dt in np.typecodes['AllInteger'] + np.typecodes['AllFloat'] + '?':
+ sca = np.array(3, dtype=dt)[()]
+ vec = np.array([1, 2], dtype=dt)
+ desired = np.array([3, 6], dtype=dt)
+ assert_equal(np.inner(vec, sca), desired)
+ assert_equal(np.inner(sca, vec), desired)
+
+ def test_inner_scalar_and_matrix(self):
+ for dt in np.typecodes['AllInteger'] + np.typecodes['AllFloat'] + '?':
+ sca = np.array(3, dtype=dt)[()]
+ arr = np.matrix([[1, 2], [3, 4]], dtype=dt)
+ desired = np.matrix([[3, 6], [9, 12]], dtype=dt)
+ assert_equal(np.inner(arr, sca), desired)
+ assert_equal(np.inner(sca, arr), desired)
+
def test_inner_scalar_and_matrix_of_objects(self):
# Ticket #4482
arr = np.matrix([1, 2], dtype=object)
@@ -4849,13 +4865,49 @@ class TestInner(TestCase):
C = np.array([1, 1], dtype=dt)
desired = np.array([4, 6], dtype=dt)
assert_equal(np.inner(A.T, C), desired)
+ assert_equal(np.inner(C, A.T), desired)
assert_equal(np.inner(B, C), desired)
+ assert_equal(np.inner(C, B), desired)
+ # check a matrix product
+ desired = np.array([[7, 10], [15, 22]], dtype=dt)
+ assert_equal(np.inner(A, B), desired)
+ # check the syrk vs. gemm paths
+ desired = np.array([[5, 11], [11, 25]], dtype=dt)
+ assert_equal(np.inner(A, A), desired)
+ assert_equal(np.inner(A, A.copy()), desired)
# check an inner product involving an aliased and reversed view
a = np.arange(5).astype(dt)
b = a[::-1]
desired = np.array(10, dtype=dt).item()
assert_equal(np.inner(b, a), desired)
+ def test_3d_tensor(self):
+ for dt in np.typecodes['AllInteger'] + np.typecodes['AllFloat'] + '?':
+ a = np.arange(24).reshape(2,3,4).astype(dt)
+ b = np.arange(24, 48).reshape(2,3,4).astype(dt)
+ desired = np.array(
+ [[[[ 158, 182, 206],
+ [ 230, 254, 278]],
+
+ [[ 566, 654, 742],
+ [ 830, 918, 1006]],
+
+ [[ 974, 1126, 1278],
+ [1430, 1582, 1734]]],
+
+ [[[1382, 1598, 1814],
+ [2030, 2246, 2462]],
+
+ [[1790, 2070, 2350],
+ [2630, 2910, 3190]],
+
+ [[2198, 2542, 2886],
+ [3230, 3574, 3918]]]],
+ dtype=dt
+ )
+ assert_equal(np.inner(a, b), desired)
+ assert_equal(np.inner(b, a).transpose(2,3,0,1), desired)
+
class TestSummarization(TestCase):
def test_1d(self):