diff options
author | Travis E. Oliphant <teoliphant@gmail.com> | 2012-06-12 00:39:40 -0700 |
---|---|---|
committer | Travis E. Oliphant <teoliphant@gmail.com> | 2012-06-12 00:39:40 -0700 |
commit | a8f1612c75cc120b6d22896c36a68fea330f5fbe (patch) | |
tree | 6bf01153b283efdd513770418c1f8e880e297f5c /numpy/lib | |
parent | 637fa6233b4e5af81e0c8e629adcbee1918b4ebd (diff) | |
parent | e909e4eafba23b6dd6391c8ea6aeb003c6192ef4 (diff) | |
download | numpy-a8f1612c75cc120b6d22896c36a68fea330f5fbe.tar.gz |
Merge pull request #306 from nouiz/fill_diagonal
fix the wrapping problem of fill_diagonal with tall matrix.
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/index_tricks.py | 40 | ||||
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 38 |
2 files changed, 74 insertions, 4 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index c29f3a6d3..6f2aa1d02 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -658,9 +658,8 @@ s_ = IndexExpression(maketuple=False) # The following functions complement those in twodim_base, but are # applicable to N-dimensions. -def fill_diagonal(a, val): - """ - Fill the main diagonal of the given array of any dimensionality. +def fill_diagonal(a, val, wrap=False): + """Fill the main diagonal of the given array of any dimensionality. For an array `a` with ``a.ndim > 2``, the diagonal is the list of locations with indices ``a[i, i, ..., i]`` all identical. This function @@ -675,6 +674,10 @@ def fill_diagonal(a, val): Value to be written on the diagonal, its type must be compatible with that of the array a. + wrap: bool For tall matrices in NumPy version up to 1.6.2, the + diagonal "wrapped" after N columns. You can have this behavior + with this option. This affect only tall matrices. + See also -------- diag_indices, diag_indices_from @@ -716,13 +719,42 @@ def fill_diagonal(a, val): [0, 0, 0], [0, 0, 4]]) + # tall matrices no wrap + >>> a = np.zeros((5, 3),int) + >>> fill_diagonal(a, 4) + array([[4, 0, 0], + [0, 4, 0], + [0, 0, 4], + [0, 0, 0], + [0, 0, 0]]) + + # tall matrices wrap + >>> a = np.zeros((5, 3),int) + >>> fill_diagonal(a, 4) + array([[4, 0, 0], + [0, 4, 0], + [0, 0, 4], + [0, 0, 0], + [4, 0, 0]]) + + # wide matrices + >>> a = np.zeros((3, 5),int) + >>> fill_diagonal(a, 4) + array([[4, 0, 0, 0, 0], + [0, 4, 0, 0, 0], + [0, 0, 4, 0, 0]]) + """ if a.ndim < 2: raise ValueError("array must be at least 2-d") + end = None if a.ndim == 2: # Explicit, fast formula for the common case. For 2-d arrays, we # accept rectangular ones. step = a.shape[1] + 1 + #This is needed to don't have tall matrix have the diagonal wrap. + if not wrap: + end = a.shape[1] * a.shape[1] else: # For more than d=2, the strided formula is only valid for arrays with # all dimensions equal, so we check first. @@ -731,7 +763,7 @@ def fill_diagonal(a, val): step = 1 + (cumprod(a.shape[:-1])).sum() # Write the value out into the diagonal. - a.flat[::step] = val + a.flat[:end:step] = val def diag_indices(n, ndim=2): diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index 2c6500a57..beda2d146 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -159,6 +159,44 @@ def test_fill_diagonal(): [0, 5, 0], [0, 0, 5]])) + #Test tall matrix + a = zeros((10, 3),int) + fill_diagonal(a, 5) + yield (assert_array_equal, a, + array([[5, 0, 0], + [0, 5, 0], + [0, 0, 5], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]])) + + #Test tall matrix wrap + a = zeros((10, 3),int) + fill_diagonal(a, 5, True) + yield (assert_array_equal, a, + array([[5, 0, 0], + [0, 5, 0], + [0, 0, 5], + [0, 0, 0], + [5, 0, 0], + [0, 5, 0], + [0, 0, 5], + [0, 0, 0], + [5, 0, 0], + [0, 5, 0]])) + + #Test wide matrix + a = zeros((3, 10),int) + fill_diagonal(a, 5) + yield (assert_array_equal, a, + array([[5, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 5, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 5, 0, 0, 0, 0, 0, 0, 0]])) + # The same function can operate on a 4-d array: a = zeros((3, 3, 3, 3), int) fill_diagonal(a, 4) |