summaryrefslogtreecommitdiff
path: root/numpy/lib/twodim_base.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-06-30 20:04:28 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-06-30 20:04:28 +0000
commit7d4c3ed2a0caebf1ce9e0da3473fdbde005699e5 (patch)
treefa4801786b5c59f13deff0bdf1cf25c3c8656224 /numpy/lib/twodim_base.py
parentec1662fb0182a87ebf39ec476109becfc7a8cdb1 (diff)
downloadnumpy-7d4c3ed2a0caebf1ce9e0da3473fdbde005699e5.tar.gz
Make the default array type float.
Diffstat (limited to 'numpy/lib/twodim_base.py')
-rw-r--r--numpy/lib/twodim_base.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index 832d3f576..a063ddfea 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -40,13 +40,14 @@ def rot90(m, k=1):
elif k == 2: return fliplr(flipud(m))
else: return fliplr(m.transpose()) # k==3
-def eye(N, M=None, k=0, dtype=int_):
+def eye(N, M=None, k=0, dtype=float):
""" eye returns a N-by-M 2-d array where the k-th diagonal is all ones,
and everything else is zeros.
"""
if M is None: M = N
m = equal(subtract.outer(arange(N), arange(M)),-k)
- return m.astype(dtype)
+ if m.dtype != dtype:
+ return m.astype(dtype)
def diag(v, k=0):
""" returns a copy of the the k-th diagonal if v is a 2-d array
@@ -81,20 +82,21 @@ def diag(v, k=0):
raise ValueError, "Input must be 1- or 2-d."
-def tri(N, M=None, k=0, dtype=int_):
+def tri(N, M=None, k=0, dtype=float):
""" returns a N-by-M array where all the diagonals starting from
lower left corner up to the k-th are all ones.
"""
if M is None: M = N
m = greater_equal(subtract.outer(arange(N), arange(M)),-k)
- return m.astype(dtype)
+ if m.dtype != dtype:
+ return m.astype(dtype)
def tril(m, k=0):
""" returns the elements on and below the k-th diagonal of m. k=0 is the
main diagonal, k > 0 is above and k < 0 is below the main diagonal.
"""
m = asanyarray(m)
- out = multiply(tri(m.shape[0], m.shape[1], k=k, dtype=m.dtype),m)
+ out = multiply(tri(m.shape[0], m.shape[1], k=k, dtype=int),m)
return out
def triu(m, k=0):
@@ -102,7 +104,7 @@ def triu(m, k=0):
main diagonal, k > 0 is above and k < 0 is below the main diagonal.
"""
m = asanyarray(m)
- out = multiply((1-tri(m.shape[0], m.shape[1], k-1, m.dtype)),m)
+ out = multiply((1-tri(m.shape[0], m.shape[1], k-1, int)),m)
return out
# borrowed from John Hunter and matplotlib