summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/tests/test_multiarray.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 6fa4b6bbb..7bb267cfb 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4825,6 +4825,22 @@ if sys.version_info[:2] >= (3, 5):
class TestInner(TestCase):
+ def test_inner_scalar_and_vector(self):
+ for dt in np.typecodes['AllInteger'] + np.typecodes['AllFloat'] + '?':
+ sca = np.array(3, dtype=dt)[()]
+ vec = np.array([1, 2], dtype=dt)
+ desired = np.array([3, 6], dtype=dt)
+ assert_equal(np.inner(vec, sca), desired)
+ assert_equal(np.inner(sca, vec), desired)
+
+ def test_inner_scalar_and_matrix(self):
+ for dt in np.typecodes['AllInteger'] + np.typecodes['AllFloat'] + '?':
+ sca = np.array(3, dtype=dt)[()]
+ arr = np.matrix([[1, 2], [3, 4]], dtype=dt)
+ desired = np.matrix([[3, 6], [9, 12]], dtype=dt)
+ assert_equal(np.inner(arr, sca), desired)
+ assert_equal(np.inner(sca, arr), desired)
+
def test_inner_scalar_and_matrix_of_objects(self):
# Ticket #4482
arr = np.matrix([1, 2], dtype=object)