summaryrefslogtreecommitdiff
path: root/convdtype.py
blob: ebc1ba512ab5e8b75d210440525983fe93db8d15 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from tokenize import  generate_tokens
import token
import sys
def insert(s1, s2, posn):
    """insert s1 into s2 at positions posn

    >>> insert("XX", "abcdef", [2, 4])
    'abXXcdXXef'
    """
    pieces = []
    start = 0
    for end in posn + [len(s2)]:
        pieces.append(s2[start:end])
        start = end
    return s1.join(pieces)

def insert_dtype(readline, output=None):
    """
    >>> from StringIO import StringIO
    >>> src = "zeros((2,3), dtype=float); zeros((2,3));"
    >>> insert_dtype(StringIO(src).readline)
    zeros((2,3), dtype=float); zeros((2,3), dtype=int);
    """
    if output is None:
        output = sys.stdout
    tokens = generate_tokens(readline)
    flag = 0
    parens = 0
    argno = 0
    posn = []
    nodtype = True
    prevtok = None
    kwarg = 0
    for (tok_type, tok, (srow, scol), (erow, ecol), line) in tokens:
        if not flag and tok_type == token.NAME and tok in ('zeros', 'ones', 'empty'):
            flag = 1
        else:
            if tok == '(':
                parens += 1
            elif tok == ')':
                parens -= 1
                if parens == 0:
                    if nodtype and argno < 1:
                        posn.append(scol)
                    argno = 0
                    flag = 0
                    nodtype = True
                    argno = 0
            elif tok == '=':
                kwarg = 1
                if prevtok == 'dtype':
                    nodtype = False
            elif tok == ',':
                argno += (parens == 1)
        if len(line) == ecol:
            output.write(insert(', dtype=int', line, posn))
            posn = []
        prevtok = tok

def _test():
    import doctest
    doctest.testmod()

if __name__ == "__main__":
    _test()