summaryrefslogtreecommitdiff
path: root/isort/sorting.py
blob: 7ad70b2217c5b5b0419d0f94cfd1de52915860e9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import re
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional

if TYPE_CHECKING:
    from .settings import Config
else:
    Config = Any

_import_line_intro_re = re.compile("^(?:from|import) ")
_import_line_midline_import_re = re.compile(" import ")


def module_key(
    module_name: str,
    config: Config,
    sub_imports: bool = False,
    ignore_case: bool = False,
    section_name: Optional[Any] = None,
    straight_import: Optional[bool] = False,
) -> str:
    match = re.match(r"^(\.+)\s*(.*)", module_name)
    if match:
        sep = " " if config.reverse_relative else "_"
        module_name = sep.join(match.groups())

    prefix = ""
    if ignore_case:
        module_name = str(module_name).lower()
    else:
        module_name = str(module_name)

    if sub_imports and config.order_by_type:
        if module_name in config.constants:
            prefix = "A"
        elif module_name in config.classes:
            prefix = "B"
        elif module_name in config.variables:
            prefix = "C"
        elif module_name.isupper() and len(module_name) > 1:  # see issue #376
            prefix = "A"
        elif module_name in config.classes or module_name[0:1].isupper():
            prefix = "B"
        else:
            prefix = "C"
    if not config.case_sensitive:
        module_name = module_name.lower()

    length_sort = (
        config.length_sort
        or (config.length_sort_straight and straight_import)
        or str(section_name).lower() in config.length_sort_sections
    )
    _length_sort_maybe = (str(len(module_name)) + ":" + module_name) if length_sort else module_name
    return f"{module_name in config.force_to_top and 'A' or 'B'}{prefix}{_length_sort_maybe}"


def section_key(line: str, config: Config) -> str:
    section = "B"

    if (
        not config.sort_relative_in_force_sorted_sections
        and config.reverse_relative
        and line.startswith("from .")
    ):
        match = re.match(r"^from (\.+)\s*(.*)", line)
        if match:  # pragma: no cover - regex always matches if line starts with "from ."
            line = f"from {' '.join(match.groups())}"
    if config.group_by_package and line.strip().startswith("from"):
        line = line.split(" import", 1)[0]

    if config.lexicographical:
        line = _import_line_intro_re.sub("", _import_line_midline_import_re.sub(".", line))
    else:
        line = re.sub("^from ", "", line)
        line = re.sub("^import ", "", line)
    if config.sort_relative_in_force_sorted_sections:
        sep = " " if config.reverse_relative else "_"
        line = re.sub(r"^(\.+)", rf"\1{sep}", line)
    if line.split(" ")[0] in config.force_to_top:
        section = "A"
    # * If honor_case_in_force_sorted_sections is true, and case_sensitive and
    #   order_by_type are different, only ignore case in part of the line.
    # * Otherwise, let order_by_type decide the sorting of the whole line. This
    #   is only "correct" if case_sensitive and order_by_type have the same value.
    if config.honor_case_in_force_sorted_sections and config.case_sensitive != config.order_by_type:
        split_module = line.split(" import ", 1)
        if len(split_module) > 1:
            module_name, names = split_module
            if not config.case_sensitive:
                module_name = module_name.lower()
            if not config.order_by_type:
                names = names.lower()
            line = " import ".join([module_name, names])
        elif not config.case_sensitive:
            line = line.lower()
    elif not config.order_by_type:
        line = line.lower()

    return f"{section}{len(line) if config.length_sort else ''}{line}"


def sort(
    config: Config,
    to_sort: Iterable[str],
    key: Optional[Callable[[str], Any]] = None,
    reverse: bool = False,
) -> List[str]:
    return config.sorting_function(to_sort, key=key, reverse=reverse)


def naturally(
    to_sort: Iterable[str], key: Optional[Callable[[str], Any]] = None, reverse: bool = False
) -> List[str]:
    """Returns a naturally sorted list"""
    if key is None:
        key_callback = _natural_keys
    else:

        def key_callback(text: str) -> List[Any]:
            return _natural_keys(key(text))  # type: ignore

    return sorted(to_sort, key=key_callback, reverse=reverse)


def _atoi(text: str) -> Any:
    return int(text) if text.isdigit() else text


def _natural_keys(text: str) -> List[Any]:
    return [_atoi(c) for c in re.split(r"(\d+)", text)]