diff options
-rw-r--r-- | cmd2/utils.py | 48 | ||||
-rw-r--r-- | tests/test_utils.py | 11 |
2 files changed, 44 insertions, 15 deletions
diff --git a/cmd2/utils.py b/cmd2/utils.py index 2add7c5e..376e2696 100644 --- a/cmd2/utils.py +++ b/cmd2/utils.py @@ -11,7 +11,7 @@ import sys import threading import unicodedata from enum import Enum -from typing import Any, Callable, Iterable, List, Optional, TextIO, Union +from typing import Any, Callable, Iterable, List, OrderedDict, Optional, TextIO, Union from . import constants @@ -696,7 +696,7 @@ def align_text(text: str, alignment: TextAlignment, *, fill_char: str = ' ', :param truncate: if True, then each line will be shortened to fit within the display width. The truncated portions are replaced by a '…' character. Defaults to False. :return: aligned text - :raises: TypeError if fill_char is more than one character + :raises: TypeError if fill_char is more than one character (not including ANSI style sequences) ValueError if text or fill_char contains an unprintable character ValueError if width is less than 1 """ @@ -716,7 +716,7 @@ def align_text(text: str, alignment: TextAlignment, *, fill_char: str = ' ', if fill_char == '\t': fill_char = ' ' - if len(fill_char) != 1: + if len(ansi.strip_style(fill_char)) != 1: raise TypeError("Fill character must be exactly one character long") fill_char_width = ansi.style_aware_wcswidth(fill_char) @@ -788,7 +788,7 @@ def align_left(text: str, *, fill_char: str = ' ', width: Optional[int] = None, :param truncate: if True, then text will be shortened to fit within the display width. The truncated portion is replaced by a '…' character. Defaults to False. :return: left-aligned text - :raises: TypeError if fill_char is more than one character + :raises: TypeError if fill_char is more than one character (not including ANSI style sequences) ValueError if text or fill_char contains an unprintable character ValueError if width is less than 1 """ @@ -811,7 +811,7 @@ def align_center(text: str, *, fill_char: str = ' ', width: Optional[int] = None :param truncate: if True, then text will be shortened to fit within the display width. The truncated portion is replaced by a '…' character. Defaults to False. :return: centered text - :raises: TypeError if fill_char is more than one character + :raises: TypeError if fill_char is more than one character (not including ANSI style sequences) ValueError if text or fill_char contains an unprintable character ValueError if width is less than 1 """ @@ -834,7 +834,7 @@ def align_right(text: str, *, fill_char: str = ' ', width: Optional[int] = None, :param truncate: if True, then text will be shortened to fit within the display width. The truncated portion is replaced by a '…' character. Defaults to False. :return: right-aligned text - :raises: TypeError if fill_char is more than one character + :raises: TypeError if fill_char is more than one character (not including ANSI style sequences) ValueError if text or fill_char contains an unprintable character ValueError if width is less than 1 """ @@ -878,14 +878,7 @@ def truncate_line(line: str, max_width: int, *, tab_width: int = 4) -> str: return line # Find all style sequences in the line - start = 0 - styles = collections.OrderedDict() - while True: - match = ansi.ANSI_STYLE_RE.search(line, start) - if match is None: - break - styles[match.start()] = match.group() - start += len(match.group()) + styles = get_styles_in_text(line) # Add characters one by one and preserve all style sequences done = False @@ -919,3 +912,30 @@ def truncate_line(line: str, max_width: int, *, tab_width: int = 4) -> str: truncated_buf.write(''.join(styles.values())) return truncated_buf.getvalue() + + +def get_styles_in_text(text: str) -> OrderedDict[int, str]: + """ + Return an OrderedDict containing all ANSI style sequences found in a string + + The structure of the dictionary is: + key: index where sequences begins + value: ANSI style sequence found at index in text + + Keys are in ascending order + + :param text: text to search for style sequences + """ + from . import ansi + + start = 0 + styles = collections.OrderedDict() + + while True: + match = ansi.ANSI_STYLE_RE.search(text, start) + if match is None: + break + styles[match.start()] = match.group() + start += len(match.group()) + + return styles diff --git a/tests/test_utils.py b/tests/test_utils.py index db432286..7546184e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -368,6 +368,15 @@ def test_align_text_fill_char_is_tab(): aligned = cu.align_text(text, cu.TextAlignment.LEFT, fill_char=fill_char, width=width) assert aligned == text + ' ' +def test_align_text_fill_char_has_color(): + from cmd2 import ansi + + text = 'foo' + fill_char = ansi.fg.bright_yellow + '-' + ansi.fg.reset + width = 5 + aligned = cu.align_text(text, cu.TextAlignment.LEFT, fill_char=fill_char, width=width) + assert aligned == text + fill_char * 2 + def test_align_text_width_is_too_small(): text = 'foo' fill_char = '-' @@ -382,7 +391,7 @@ def test_align_text_fill_char_is_too_long(): with pytest.raises(TypeError): cu.align_text(text, cu.TextAlignment.LEFT, fill_char=fill_char, width=width) -def test_align_text_fill_char_is_unprintable(): +def test_align_text_fill_char_is_newline(): text = 'foo' fill_char = '\n' width = 5 |