summaryrefslogtreecommitdiff
path: root/astroid/_ast.py
blob: 34b74c5f2395baa2d0ff86287562afeae0f493f5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import ast
from collections import namedtuple
from functools import partial
from typing import Optional
import sys

import astroid

_ast_py3 = None
try:
    import typed_ast.ast3 as _ast_py3
except ImportError:
    pass


PY38 = sys.version_info[:2] >= (3, 8)
if PY38:
    # On Python 3.8, typed_ast was merged back into `ast`
    _ast_py3 = ast


FunctionType = namedtuple("FunctionType", ["argtypes", "returns"])


class ParserModule(
    namedtuple(
        "ParserModule",
        [
            "module",
            "unary_op_classes",
            "cmp_op_classes",
            "bool_op_classes",
            "bin_op_classes",
            "context_classes",
        ],
    )
):
    def parse(self, string: str, type_comments=True):
        if self.module is _ast_py3:
            if PY38:
                parse_func = partial(self.module.parse, type_comments=type_comments)
            else:
                parse_func = partial(
                    self.module.parse, feature_version=sys.version_info.minor
                )
        else:
            parse_func = self.module.parse
        return parse_func(string)


def parse_function_type_comment(type_comment: str) -> Optional[FunctionType]:
    """Given a correct type comment, obtain a FunctionType object"""
    if _ast_py3 is None:
        return None

    func_type = _ast_py3.parse(type_comment, "<type_comment>", "func_type")
    return FunctionType(argtypes=func_type.argtypes, returns=func_type.returns)


def get_parser_module(type_comments=True) -> ParserModule:
    if not type_comments:
        parser_module = ast
    else:
        parser_module = _ast_py3
    parser_module = parser_module or ast

    unary_op_classes = _unary_operators_from_module(parser_module)
    cmp_op_classes = _compare_operators_from_module(parser_module)
    bool_op_classes = _bool_operators_from_module(parser_module)
    bin_op_classes = _binary_operators_from_module(parser_module)
    context_classes = _contexts_from_module(parser_module)

    return ParserModule(
        parser_module,
        unary_op_classes,
        cmp_op_classes,
        bool_op_classes,
        bin_op_classes,
        context_classes,
    )


def _unary_operators_from_module(module):
    return {module.UAdd: "+", module.USub: "-", module.Not: "not", module.Invert: "~"}


def _binary_operators_from_module(module):
    binary_operators = {
        module.Add: "+",
        module.BitAnd: "&",
        module.BitOr: "|",
        module.BitXor: "^",
        module.Div: "/",
        module.FloorDiv: "//",
        module.MatMult: "@",
        module.Mod: "%",
        module.Mult: "*",
        module.Pow: "**",
        module.Sub: "-",
        module.LShift: "<<",
        module.RShift: ">>",
    }
    return binary_operators


def _bool_operators_from_module(module):
    return {module.And: "and", module.Or: "or"}


def _compare_operators_from_module(module):
    return {
        module.Eq: "==",
        module.Gt: ">",
        module.GtE: ">=",
        module.In: "in",
        module.Is: "is",
        module.IsNot: "is not",
        module.Lt: "<",
        module.LtE: "<=",
        module.NotEq: "!=",
        module.NotIn: "not in",
    }


def _contexts_from_module(module):
    return {
        module.Load: astroid.Load,
        module.Store: astroid.Store,
        module.Del: astroid.Del,
        module.Param: astroid.Store,
    }