diff options
-rw-r--r-- | numpy/core/einsumfunc.py | 2 | ||||
-rw-r--r-- | numpy/core/tests/test_einsum.py | 14 |
2 files changed, 12 insertions, 4 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index bb6767c4f..a4c18d482 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -1109,7 +1109,7 @@ def einsum(*operands, **kwargs): # Checks have already been handled input_str, results_index = einsum_str.split('->') input_left, input_right = input_str.split(',') - if 1 in tmp_operands[0] or 1 in tmp_operands[1]: + if 1 in tmp_operands[0].shape or 1 in tmp_operands[1].shape: left_dims = {dim: size for dim, size in zip(input_left, tmp_operands[0].shape)} right_dims = {dim: size for dim, size in diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index d92398456..63e75ff7a 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -491,8 +491,16 @@ class TestEinSum(object): assert_array_equal(np.einsum('ij,ij->j', p, q, optimize=True), [10.] * 2) - p = np.ones((1, 5)) - q = np.ones((5, 5)) + # a blas-compatible contraction broadcasting case which was failing + # for optimize=True (ticket #10930) + x = np.array([2., 3.]) + y = np.array([4.]) + assert_array_equal(np.einsum("i, i", x, y, optimize=False), 20.) + assert_array_equal(np.einsum("i, i", x, y, optimize=True), 20.) + + # all-ones array was bypassing bug (ticket #10930) + p = np.ones((1, 5)) / 2 + q = np.ones((5, 5)) / 2 for optimize in (True, False): assert_array_equal(np.einsum("...ij,...jk->...ik", p, p, optimize=optimize), @@ -500,7 +508,7 @@ class TestEinSum(object): optimize=optimize)) assert_array_equal(np.einsum("...ij,...jk->...ik", p, q, optimize=optimize), - np.full((1, 5), 5)) + np.full((1, 5), 1.25)) # Cases which were failing (gh-10899) x = np.eye(2, dtype=dtype) |