summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/tests/test_twodim_base.py38
-rw-r--r--numpy/lib/twodim_base.py47
2 files changed, 62 insertions, 23 deletions
diff --git a/numpy/lib/tests/test_twodim_base.py b/numpy/lib/tests/test_twodim_base.py
index 45ab54fbc..5d850f9fd 100644
--- a/numpy/lib/tests/test_twodim_base.py
+++ b/numpy/lib/tests/test_twodim_base.py
@@ -53,6 +53,23 @@ class TestEye(TestCase):
[1,0,0],
[0,1,0]]))
+ def test_eye_bounds(self):
+ assert_equal(eye(2, 2, 1), [[0, 1], [0, 0]])
+ assert_equal(eye(2, 2, -1), [[0, 0], [1, 0]])
+ assert_equal(eye(2, 2, 2), [[0, 0], [0, 0]])
+ assert_equal(eye(2, 2, -2), [[0, 0], [0, 0]])
+ assert_equal(eye(3, 2, 2), [[0, 0], [0, 0], [0, 0]])
+ assert_equal(eye(3, 2, 1), [[0, 1], [0, 0], [0, 0]])
+ assert_equal(eye(3, 2, -1), [[0, 0], [1, 0], [0, 1]])
+ assert_equal(eye(3, 2, -2), [[0, 0], [0, 0], [1, 0]])
+ assert_equal(eye(3, 2, -3), [[0, 0], [0, 0], [0, 0]])
+
+ def test_strings(self):
+ assert_equal(eye(2, 2, dtype='S3'), [['1', ''], ['', '1']])
+
+ def test_bool(self):
+ assert_equal(eye(2, 2, dtype=bool), [[True, False], [False, True]])
+
class TestDiag(TestCase):
def test_vector(self):
vals = (100 * arange(5)).astype('l')
@@ -68,8 +85,9 @@ class TestDiag(TestCase):
assert_equal(diag(vals, k=2), b)
assert_equal(diag(vals, k=-2), c)
- def test_matrix(self):
- vals = (100 * get_mat(5) + 1).astype('l')
+ def test_matrix(self, vals=None):
+ if vals is None:
+ vals = (100 * get_mat(5) + 1).astype('l')
b = zeros((5,))
for k in range(5):
b[k] = vals[k,k]
@@ -82,6 +100,22 @@ class TestDiag(TestCase):
b[k] = vals[k + 2, k]
assert_equal(diag(vals, -2), b[:3])
+ def test_fortran_order(self):
+ vals = array((100 * get_mat(5) + 1), order='F', dtype='l')
+ self.test_matrix(vals)
+
+ def test_diag_bounds(self):
+ A = [[1, 2], [3, 4], [5, 6]]
+ assert_equal(diag(A, k=2), [])
+ assert_equal(diag(A, k=1), [2])
+ assert_equal(diag(A, k=0), [1, 4])
+ assert_equal(diag(A, k=-1), [3, 6])
+ assert_equal(diag(A, k=-2), [5])
+ assert_equal(diag(A, k=-3), [])
+
+ def test_failure(self):
+ self.failUnlessRaises(ValueError, diag, [[[1]]])
+
class TestFliplr(TestCase):
def test_basic(self):
self.failUnlessRaises(ValueError, fliplr, ones(4))
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index 6c9eb5dbb..e794d4144 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -8,7 +8,8 @@ __all__ = ['diag','diagflat','eye','fliplr','flipud','rot90','tri','triu',
]
from numpy.core.numeric import asanyarray, equal, subtract, arange, \
- zeros, greater_equal, multiply, ones, asarray, alltrue, where
+ zeros, greater_equal, multiply, ones, asarray, alltrue, where, \
+ empty
def fliplr(m):
"""
@@ -197,10 +198,16 @@ def eye(N, M=None, k=0, dtype=float):
[ 0., 0., 0.]])
"""
- if M is None: M = N
- m = equal(subtract.outer(arange(N), arange(M)),-k)
- if m.dtype != dtype:
- m = m.astype(dtype)
+ if M is None:
+ M = N
+ m = zeros((N, M), dtype=dtype)
+ if k >= M:
+ return m
+ if k >= 0:
+ i = k
+ else:
+ i = (-k) * M
+ m[:M-k].flat[i::M+1] = 1
return m
def diag(v, k=0):
@@ -246,28 +253,26 @@ def diag(v, k=0):
"""
v = asarray(v)
s = v.shape
- if len(s)==1:
+ if len(s) == 1:
n = s[0]+abs(k)
res = zeros((n,n), v.dtype)
- if (k>=0):
- i = arange(0,n-k)
- fi = i+k+i*n
+ if k >= 0:
+ i = k
else:
- i = arange(0,n+k)
- fi = i+(i-k)*n
- res.flat[fi] = v
+ i = (-k) * n
+ res[:n-k].flat[i::n+1] = v
return res
- elif len(s)==2:
- N1,N2 = s
+ elif len(s) == 2:
+ if k >= s[1]:
+ return empty(0, dtype=v.dtype)
+ if v.flags.f_contiguous:
+ # faster slicing
+ v, k, s = v.T, -k, s[::-1]
if k >= 0:
- M = min(N1,N2-k)
- i = arange(0,M)
- fi = i+k+i*N2
+ i = k
else:
- M = min(N1+k,N2)
- i = arange(0,M)
- fi = i + (i-k)*N2
- return v.flat[fi]
+ i = (-k) * s[1]
+ return v[:s[1]-k].flat[i::s[1]+1]
else:
raise ValueError, "Input must be 1- or 2-d."