diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2015-06-05 14:24:40 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-06-05 14:24:40 -0600 |
commit | 943ac81b58c7c8afbfadedbdd28ab94e56ad58fa (patch) | |
tree | 75e684452eecbaf1888f0169a18768edddcac6f7 | |
parent | a99c77577850054c55850102132dcf3eb43dea6a (diff) | |
parent | a19efc85c21940cb2621e63c5012f89d7d7eb1fa (diff) | |
download | numpy-943ac81b58c7c8afbfadedbdd28ab94e56ad58fa.tar.gz |
Merge pull request #5944 from seberg/einsum-fix
Einsum fix
-rw-r--r-- | numpy/core/src/multiarray/einsum.c.src | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_einsum.py | 24 |
2 files changed, 25 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index 5b1bb77fa..7483fb01b 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -1770,14 +1770,14 @@ get_sum_of_products_function(int nop, int type_num, } /* Check for all contiguous */ - for (iop = 0; iop < nop; ++iop) { + for (iop = 0; iop < nop + 1; ++iop) { if (fixed_strides[iop] != itemsize) { break; } } /* Contiguous loop */ - if (iop == nop) { + if (iop == nop + 1) { return _allcontig_specialized_table[type_num][nop <= 3 ? nop : 0]; } diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index e32a05df3..6aa20eb72 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -91,7 +91,7 @@ class TestEinSum(TestCase): b = np.einsum(a, [0, 1]) assert_(b.base is a) assert_equal(b, a) - + # output is writeable whenever input is writeable b = np.einsum("...", a) assert_(b.flags['WRITEABLE']) @@ -585,6 +585,28 @@ class TestEinSum(TestCase): y2 = x[idx[:, None], idx[:, None], idx, idx] assert_equal(y1, y2) + def test_einsum_all_contig_non_contig_output(self): + # Issue gh-5907, tests that the all contiguous special case + # actually checks the contiguity of the output + x = np.ones((5, 5)) + out = np.ones(10)[::2] + correct_base = np.ones(10) + correct_base[::2] = 5 + # Always worked (inner iteration is done with 0-stride): + np.einsum('mi,mi,mi->m', x, x, x, out=out) + assert_array_equal(out.base, correct_base) + # Example 1: + out = np.ones(10)[::2] + np.einsum('im,im,im->m', x, x, x, out=out) + assert_array_equal(out.base, correct_base) + # Example 2, buffering causes x to be contiguous but + # special cases do not catch the operation before: + out = np.ones((2, 2, 2))[..., 0] + correct_base = np.ones((2, 2, 2)) + correct_base[..., 0] = 2 + x = np.ones((2, 2), np.float32) + np.einsum('ij,jk->ik', x, x, out=out) + assert_array_equal(out.base, correct_base) if __name__ == "__main__": run_module_suite() |