diff options
Diffstat (limited to 'numpy/lib/tests/test_index_tricks.py')
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index 641737d43..50dc5ac15 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -1,5 +1,6 @@ from numpy.testing import * -from numpy import array, ones, r_, mgrid, unravel_index, ndenumerate +from numpy import ( array, ones, r_, mgrid, unravel_index, zeros, where, + fill_diagonal, diag_indices, diag_indices_from ) class TestUnravelIndex(TestCase): def test_basic(self): @@ -68,5 +69,48 @@ class TestNdenumerate(TestCase): assert_equal(list(ndenumerate(a)), [((0,0), 1), ((0,1), 2), ((1,0), 3), ((1,1), 4)]) + +def test_fill_diagonal(): + a = zeros((3, 3),int) + fill_diagonal(a, 5) + yield (assert_array_equal, a, + array([[5, 0, 0], + [0, 5, 0], + [0, 0, 5]])) + + # The same function can operate on a 4-d array: + a = zeros((3, 3, 3, 3), int) + fill_diagonal(a, 4) + i = array([0, 1, 2]) + yield (assert_equal, where(a != 0), (i, i, i, i)) + + +def test_diag_indices(): + di = diag_indices(4) + a = array([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]]) + a[di] = 100 + yield (assert_array_equal, a, + array([[100, 2, 3, 4], + [ 5, 100, 7, 8], + [ 9, 10, 100, 12], + [ 13, 14, 15, 100]])) + + # Now, we create indices to manipulate a 3-d array: + d3 = diag_indices(2, 3) + + # And use it to set the diagonal of a zeros array to 1: + a = zeros((2, 2, 2),int) + a[d3] = 1 + yield (assert_array_equal, a, + array([[[1, 0], + [0, 0]], + + [[0, 0], + [0, 1]]]) ) + + if __name__ == "__main__": run_module_suite() |