summaryrefslogtreecommitdiff
path: root/numpy/lib/twodim_base.py
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2009-07-12 22:22:45 +0000
committerPauli Virtanen <pav@iki.fi>2009-07-12 22:22:45 +0000
commit07af346db0370c2709e545fe283123d4d9b2bd97 (patch)
tree855c01372d464b3d95ca96cc1120e2fc7a437feb /numpy/lib/twodim_base.py
parentc3a0fb9711a119ac303d18719cea702efeff408f (diff)
downloadnumpy-07af346db0370c2709e545fe283123d4d9b2bd97.tar.gz
Address #1167: faster twodim_base.diag/eye implementation by Luca Citi + tests
Diffstat (limited to 'numpy/lib/twodim_base.py')
-rw-r--r--numpy/lib/twodim_base.py47
1 files changed, 26 insertions, 21 deletions
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index 6c9eb5dbb..e794d4144 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -8,7 +8,8 @@ __all__ = ['diag','diagflat','eye','fliplr','flipud','rot90','tri','triu',
]
from numpy.core.numeric import asanyarray, equal, subtract, arange, \
- zeros, greater_equal, multiply, ones, asarray, alltrue, where
+ zeros, greater_equal, multiply, ones, asarray, alltrue, where, \
+ empty
def fliplr(m):
"""
@@ -197,10 +198,16 @@ def eye(N, M=None, k=0, dtype=float):
[ 0., 0., 0.]])
"""
- if M is None: M = N
- m = equal(subtract.outer(arange(N), arange(M)),-k)
- if m.dtype != dtype:
- m = m.astype(dtype)
+ if M is None:
+ M = N
+ m = zeros((N, M), dtype=dtype)
+ if k >= M:
+ return m
+ if k >= 0:
+ i = k
+ else:
+ i = (-k) * M
+ m[:M-k].flat[i::M+1] = 1
return m
def diag(v, k=0):
@@ -246,28 +253,26 @@ def diag(v, k=0):
"""
v = asarray(v)
s = v.shape
- if len(s)==1:
+ if len(s) == 1:
n = s[0]+abs(k)
res = zeros((n,n), v.dtype)
- if (k>=0):
- i = arange(0,n-k)
- fi = i+k+i*n
+ if k >= 0:
+ i = k
else:
- i = arange(0,n+k)
- fi = i+(i-k)*n
- res.flat[fi] = v
+ i = (-k) * n
+ res[:n-k].flat[i::n+1] = v
return res
- elif len(s)==2:
- N1,N2 = s
+ elif len(s) == 2:
+ if k >= s[1]:
+ return empty(0, dtype=v.dtype)
+ if v.flags.f_contiguous:
+ # faster slicing
+ v, k, s = v.T, -k, s[::-1]
if k >= 0:
- M = min(N1,N2-k)
- i = arange(0,M)
- fi = i+k+i*N2
+ i = k
else:
- M = min(N1+k,N2)
- i = arange(0,M)
- fi = i + (i-k)*N2
- return v.flat[fi]
+ i = (-k) * s[1]
+ return v[:s[1]-k].flat[i::s[1]+1]
else:
raise ValueError, "Input must be 1- or 2-d."