diff options
-rw-r--r-- | numpy/core/src/multiarray/nditer.c.src | 7 | ||||
-rw-r--r-- | numpy/core/tests/test_einsum.py | 6 | ||||
-rw-r--r-- | numpy/core/tests/test_iterator.py | 11 |
3 files changed, 21 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/nditer.c.src b/numpy/core/src/multiarray/nditer.c.src index dfc0baa6f..ce16a53d0 100644 --- a/numpy/core/src/multiarray/nditer.c.src +++ b/numpy/core/src/multiarray/nditer.c.src @@ -6080,10 +6080,11 @@ npyiter_checkreducesize(NpyIter *iter, npy_intp count, } /* - * If there was any non-zero coordinate or the reduction inner - * loop doesn't fit in the buffersize, can't do the double loop + * If there was any non-zero coordinate, the reduction inner + * loop doesn't fit in the buffersize, or the reduction inner loop + * covered the entire iteration size, can't do the double loop. */ - if (nonzerocoord || count < reducespace) { + if (nonzerocoord || count < reducespace || idim == ndim) { if (reducespace < count) { count = reducespace; } diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index 6a03a5913..d0161b366 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -421,6 +421,12 @@ class TestEinSum(TestCase): assert_equal(b, np.sum(a)) assert_equal(b.dtype, np.dtype(dtype)) + # A case which was failing (ticket #1885) + p = np.arange(2) + 1 + q = np.arange(4).reshape(2,2) + 3 + r = np.arange(4).reshape(2,2) + 7 + assert_equal(np.einsum('z,mz,zm->', p, q, r), 253) + def test_einsum_sums_int8(self): self.check_einsum_sums('i1'); diff --git a/numpy/core/tests/test_iterator.py b/numpy/core/tests/test_iterator.py index fe6871cd3..f39d9584f 100644 --- a/numpy/core/tests/test_iterator.py +++ b/numpy/core/tests/test_iterator.py @@ -2284,5 +2284,16 @@ def test_iter_buffering_reduction(): y[...] += x assert_equal(b, np.sum(a, axis=1)) + # Iterator inner double loop was wrong on this one + p = np.arange(2) + 1 + it = np.nditer([p,None], + ['delay_bufalloc','reduce_ok','buffered','external_loop'], + [['readonly'],['readwrite','allocate']], + op_axes=[[-1,0],[-1,-1]], + itershape=(2,2)) + it.operands[1].fill(0) + it.reset() + assert_equal(it[0], [1,2,1,2]) + if __name__ == "__main__": run_module_suite() |