diff options
-rw-r--r-- | numpy/lib/index_tricks.py | 37 | ||||
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 16 |
2 files changed, 49 insertions, 4 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index e248bfaea..6f2aa1d02 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -658,9 +658,8 @@ s_ = IndexExpression(maketuple=False) # The following functions complement those in twodim_base, but are # applicable to N-dimensions. -def fill_diagonal(a, val): - """ - Fill the main diagonal of the given array of any dimensionality. +def fill_diagonal(a, val, wrap=False): + """Fill the main diagonal of the given array of any dimensionality. For an array `a` with ``a.ndim > 2``, the diagonal is the list of locations with indices ``a[i, i, ..., i]`` all identical. This function @@ -675,6 +674,10 @@ def fill_diagonal(a, val): Value to be written on the diagonal, its type must be compatible with that of the array a. + wrap: bool For tall matrices in NumPy version up to 1.6.2, the + diagonal "wrapped" after N columns. You can have this behavior + with this option. This affect only tall matrices. + See also -------- diag_indices, diag_indices_from @@ -716,6 +719,31 @@ def fill_diagonal(a, val): [0, 0, 0], [0, 0, 4]]) + # tall matrices no wrap + >>> a = np.zeros((5, 3),int) + >>> fill_diagonal(a, 4) + array([[4, 0, 0], + [0, 4, 0], + [0, 0, 4], + [0, 0, 0], + [0, 0, 0]]) + + # tall matrices wrap + >>> a = np.zeros((5, 3),int) + >>> fill_diagonal(a, 4) + array([[4, 0, 0], + [0, 4, 0], + [0, 0, 4], + [0, 0, 0], + [4, 0, 0]]) + + # wide matrices + >>> a = np.zeros((3, 5),int) + >>> fill_diagonal(a, 4) + array([[4, 0, 0, 0, 0], + [0, 4, 0, 0, 0], + [0, 0, 4, 0, 0]]) + """ if a.ndim < 2: raise ValueError("array must be at least 2-d") @@ -725,7 +753,8 @@ def fill_diagonal(a, val): # accept rectangular ones. step = a.shape[1] + 1 #This is needed to don't have tall matrix have the diagonal wrap. - end = a.shape[1] * a.shape[1] + if not wrap: + end = a.shape[1] * a.shape[1] else: # For more than d=2, the strided formula is only valid for arrays with # all dimensions equal, so we check first. diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index aaedd83ea..beda2d146 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -158,6 +158,7 @@ def test_fill_diagonal(): array([[5, 0, 0], [0, 5, 0], [0, 0, 5]])) + #Test tall matrix a = zeros((10, 3),int) fill_diagonal(a, 5) @@ -173,6 +174,21 @@ def test_fill_diagonal(): [0, 0, 0], [0, 0, 0]])) + #Test tall matrix wrap + a = zeros((10, 3),int) + fill_diagonal(a, 5, True) + yield (assert_array_equal, a, + array([[5, 0, 0], + [0, 5, 0], + [0, 0, 5], + [0, 0, 0], + [5, 0, 0], + [0, 5, 0], + [0, 0, 5], + [0, 0, 0], + [5, 0, 0], + [0, 5, 0]])) + #Test wide matrix a = zeros((3, 10),int) fill_diagonal(a, 5) |