diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2013-03-01 19:49:10 +0100 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2013-03-03 12:55:35 +0100 |
commit | c11fa494fd8f7efcf70097d052b13d8b25caa1ab (patch) | |
tree | c95daa6c2e5ec3ea398920be482e6ab509a92915 | |
parent | acce195ad306529c7f083f48a48b51876168f421 (diff) | |
download | numpy-c11fa494fd8f7efcf70097d052b13d8b25caa1ab.tar.gz |
TST: Add basic tests for 0-d np.nditer/np.nested_iter support
-rw-r--r-- | numpy/core/tests/test_nditer.py | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/numpy/core/tests/test_nditer.py b/numpy/core/tests/test_nditer.py index d537e4921..aa4c2eea9 100644 --- a/numpy/core/tests/test_nditer.py +++ b/numpy/core/tests/test_nditer.py @@ -2472,5 +2472,59 @@ def test_iter_allocated_array_dtypes(): c[1,1] = a / b assert_equal(it.operands[2], [[8, 12], [20, 5]]) + +def test_0d_iter(): + # Basic test for iteration of 0-d arrays: + i = nditer([2, 3], ['multi_index'], [['readonly']]*2) + assert_equal(i.ndim, 0) + assert_equal(i.next(), (2, 3)) + assert_equal(i.multi_index, ()) + assert_equal(i.iterindex, 0) + assert_raises(StopIteration, i.next) + # test reset: + i.reset() + assert_equal(i.next(), (2, 3)) + assert_raises(StopIteration, i.next) + + # test forcing to 0-d + i = nditer(np.arange(5), ['multi_index'], [['readonly']], op_axes=[()]) + assert_equal(i.ndim, 0) + assert_equal(len(i), 1) + # note that itershape=(), still behaves like None due to the conversions + + # Test a more complex buffered casting case (same as another test above) + sdt = [('a', 'f4'), ('b', 'i8'), ('c', 'c8', (2,3)), ('d', 'O')] + a = np.array(0.5, dtype='f4') + i = nditer(a, ['buffered','refs_ok'], ['readonly'], + casting='unsafe', op_dtypes=sdt) + vals = i.next() + assert_equal(vals['a'], 0.5) + assert_equal(vals['b'], 0) + assert_equal(vals['c'], [[(0.5)]*3]*2) + assert_equal(vals['d'], 0.5) + + +def test_0d_nested_iter(): + a = np.arange(12).reshape(2,3,2) + i, j = np.nested_iters(a, [[],[1,0,2]]) + vals = [] + for x in i: + vals.append([y for y in j]) + assert_equal(vals, [[0,1,2,3,4,5,6,7,8,9,10,11]]) + + i, j = np.nested_iters(a, [[1,0,2],[]]) + vals = [] + for x in i: + vals.append([y for y in j]) + assert_equal(vals, [[0],[1],[2],[3],[4],[5],[6],[7],[8],[9],[10],[11]]) + + i, j, k = np.nested_iters(a, [[2,0], [] ,[1]]) + vals = [] + for x in i: + for y in j: + vals.append([z for z in k]) + assert_equal(vals, [[0,2,4],[1,3,5],[6,8,10],[7,9,11]]) + + if __name__ == "__main__": run_module_suite() |