summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/index_tricks.py21
-rw-r--r--numpy/lib/tests/test_index_tricks.py7
2 files changed, 18 insertions, 10 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 737fc0a60..3388decf0 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -669,10 +669,11 @@ s_ = IndexExpression(maketuple=False)
# End contribution from Konrad.
-# I'm not sure this is the best place in numpy for these functions, but since
-# they handle multidimensional arrays, it seemed better than twodim_base.
-def fill_diagonal(a,val):
+# The following functions complement those in twodim_base, but are
+# applicable to N-dimensions.
+
+def fill_diagonal(a, val):
"""Fill the main diagonal of the given array of any dimensionality.
For an array with ndim > 2, the diagonal is the list of locations with
@@ -737,13 +738,13 @@ def fill_diagonal(a,val):
# all dimensions equal, so we check first.
if not alltrue(diff(a.shape)==0):
raise ValueError("All dimensions of input must be of equal length")
- step = cumprod((1,)+a.shape[:-1]).sum()
+ step = cumprod((1,) + a.shape[:-1]).sum()
# Write the value out into the diagonal.
a.flat[::step] = val
-def diag_indices(n,ndim=2):
+def diag_indices(n, ndim=2):
"""Return the indices to access the main diagonal of an array.
This returns a tuple of indices that can be used to access the main
@@ -758,7 +759,7 @@ def diag_indices(n,ndim=2):
indices can be used.
ndim : int, optional
- The number of dimensions
+ The number of dimensions.
Examples
--------
@@ -794,10 +795,10 @@ def diag_indices(n,ndim=2):
See also
--------
- diag_indices_from: create the indices based on the shape of an existing
- array.
+ array.
"""
idx = arange(n)
- return (idx,)*ndim
+ return (idx,) * ndim
def diag_indices_from(arr):
@@ -814,7 +815,7 @@ def diag_indices_from(arr):
raise ValueError("input array must be at least 2-d")
# For more than d=2, the strided formula is only valid for arrays with
# all dimensions equal, so we check first.
- if not alltrue(diff(a.shape)==0):
+ if not alltrue(diff(arr.shape) == 0):
raise ValueError("All dimensions of input must be of equal length")
- return diag_indices(a.shape[0],a.ndim)
+ return diag_indices(arr.shape[0], arr.ndim)
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index afa94396a..d7e61799a 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -1,4 +1,5 @@
from numpy.testing import *
+import numpy as np
from numpy import ( array, ones, r_, mgrid, unravel_index, zeros, where,
ndenumerate, fill_diagonal, diag_indices,
diag_indices_from )
@@ -112,6 +113,12 @@ def test_diag_indices():
[[0, 0],
[0, 1]]]) )
+def test_diag_indices_from():
+ x = np.random.random((4, 4))
+ r, c = diag_indices_from(x)
+ assert_array_equal(r, np.arange(4))
+ assert_array_equal(c, np.arange(4))
+
if __name__ == "__main__":
run_module_suite()