summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2009-07-04 12:13:38 +0000
committerStefan van der Walt <stefan@sun.ac.za>2009-07-04 12:13:38 +0000
commit133e5c29958ef7090a9ca80665c9436cdcebb7f9 (patch)
treeab0bbfeab61d4003ffd224c7ac318fe7213de7ff /numpy/lib/index_tricks.py
parent9397ecd192974fa623492a677d7b2fb2d715c137 (diff)
downloadnumpy-133e5c29958ef7090a9ca80665c9436cdcebb7f9.tar.gz
Add indexing functions by Fernando Perez.
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py157
1 files changed, 155 insertions, 2 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index b6eaae29f..737fc0a60 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -3,16 +3,19 @@ __all__ = ['unravel_index',
'ogrid',
'r_', 'c_', 's_',
'index_exp', 'ix_',
- 'ndenumerate','ndindex']
+ 'ndenumerate','ndindex',
+ 'fill_diagonal','diag_indices','diag_indices_from']
import sys
import numpy.core.numeric as _nx
-from numpy.core.numeric import asarray, ScalarType, array
+from numpy.core.numeric import ( asarray, ScalarType, array, alltrue, cumprod,
+ arange )
from numpy.core.numerictypes import find_common_type
import math
import function_base
import numpy.core.defmatrix as matrix
+from function_base import diff
makemat = matrix.matrix
# contributed by Stefan van der Walt
@@ -665,3 +668,153 @@ index_exp = IndexExpression(maketuple=True)
s_ = IndexExpression(maketuple=False)
# End contribution from Konrad.
+
+# I'm not sure this is the best place in numpy for these functions, but since
+# they handle multidimensional arrays, it seemed better than twodim_base.
+
+def fill_diagonal(a,val):
+ """Fill the main diagonal of the given array of any dimensionality.
+
+ For an array with ndim > 2, the diagonal is the list of locations with
+ indices a[i,i,...,i], all identical.
+
+ This function modifies the input array in-place, it does not return a
+ value.
+
+ This functionality can be obtained via diag_indices(), but internally this
+ version uses a much faster implementation that never constructs the indices
+ and uses simple slicing.
+
+ Parameters
+ ----------
+ a : array, at least 2-dimensional.
+ Array whose diagonal is to be filled, it gets modified in-place.
+
+ val : scalar
+ Value to be written on the diagonal, its type must be compatible with
+ that of the array a.
+
+ Examples
+ --------
+ >>> a = zeros((3,3),int)
+ >>> fill_diagonal(a,5)
+ >>> a
+ array([[5, 0, 0],
+ [0, 5, 0],
+ [0, 0, 5]])
+
+ The same function can operate on a 4-d array:
+ >>> a = zeros((3,3,3,3),int)
+ >>> fill_diagonal(a,4)
+
+ We only show a few blocks for clarity:
+ >>> a[0,0]
+ array([[4, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]])
+ >>> a[1,1]
+ array([[0, 0, 0],
+ [0, 4, 0],
+ [0, 0, 0]])
+ >>> a[2,2]
+ array([[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 4]])
+
+ See also
+ --------
+ - diag_indices: indices to access diagonals given shape information.
+ - diag_indices_from: indices to access diagonals given an array.
+ """
+ if a.ndim < 2:
+ raise ValueError("array must be at least 2-d")
+ if a.ndim == 2:
+ # Explicit, fast formula for the common case. For 2-d arrays, we
+ # accept rectangular ones.
+ step = a.shape[1] + 1
+ else:
+ # For more than d=2, the strided formula is only valid for arrays with
+ # all dimensions equal, so we check first.
+ if not alltrue(diff(a.shape)==0):
+ raise ValueError("All dimensions of input must be of equal length")
+ step = cumprod((1,)+a.shape[:-1]).sum()
+
+ # Write the value out into the diagonal.
+ a.flat[::step] = val
+
+
+def diag_indices(n,ndim=2):
+ """Return the indices to access the main diagonal of an array.
+
+ This returns a tuple of indices that can be used to access the main
+ diagonal of an array with ndim (>=2) dimensions and shape (n,n,...,n). For
+ ndim=2 this is the usual diagonal, for ndim>2 this is the set of indices
+ to access A[i,i,...,i] for i=[0..n-1].
+
+ Parameters
+ ----------
+ n : int
+ The size, along each dimension, of the arrays for which the returned
+ indices can be used.
+
+ ndim : int, optional
+ The number of dimensions
+
+ Examples
+ --------
+ Create a set of indices to access the diagonal of a (4,4) array:
+ >>> di = diag_indices(4)
+
+ >>> a = np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]])
+ >>> a
+ array([[ 1, 2, 3, 4],
+ [ 5, 6, 7, 8],
+ [ 9, 10, 11, 12],
+ [13, 14, 15, 16]])
+ >>> a[di] = 100
+ >>> a
+ array([[100, 2, 3, 4],
+ [ 5, 100, 7, 8],
+ [ 9, 10, 100, 12],
+ [ 13, 14, 15, 100]])
+
+ Now, we create indices to manipulate a 3-d array:
+ >>> d3 = diag_indices(2,3)
+
+ And use it to set the diagonal of a zeros array to 1:
+ >>> a = zeros((2,2,2),int)
+ >>> a[d3] = 1
+ >>> a
+ array([[[1, 0],
+ [0, 0]],
+
+ [[0, 0],
+ [0, 1]]])
+
+ See also
+ --------
+ - diag_indices_from: create the indices based on the shape of an existing
+ array.
+ """
+ idx = arange(n)
+ return (idx,)*ndim
+
+
+def diag_indices_from(arr):
+ """Return the indices to access the main diagonal of an n-dimensional array.
+
+ See diag_indices() for full details.
+
+ Parameters
+ ----------
+ arr : array, at least 2-d
+ """
+
+ if not arr.ndim >= 2:
+ raise ValueError("input array must be at least 2-d")
+ # For more than d=2, the strided formula is only valid for arrays with
+ # all dimensions equal, so we check first.
+ if not alltrue(diff(a.shape)==0):
+ raise ValueError("All dimensions of input must be of equal length")
+
+ return diag_indices(a.shape[0],a.ndim)