diff options
author | Pauli Virtanen <pav@iki.fi> | 2009-07-12 22:22:45 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2009-07-12 22:22:45 +0000 |
commit | 07af346db0370c2709e545fe283123d4d9b2bd97 (patch) | |
tree | 855c01372d464b3d95ca96cc1120e2fc7a437feb /numpy/lib | |
parent | c3a0fb9711a119ac303d18719cea702efeff408f (diff) | |
download | numpy-07af346db0370c2709e545fe283123d4d9b2bd97.tar.gz |
Address #1167: faster twodim_base.diag/eye implementation by Luca Citi + tests
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/tests/test_twodim_base.py | 38 | ||||
-rw-r--r-- | numpy/lib/twodim_base.py | 47 |
2 files changed, 62 insertions, 23 deletions
diff --git a/numpy/lib/tests/test_twodim_base.py b/numpy/lib/tests/test_twodim_base.py index 45ab54fbc..5d850f9fd 100644 --- a/numpy/lib/tests/test_twodim_base.py +++ b/numpy/lib/tests/test_twodim_base.py @@ -53,6 +53,23 @@ class TestEye(TestCase): [1,0,0], [0,1,0]])) + def test_eye_bounds(self): + assert_equal(eye(2, 2, 1), [[0, 1], [0, 0]]) + assert_equal(eye(2, 2, -1), [[0, 0], [1, 0]]) + assert_equal(eye(2, 2, 2), [[0, 0], [0, 0]]) + assert_equal(eye(2, 2, -2), [[0, 0], [0, 0]]) + assert_equal(eye(3, 2, 2), [[0, 0], [0, 0], [0, 0]]) + assert_equal(eye(3, 2, 1), [[0, 1], [0, 0], [0, 0]]) + assert_equal(eye(3, 2, -1), [[0, 0], [1, 0], [0, 1]]) + assert_equal(eye(3, 2, -2), [[0, 0], [0, 0], [1, 0]]) + assert_equal(eye(3, 2, -3), [[0, 0], [0, 0], [0, 0]]) + + def test_strings(self): + assert_equal(eye(2, 2, dtype='S3'), [['1', ''], ['', '1']]) + + def test_bool(self): + assert_equal(eye(2, 2, dtype=bool), [[True, False], [False, True]]) + class TestDiag(TestCase): def test_vector(self): vals = (100 * arange(5)).astype('l') @@ -68,8 +85,9 @@ class TestDiag(TestCase): assert_equal(diag(vals, k=2), b) assert_equal(diag(vals, k=-2), c) - def test_matrix(self): - vals = (100 * get_mat(5) + 1).astype('l') + def test_matrix(self, vals=None): + if vals is None: + vals = (100 * get_mat(5) + 1).astype('l') b = zeros((5,)) for k in range(5): b[k] = vals[k,k] @@ -82,6 +100,22 @@ class TestDiag(TestCase): b[k] = vals[k + 2, k] assert_equal(diag(vals, -2), b[:3]) + def test_fortran_order(self): + vals = array((100 * get_mat(5) + 1), order='F', dtype='l') + self.test_matrix(vals) + + def test_diag_bounds(self): + A = [[1, 2], [3, 4], [5, 6]] + assert_equal(diag(A, k=2), []) + assert_equal(diag(A, k=1), [2]) + assert_equal(diag(A, k=0), [1, 4]) + assert_equal(diag(A, k=-1), [3, 6]) + assert_equal(diag(A, k=-2), [5]) + assert_equal(diag(A, k=-3), []) + + def test_failure(self): + self.failUnlessRaises(ValueError, diag, [[[1]]]) + class TestFliplr(TestCase): def test_basic(self): self.failUnlessRaises(ValueError, fliplr, ones(4)) diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py index 6c9eb5dbb..e794d4144 100644 --- a/numpy/lib/twodim_base.py +++ b/numpy/lib/twodim_base.py @@ -8,7 +8,8 @@ __all__ = ['diag','diagflat','eye','fliplr','flipud','rot90','tri','triu', ] from numpy.core.numeric import asanyarray, equal, subtract, arange, \ - zeros, greater_equal, multiply, ones, asarray, alltrue, where + zeros, greater_equal, multiply, ones, asarray, alltrue, where, \ + empty def fliplr(m): """ @@ -197,10 +198,16 @@ def eye(N, M=None, k=0, dtype=float): [ 0., 0., 0.]]) """ - if M is None: M = N - m = equal(subtract.outer(arange(N), arange(M)),-k) - if m.dtype != dtype: - m = m.astype(dtype) + if M is None: + M = N + m = zeros((N, M), dtype=dtype) + if k >= M: + return m + if k >= 0: + i = k + else: + i = (-k) * M + m[:M-k].flat[i::M+1] = 1 return m def diag(v, k=0): @@ -246,28 +253,26 @@ def diag(v, k=0): """ v = asarray(v) s = v.shape - if len(s)==1: + if len(s) == 1: n = s[0]+abs(k) res = zeros((n,n), v.dtype) - if (k>=0): - i = arange(0,n-k) - fi = i+k+i*n + if k >= 0: + i = k else: - i = arange(0,n+k) - fi = i+(i-k)*n - res.flat[fi] = v + i = (-k) * n + res[:n-k].flat[i::n+1] = v return res - elif len(s)==2: - N1,N2 = s + elif len(s) == 2: + if k >= s[1]: + return empty(0, dtype=v.dtype) + if v.flags.f_contiguous: + # faster slicing + v, k, s = v.T, -k, s[::-1] if k >= 0: - M = min(N1,N2-k) - i = arange(0,M) - fi = i+k+i*N2 + i = k else: - M = min(N1+k,N2) - i = arange(0,M) - fi = i + (i-k)*N2 - return v.flat[fi] + i = (-k) * s[1] + return v[:s[1]-k].flat[i::s[1]+1] else: raise ValueError, "Input must be 1- or 2-d." |