diff options
Diffstat (limited to 'numpy/lib/twodim_base.py')
-rw-r--r-- | numpy/lib/twodim_base.py | 31 |
1 files changed, 10 insertions, 21 deletions
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py index d95a59e3f..12bba99a6 100644 --- a/numpy/lib/twodim_base.py +++ b/numpy/lib/twodim_base.py @@ -9,7 +9,7 @@ __all__ = ['diag','diagflat','eye','fliplr','flipud','rot90','tri','triu', from numpy.core.numeric import asanyarray, equal, subtract, arange, \ zeros, greater_equal, multiply, ones, asarray, alltrue, where, \ - empty + empty, diagonal def fliplr(m): """ @@ -166,7 +166,7 @@ def rot90(m, k=1): # k == 3 return fliplr(m.swapaxes(0,1)) -def eye(N, M=None, k=0, dtype=float): +def eye(N, M=None, k=0, dtype=float, maskna=False): """ Return a 2-D array with ones on the diagonal and zeros elsewhere. @@ -182,6 +182,8 @@ def eye(N, M=None, k=0, dtype=float): to a lower diagonal. dtype : data-type, optional Data-type of the returned array. + maskna : boolean + If this is true, the returned array will have an NA mask. Returns ------- @@ -207,24 +209,20 @@ def eye(N, M=None, k=0, dtype=float): """ if M is None: M = N - m = zeros((N, M), dtype=dtype) - if k >= M: - return m - if k >= 0: - i = k - else: - i = (-k) * M - m[:M-k].flat[i::M+1] = 1 + m = zeros((N, M), dtype=dtype, maskna=maskna) + diagonal(m, k)[...] = 1 return m def diag(v, k=0): """ Extract a diagonal or construct a diagonal array. + As of NumPy 1.7, extracting a diagonal always returns a view into `v`. + Parameters ---------- v : array_like - If `v` is a 2-D array, return a copy of its `k`-th diagonal. + If `v` is a 2-D array, return a view of its `k`-th diagonal. If `v` is a 1-D array, return a 2-D array with `v` on the `k`-th diagonal. k : int, optional @@ -278,16 +276,7 @@ def diag(v, k=0): res[:n-k].flat[i::n+1] = v return res elif len(s) == 2: - if k >= s[1]: - return empty(0, dtype=v.dtype) - if v.flags.f_contiguous: - # faster slicing - v, k, s = v.T, -k, s[::-1] - if k >= 0: - i = k - else: - i = (-k) * s[1] - return v[:s[1]-k].flat[i::s[1]+1] + return v.diagonal(k) else: raise ValueError("Input must be 1- or 2-d.") |