diff options
author | Pauli Virtanen <pav@iki.fi> | 2009-07-12 22:22:45 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2009-07-12 22:22:45 +0000 |
commit | 07af346db0370c2709e545fe283123d4d9b2bd97 (patch) | |
tree | 855c01372d464b3d95ca96cc1120e2fc7a437feb /numpy/lib/twodim_base.py | |
parent | c3a0fb9711a119ac303d18719cea702efeff408f (diff) | |
download | numpy-07af346db0370c2709e545fe283123d4d9b2bd97.tar.gz |
Address #1167: faster twodim_base.diag/eye implementation by Luca Citi + tests
Diffstat (limited to 'numpy/lib/twodim_base.py')
-rw-r--r-- | numpy/lib/twodim_base.py | 47 |
1 files changed, 26 insertions, 21 deletions
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py index 6c9eb5dbb..e794d4144 100644 --- a/numpy/lib/twodim_base.py +++ b/numpy/lib/twodim_base.py @@ -8,7 +8,8 @@ __all__ = ['diag','diagflat','eye','fliplr','flipud','rot90','tri','triu', ] from numpy.core.numeric import asanyarray, equal, subtract, arange, \ - zeros, greater_equal, multiply, ones, asarray, alltrue, where + zeros, greater_equal, multiply, ones, asarray, alltrue, where, \ + empty def fliplr(m): """ @@ -197,10 +198,16 @@ def eye(N, M=None, k=0, dtype=float): [ 0., 0., 0.]]) """ - if M is None: M = N - m = equal(subtract.outer(arange(N), arange(M)),-k) - if m.dtype != dtype: - m = m.astype(dtype) + if M is None: + M = N + m = zeros((N, M), dtype=dtype) + if k >= M: + return m + if k >= 0: + i = k + else: + i = (-k) * M + m[:M-k].flat[i::M+1] = 1 return m def diag(v, k=0): @@ -246,28 +253,26 @@ def diag(v, k=0): """ v = asarray(v) s = v.shape - if len(s)==1: + if len(s) == 1: n = s[0]+abs(k) res = zeros((n,n), v.dtype) - if (k>=0): - i = arange(0,n-k) - fi = i+k+i*n + if k >= 0: + i = k else: - i = arange(0,n+k) - fi = i+(i-k)*n - res.flat[fi] = v + i = (-k) * n + res[:n-k].flat[i::n+1] = v return res - elif len(s)==2: - N1,N2 = s + elif len(s) == 2: + if k >= s[1]: + return empty(0, dtype=v.dtype) + if v.flags.f_contiguous: + # faster slicing + v, k, s = v.T, -k, s[::-1] if k >= 0: - M = min(N1,N2-k) - i = arange(0,M) - fi = i+k+i*N2 + i = k else: - M = min(N1+k,N2) - i = arange(0,M) - fi = i + (i-k)*N2 - return v.flat[fi] + i = (-k) * s[1] + return v[:s[1]-k].flat[i::s[1]+1] else: raise ValueError, "Input must be 1- or 2-d." |