summaryrefslogtreecommitdiff
path: root/devtools/diffrtpy.py
blob: 6da8bc37b843d1647008be7197164e9f826170c6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#!/usr/bin/env python
"""
This is a tool that can be used with git difftool to generate a diff that
ignores type hints and comments.

The name of this script, ``diffrtpy`` is short for "diff runtime python", as
this will only compare the parts of the python code that has a runtime impact.

This is to make it easier to review PRs that contain type hints.

To use this script

.. code-block:: bash
    task run -- python -m pip install --upgrade strip-hints black python-minifier
    PYLOGGING_LEVEL=INFO task run -- git difftool -y -x $(readlink -f devtools/diffrtpy.py) upstream/main | tee /var/tmp/compact.diff

Then attach ``/var/tmp/compact.diff`` to the PR.
"""


import argparse
import logging
import os
import sys
from dataclasses import dataclass, field
from difflib import unified_diff
from pathlib import Path
from typing import List

import black
import python_minifier
from strip_hints import strip_string_to_string


def clean_python(input: Path) -> str:
    code = input.read_text()
    try:
        code = strip_string_to_string(code, to_empty=True, strip_nl=True)
    except Exception:
        logging.warning(
            "failed to strip type hints from %s, falling back to using with type hints",
            input,
        )
        code = code
    code = python_minifier.minify(
        code,
        remove_annotations=True,
        remove_pass=False,
        remove_literal_statements=True,
        combine_imports=False,
        hoist_literals=False,
        rename_locals=False,
        rename_globals=False,
        remove_object_base=False,
        convert_posargs_to_args=False,
        preserve_shebang=True,
    )
    code = black.format_str(code, mode=black.FileMode())
    return code


@dataclass
class Application:
    parser: argparse.ArgumentParser = field(
        default_factory=lambda: argparse.ArgumentParser(add_help=True)
    )

    def __post_init__(self) -> None:
        parser = self.parser
        parser.add_argument(
            "-v",
            "--verbose",
            action="count",
            dest="verbosity",
            help="increase verbosity level",
        )
        parser.add_argument("lhs_file", nargs=1, type=str)
        parser.add_argument("rhs_file", nargs=1, type=str)
        parser.set_defaults(handler=self.handle)

    def run(self, args: List[str]) -> None:
        parse_result = self.parser.parse_args(args)

        verbosity = parse_result.verbosity
        if verbosity is not None:
            root_logger = logging.getLogger("")
            root_logger.propagate = True
            new_level = (
                root_logger.getEffectiveLevel()
                - (min(1, verbosity)) * 10
                - min(max(0, verbosity - 1), 9) * 1
            )
            root_logger.setLevel(new_level)

        logging.debug(
            "sys.executable = %s, args = %s, parse_result = %s, logging.level = %s",
            sys.executable,
            args,
            parse_result,
            logging.getLogger("").getEffectiveLevel(),
        )

        parse_result.handler(parse_result)

    def handle(self, parse_result: argparse.Namespace) -> None:
        logging.debug("entry ...")

        base = os.environ["BASE"]

        lhs_file: Path = Path(parse_result.lhs_file[0])
        rhs_file: Path = Path(parse_result.rhs_file[0])

        logging.debug(
            "base = %s, lhs_file = %s, rhs_file = %s", base, lhs_file, rhs_file
        )

        if lhs_file.name.endswith(".py") and rhs_file.name.endswith(".py"):
            lhs_file_content = clean_python(lhs_file)
            rhs_file_content = clean_python(rhs_file)
        else:
            lhs_file_content = lhs_file.read_text()
            rhs_file_content = rhs_file.read_text()

        lhs_file_lines = lhs_file_content.splitlines(keepends=True)
        rhs_file_lines = rhs_file_content.splitlines(keepends=True)

        sys.stdout.writelines(
            unified_diff(lhs_file_lines, rhs_file_lines, f"a/{base}", f"b/{base}", n=5)
        )


def main() -> None:
    logging.basicConfig(
        level=os.environ.get("PYLOGGING_LEVEL", logging.INFO),
        stream=sys.stderr,
        datefmt="%Y-%m-%dT%H:%M:%S",
        format=(
            "%(asctime)s.%(msecs)03d %(process)d %(thread)d %(levelno)03d:%(levelname)-8s "
            "%(name)-12s %(module)s:%(lineno)s:%(funcName)s %(message)s"
        ),
    )

    Application().run(sys.argv[1:])


if __name__ == "__main__":
    main()