diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 737fc0a60..3388decf0 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -669,10 +669,11 @@ s_ = IndexExpression(maketuple=False) # End contribution from Konrad. -# I'm not sure this is the best place in numpy for these functions, but since -# they handle multidimensional arrays, it seemed better than twodim_base. -def fill_diagonal(a,val): +# The following functions complement those in twodim_base, but are +# applicable to N-dimensions. + +def fill_diagonal(a, val): """Fill the main diagonal of the given array of any dimensionality. For an array with ndim > 2, the diagonal is the list of locations with @@ -737,13 +738,13 @@ def fill_diagonal(a,val): # all dimensions equal, so we check first. if not alltrue(diff(a.shape)==0): raise ValueError("All dimensions of input must be of equal length") - step = cumprod((1,)+a.shape[:-1]).sum() + step = cumprod((1,) + a.shape[:-1]).sum() # Write the value out into the diagonal. a.flat[::step] = val -def diag_indices(n,ndim=2): +def diag_indices(n, ndim=2): """Return the indices to access the main diagonal of an array. This returns a tuple of indices that can be used to access the main @@ -758,7 +759,7 @@ def diag_indices(n,ndim=2): indices can be used. ndim : int, optional - The number of dimensions + The number of dimensions. Examples -------- @@ -794,10 +795,10 @@ def diag_indices(n,ndim=2): See also -------- - diag_indices_from: create the indices based on the shape of an existing - array. + array. """ idx = arange(n) - return (idx,)*ndim + return (idx,) * ndim def diag_indices_from(arr): @@ -814,7 +815,7 @@ def diag_indices_from(arr): raise ValueError("input array must be at least 2-d") # For more than d=2, the strided formula is only valid for arrays with # all dimensions equal, so we check first. - if not alltrue(diff(a.shape)==0): + if not alltrue(diff(arr.shape) == 0): raise ValueError("All dimensions of input must be of equal length") - return diag_indices(a.shape[0],a.ndim) + return diag_indices(arr.shape[0], arr.ndim) |