summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/einsumfunc.py2
-rw-r--r--numpy/core/tests/test_einsum.py14
2 files changed, 12 insertions, 4 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py
index bb6767c4f..a4c18d482 100644
--- a/numpy/core/einsumfunc.py
+++ b/numpy/core/einsumfunc.py
@@ -1109,7 +1109,7 @@ def einsum(*operands, **kwargs):
# Checks have already been handled
input_str, results_index = einsum_str.split('->')
input_left, input_right = input_str.split(',')
- if 1 in tmp_operands[0] or 1 in tmp_operands[1]:
+ if 1 in tmp_operands[0].shape or 1 in tmp_operands[1].shape:
left_dims = {dim: size for dim, size in
zip(input_left, tmp_operands[0].shape)}
right_dims = {dim: size for dim, size in
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index d92398456..63e75ff7a 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -491,8 +491,16 @@ class TestEinSum(object):
assert_array_equal(np.einsum('ij,ij->j', p, q, optimize=True),
[10.] * 2)
- p = np.ones((1, 5))
- q = np.ones((5, 5))
+ # a blas-compatible contraction broadcasting case which was failing
+ # for optimize=True (ticket #10930)
+ x = np.array([2., 3.])
+ y = np.array([4.])
+ assert_array_equal(np.einsum("i, i", x, y, optimize=False), 20.)
+ assert_array_equal(np.einsum("i, i", x, y, optimize=True), 20.)
+
+ # all-ones array was bypassing bug (ticket #10930)
+ p = np.ones((1, 5)) / 2
+ q = np.ones((5, 5)) / 2
for optimize in (True, False):
assert_array_equal(np.einsum("...ij,...jk->...ik", p, p,
optimize=optimize),
@@ -500,7 +508,7 @@ class TestEinSum(object):
optimize=optimize))
assert_array_equal(np.einsum("...ij,...jk->...ik", p, q,
optimize=optimize),
- np.full((1, 5), 5))
+ np.full((1, 5), 1.25))
# Cases which were failing (gh-10899)
x = np.eye(2, dtype=dtype)