diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/tests/test_einsum.py | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index e32a05df3..6aa20eb72 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -91,7 +91,7 @@ class TestEinSum(TestCase): b = np.einsum(a, [0, 1]) assert_(b.base is a) assert_equal(b, a) - + # output is writeable whenever input is writeable b = np.einsum("...", a) assert_(b.flags['WRITEABLE']) @@ -585,6 +585,28 @@ class TestEinSum(TestCase): y2 = x[idx[:, None], idx[:, None], idx, idx] assert_equal(y1, y2) + def test_einsum_all_contig_non_contig_output(self): + # Issue gh-5907, tests that the all contiguous special case + # actually checks the contiguity of the output + x = np.ones((5, 5)) + out = np.ones(10)[::2] + correct_base = np.ones(10) + correct_base[::2] = 5 + # Always worked (inner iteration is done with 0-stride): + np.einsum('mi,mi,mi->m', x, x, x, out=out) + assert_array_equal(out.base, correct_base) + # Example 1: + out = np.ones(10)[::2] + np.einsum('im,im,im->m', x, x, x, out=out) + assert_array_equal(out.base, correct_base) + # Example 2, buffering causes x to be contiguous but + # special cases do not catch the operation before: + out = np.ones((2, 2, 2))[..., 0] + correct_base = np.ones((2, 2, 2)) + correct_base[..., 0] = 2 + x = np.ones((2, 2), np.float32) + np.einsum('ij,jk->ik', x, x, out=out) + assert_array_equal(out.base, correct_base) if __name__ == "__main__": run_module_suite() |