summaryrefslogtreecommitdiff
path: root/cmd2/rl_utils.py
blob: b2dc764960410c909bd5436fe6b8fc483bb2e2bb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# coding=utf-8
"""
Imports the proper Readline for the platform and provides utility functions for it
"""
import sys
from enum import (
    Enum,
)
from typing import (
    Union,
    cast,
)

# Prefer statically linked gnureadline if available (for macOS compatibility due to issues with libedit)
try:
    # noinspection PyPackageRequirements
    import gnureadline as readline  # type: ignore[import]
except ImportError:
    # Try to import readline, but allow failure for convenience in Windows unit testing
    # Note: If this actually fails, you should install readline on Linux or Mac or pyreadline on Windows
    try:
        # noinspection PyUnresolvedReferences
        import readline  # type: ignore[no-redef]
    except ImportError:  # pragma: no cover
        pass


class RlType(Enum):
    """Readline library types we recognize"""

    GNU = 1
    PYREADLINE = 2
    NONE = 3


# Check what implementation of Readline we are using
rl_type = RlType.NONE

# Tells if the terminal we are running in supports vt100 control characters
vt100_support = False

# Explanation for why Readline wasn't loaded
_rl_warn_reason = ''

# The order of this check matters since importing pyreadline/pyreadline3 will also show readline in the modules list
if 'pyreadline' in sys.modules or 'pyreadline3' in sys.modules:
    rl_type = RlType.PYREADLINE

    import atexit
    from ctypes import (
        byref,
    )
    from ctypes.wintypes import (
        DWORD,
        HANDLE,
    )

    # Check if we are running in a terminal
    if sys.stdout.isatty():  # pragma: no cover
        # noinspection PyPep8Naming,PyUnresolvedReferences
        def enable_win_vt100(handle: HANDLE) -> bool:
            """
            Enables VT100 character sequences in a Windows console
            This only works on Windows 10 and up
            :param handle: the handle on which to enable vt100
            :return: True if vt100 characters are enabled for the handle
            """
            ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004

            # Get the current mode for this handle in the console
            cur_mode = DWORD(0)
            readline.rl.console.GetConsoleMode(handle, byref(cur_mode))

            retVal = False

            # Check if ENABLE_VIRTUAL_TERMINAL_PROCESSING is already enabled
            if (cur_mode.value & ENABLE_VIRTUAL_TERMINAL_PROCESSING) != 0:
                retVal = True

            elif readline.rl.console.SetConsoleMode(handle, cur_mode.value | ENABLE_VIRTUAL_TERMINAL_PROCESSING):
                # Restore the original mode when we exit
                atexit.register(readline.rl.console.SetConsoleMode, handle, cur_mode)
                retVal = True

            return retVal

        # Enable VT100 sequences for stdout and stderr
        STD_OUT_HANDLE = -11
        STD_ERROR_HANDLE = -12
        # noinspection PyUnresolvedReferences
        vt100_stdout_support = enable_win_vt100(readline.rl.console.GetStdHandle(STD_OUT_HANDLE))
        # noinspection PyUnresolvedReferences
        vt100_stderr_support = enable_win_vt100(readline.rl.console.GetStdHandle(STD_ERROR_HANDLE))
        vt100_support = vt100_stdout_support and vt100_stderr_support

    ############################################################################################################
    # pyreadline is incomplete in terms of the Python readline API. Add the missing functions we need.
    ############################################################################################################
    # readline.redisplay()
    try:
        getattr(readline, 'redisplay')
    except AttributeError:
        # noinspection PyProtectedMember,PyUnresolvedReferences
        readline.redisplay = readline.rl.mode._update_line

    # readline.remove_history_item()
    try:
        getattr(readline, 'remove_history_item')
    except AttributeError:
        # noinspection PyProtectedMember,PyUnresolvedReferences
        def pyreadline_remove_history_item(pos: int) -> None:
            """
            An implementation of remove_history_item() for pyreadline
            :param pos: The 0-based position in history to remove
            """
            # Save of the current location of the history cursor
            saved_cursor = readline.rl.mode._history.history_cursor

            # Delete the history item
            del readline.rl.mode._history.history[pos]

            # Update the cursor if needed
            if saved_cursor > pos:
                readline.rl.mode._history.history_cursor -= 1

        readline.remove_history_item = pyreadline_remove_history_item

elif 'gnureadline' in sys.modules or 'readline' in sys.modules:
    # We don't support libedit
    if 'libedit' not in readline.__doc__:
        try:
            # Load the readline lib so we can access members of it
            import ctypes

            readline_lib = ctypes.CDLL(readline.__file__)
        except (AttributeError, OSError):  # pragma: no cover
            _rl_warn_reason = (
                "this application is running in a non-standard Python environment in\n"
                "which GNU readline is not loaded dynamically from a shared library file."
            )
        else:
            rl_type = RlType.GNU
            vt100_support = sys.stdout.isatty()

# Check if readline was loaded
if rl_type == RlType.NONE:  # pragma: no cover
    if not _rl_warn_reason:
        _rl_warn_reason = (
            "no supported version of readline was found. To resolve this, install\n"
            "pyreadline on Windows or gnureadline on Mac."
        )
    rl_warning = "Readline features including tab completion have been disabled because\n" + _rl_warn_reason + '\n\n'
else:
    rl_warning = ''


# noinspection PyProtectedMember,PyUnresolvedReferences
def rl_force_redisplay() -> None:  # pragma: no cover
    """
    Causes readline to display the prompt and input text wherever the cursor is and start
    reading input from this location. This is the proper way to restore the input line after
    printing to the screen
    """
    if not sys.stdout.isatty():
        return

    if rl_type == RlType.GNU:
        readline_lib.rl_forced_update_display()

        # After manually updating the display, readline asks that rl_display_fixed be set to 1 for efficiency
        display_fixed = ctypes.c_int.in_dll(readline_lib, "rl_display_fixed")
        display_fixed.value = 1

    elif rl_type == RlType.PYREADLINE:
        # Call _print_prompt() first to set the new location of the prompt
        readline.rl.mode._print_prompt()
        readline.rl.mode._update_line()


# noinspection PyProtectedMember, PyUnresolvedReferences
def rl_get_point() -> int:  # pragma: no cover
    """
    Returns the offset of the current cursor position in rl_line_buffer
    """
    if rl_type == RlType.GNU:
        return ctypes.c_int.in_dll(readline_lib, "rl_point").value

    elif rl_type == RlType.PYREADLINE:
        return int(readline.rl.mode.l_buffer.point)

    else:
        return 0


# noinspection PyUnresolvedReferences
def rl_get_prompt() -> str:  # pragma: no cover
    """Gets Readline's current prompt"""
    if rl_type == RlType.GNU:
        encoded_prompt = ctypes.c_char_p.in_dll(readline_lib, "rl_prompt").value
        prompt = cast(bytes, encoded_prompt).decode(encoding='utf-8')

    elif rl_type == RlType.PYREADLINE:
        prompt_data: Union[str, bytes] = readline.rl.prompt
        if isinstance(prompt_data, bytes):
            prompt = prompt_data.decode(encoding='utf-8')
        else:
            prompt = prompt_data

    else:
        prompt = ''

    return rl_unescape_prompt(prompt)


# noinspection PyUnresolvedReferences
def rl_set_prompt(prompt: str) -> None:  # pragma: no cover
    """
    Sets Readline's prompt
    :param prompt: the new prompt value
    """
    escaped_prompt = rl_escape_prompt(prompt)

    if rl_type == RlType.GNU:
        encoded_prompt = bytes(escaped_prompt, encoding='utf-8')
        readline_lib.rl_set_prompt(encoded_prompt)

    elif rl_type == RlType.PYREADLINE:
        readline.rl.prompt = escaped_prompt


def rl_escape_prompt(prompt: str) -> str:
    """Overcome bug in GNU Readline in relation to calculation of prompt length in presence of ANSI escape codes

    :param prompt: original prompt
    :return: prompt safe to pass to GNU Readline
    """
    if rl_type == RlType.GNU:
        # start code to tell GNU Readline about beginning of invisible characters
        escape_start = "\x01"

        # end code to tell GNU Readline about end of invisible characters
        escape_end = "\x02"

        escaped = False
        result = ""

        for c in prompt:
            if c == "\x1b" and not escaped:
                result += escape_start + c
                escaped = True
            elif c.isalpha() and escaped:
                result += c + escape_end
                escaped = False
            else:
                result += c

        return result

    else:
        return prompt


def rl_unescape_prompt(prompt: str) -> str:
    """Remove escape characters from a Readline prompt"""
    if rl_type == RlType.GNU:
        escape_start = "\x01"
        escape_end = "\x02"
        prompt = prompt.replace(escape_start, "").replace(escape_end, "")

    return prompt