summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-06-05 14:24:40 -0600
committerCharles Harris <charlesr.harris@gmail.com>2015-06-05 14:24:40 -0600
commit943ac81b58c7c8afbfadedbdd28ab94e56ad58fa (patch)
tree75e684452eecbaf1888f0169a18768edddcac6f7
parenta99c77577850054c55850102132dcf3eb43dea6a (diff)
parenta19efc85c21940cb2621e63c5012f89d7d7eb1fa (diff)
downloadnumpy-943ac81b58c7c8afbfadedbdd28ab94e56ad58fa.tar.gz
Merge pull request #5944 from seberg/einsum-fix
Einsum fix
-rw-r--r--numpy/core/src/multiarray/einsum.c.src4
-rw-r--r--numpy/core/tests/test_einsum.py24
2 files changed, 25 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index 5b1bb77fa..7483fb01b 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -1770,14 +1770,14 @@ get_sum_of_products_function(int nop, int type_num,
}
/* Check for all contiguous */
- for (iop = 0; iop < nop; ++iop) {
+ for (iop = 0; iop < nop + 1; ++iop) {
if (fixed_strides[iop] != itemsize) {
break;
}
}
/* Contiguous loop */
- if (iop == nop) {
+ if (iop == nop + 1) {
return _allcontig_specialized_table[type_num][nop <= 3 ? nop : 0];
}
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index e32a05df3..6aa20eb72 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -91,7 +91,7 @@ class TestEinSum(TestCase):
b = np.einsum(a, [0, 1])
assert_(b.base is a)
assert_equal(b, a)
-
+
# output is writeable whenever input is writeable
b = np.einsum("...", a)
assert_(b.flags['WRITEABLE'])
@@ -585,6 +585,28 @@ class TestEinSum(TestCase):
y2 = x[idx[:, None], idx[:, None], idx, idx]
assert_equal(y1, y2)
+ def test_einsum_all_contig_non_contig_output(self):
+ # Issue gh-5907, tests that the all contiguous special case
+ # actually checks the contiguity of the output
+ x = np.ones((5, 5))
+ out = np.ones(10)[::2]
+ correct_base = np.ones(10)
+ correct_base[::2] = 5
+ # Always worked (inner iteration is done with 0-stride):
+ np.einsum('mi,mi,mi->m', x, x, x, out=out)
+ assert_array_equal(out.base, correct_base)
+ # Example 1:
+ out = np.ones(10)[::2]
+ np.einsum('im,im,im->m', x, x, x, out=out)
+ assert_array_equal(out.base, correct_base)
+ # Example 2, buffering causes x to be contiguous but
+ # special cases do not catch the operation before:
+ out = np.ones((2, 2, 2))[..., 0]
+ correct_base = np.ones((2, 2, 2))
+ correct_base[..., 0] = 2
+ x = np.ones((2, 2), np.float32)
+ np.einsum('ij,jk->ik', x, x, out=out)
+ assert_array_equal(out.base, correct_base)
if __name__ == "__main__":
run_module_suite()