summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTravis E. Oliphant <teoliphant@gmail.com>2012-06-12 00:39:40 -0700
committerTravis E. Oliphant <teoliphant@gmail.com>2012-06-12 00:39:40 -0700
commita8f1612c75cc120b6d22896c36a68fea330f5fbe (patch)
tree6bf01153b283efdd513770418c1f8e880e297f5c
parent637fa6233b4e5af81e0c8e629adcbee1918b4ebd (diff)
parente909e4eafba23b6dd6391c8ea6aeb003c6192ef4 (diff)
downloadnumpy-a8f1612c75cc120b6d22896c36a68fea330f5fbe.tar.gz
Merge pull request #306 from nouiz/fill_diagonal
fix the wrapping problem of fill_diagonal with tall matrix.
-rw-r--r--numpy/lib/index_tricks.py40
-rw-r--r--numpy/lib/tests/test_index_tricks.py38
2 files changed, 74 insertions, 4 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index c29f3a6d3..6f2aa1d02 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -658,9 +658,8 @@ s_ = IndexExpression(maketuple=False)
# The following functions complement those in twodim_base, but are
# applicable to N-dimensions.
-def fill_diagonal(a, val):
- """
- Fill the main diagonal of the given array of any dimensionality.
+def fill_diagonal(a, val, wrap=False):
+ """Fill the main diagonal of the given array of any dimensionality.
For an array `a` with ``a.ndim > 2``, the diagonal is the list of
locations with indices ``a[i, i, ..., i]`` all identical. This function
@@ -675,6 +674,10 @@ def fill_diagonal(a, val):
Value to be written on the diagonal, its type must be compatible with
that of the array a.
+ wrap: bool For tall matrices in NumPy version up to 1.6.2, the
+ diagonal "wrapped" after N columns. You can have this behavior
+ with this option. This affect only tall matrices.
+
See also
--------
diag_indices, diag_indices_from
@@ -716,13 +719,42 @@ def fill_diagonal(a, val):
[0, 0, 0],
[0, 0, 4]])
+ # tall matrices no wrap
+ >>> a = np.zeros((5, 3),int)
+ >>> fill_diagonal(a, 4)
+ array([[4, 0, 0],
+ [0, 4, 0],
+ [0, 0, 4],
+ [0, 0, 0],
+ [0, 0, 0]])
+
+ # tall matrices wrap
+ >>> a = np.zeros((5, 3),int)
+ >>> fill_diagonal(a, 4)
+ array([[4, 0, 0],
+ [0, 4, 0],
+ [0, 0, 4],
+ [0, 0, 0],
+ [4, 0, 0]])
+
+ # wide matrices
+ >>> a = np.zeros((3, 5),int)
+ >>> fill_diagonal(a, 4)
+ array([[4, 0, 0, 0, 0],
+ [0, 4, 0, 0, 0],
+ [0, 0, 4, 0, 0]])
+
"""
if a.ndim < 2:
raise ValueError("array must be at least 2-d")
+ end = None
if a.ndim == 2:
# Explicit, fast formula for the common case. For 2-d arrays, we
# accept rectangular ones.
step = a.shape[1] + 1
+ #This is needed to don't have tall matrix have the diagonal wrap.
+ if not wrap:
+ end = a.shape[1] * a.shape[1]
else:
# For more than d=2, the strided formula is only valid for arrays with
# all dimensions equal, so we check first.
@@ -731,7 +763,7 @@ def fill_diagonal(a, val):
step = 1 + (cumprod(a.shape[:-1])).sum()
# Write the value out into the diagonal.
- a.flat[::step] = val
+ a.flat[:end:step] = val
def diag_indices(n, ndim=2):
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index 2c6500a57..beda2d146 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -159,6 +159,44 @@ def test_fill_diagonal():
[0, 5, 0],
[0, 0, 5]]))
+ #Test tall matrix
+ a = zeros((10, 3),int)
+ fill_diagonal(a, 5)
+ yield (assert_array_equal, a,
+ array([[5, 0, 0],
+ [0, 5, 0],
+ [0, 0, 5],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]]))
+
+ #Test tall matrix wrap
+ a = zeros((10, 3),int)
+ fill_diagonal(a, 5, True)
+ yield (assert_array_equal, a,
+ array([[5, 0, 0],
+ [0, 5, 0],
+ [0, 0, 5],
+ [0, 0, 0],
+ [5, 0, 0],
+ [0, 5, 0],
+ [0, 0, 5],
+ [0, 0, 0],
+ [5, 0, 0],
+ [0, 5, 0]]))
+
+ #Test wide matrix
+ a = zeros((3, 10),int)
+ fill_diagonal(a, 5)
+ yield (assert_array_equal, a,
+ array([[5, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 5, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 5, 0, 0, 0, 0, 0, 0, 0]]))
+
# The same function can operate on a 4-d array:
a = zeros((3, 3, 3, 3), int)
fill_diagonal(a, 4)