summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd2/utils.py48
-rw-r--r--tests/test_utils.py11
2 files changed, 44 insertions, 15 deletions
diff --git a/cmd2/utils.py b/cmd2/utils.py
index 2add7c5e..376e2696 100644
--- a/cmd2/utils.py
+++ b/cmd2/utils.py
@@ -11,7 +11,7 @@ import sys
import threading
import unicodedata
from enum import Enum
-from typing import Any, Callable, Iterable, List, Optional, TextIO, Union
+from typing import Any, Callable, Iterable, List, OrderedDict, Optional, TextIO, Union
from . import constants
@@ -696,7 +696,7 @@ def align_text(text: str, alignment: TextAlignment, *, fill_char: str = ' ',
:param truncate: if True, then each line will be shortened to fit within the display width. The truncated
portions are replaced by a '…' character. Defaults to False.
:return: aligned text
- :raises: TypeError if fill_char is more than one character
+ :raises: TypeError if fill_char is more than one character (not including ANSI style sequences)
ValueError if text or fill_char contains an unprintable character
ValueError if width is less than 1
"""
@@ -716,7 +716,7 @@ def align_text(text: str, alignment: TextAlignment, *, fill_char: str = ' ',
if fill_char == '\t':
fill_char = ' '
- if len(fill_char) != 1:
+ if len(ansi.strip_style(fill_char)) != 1:
raise TypeError("Fill character must be exactly one character long")
fill_char_width = ansi.style_aware_wcswidth(fill_char)
@@ -788,7 +788,7 @@ def align_left(text: str, *, fill_char: str = ' ', width: Optional[int] = None,
:param truncate: if True, then text will be shortened to fit within the display width. The truncated portion is
replaced by a '…' character. Defaults to False.
:return: left-aligned text
- :raises: TypeError if fill_char is more than one character
+ :raises: TypeError if fill_char is more than one character (not including ANSI style sequences)
ValueError if text or fill_char contains an unprintable character
ValueError if width is less than 1
"""
@@ -811,7 +811,7 @@ def align_center(text: str, *, fill_char: str = ' ', width: Optional[int] = None
:param truncate: if True, then text will be shortened to fit within the display width. The truncated portion is
replaced by a '…' character. Defaults to False.
:return: centered text
- :raises: TypeError if fill_char is more than one character
+ :raises: TypeError if fill_char is more than one character (not including ANSI style sequences)
ValueError if text or fill_char contains an unprintable character
ValueError if width is less than 1
"""
@@ -834,7 +834,7 @@ def align_right(text: str, *, fill_char: str = ' ', width: Optional[int] = None,
:param truncate: if True, then text will be shortened to fit within the display width. The truncated portion is
replaced by a '…' character. Defaults to False.
:return: right-aligned text
- :raises: TypeError if fill_char is more than one character
+ :raises: TypeError if fill_char is more than one character (not including ANSI style sequences)
ValueError if text or fill_char contains an unprintable character
ValueError if width is less than 1
"""
@@ -878,14 +878,7 @@ def truncate_line(line: str, max_width: int, *, tab_width: int = 4) -> str:
return line
# Find all style sequences in the line
- start = 0
- styles = collections.OrderedDict()
- while True:
- match = ansi.ANSI_STYLE_RE.search(line, start)
- if match is None:
- break
- styles[match.start()] = match.group()
- start += len(match.group())
+ styles = get_styles_in_text(line)
# Add characters one by one and preserve all style sequences
done = False
@@ -919,3 +912,30 @@ def truncate_line(line: str, max_width: int, *, tab_width: int = 4) -> str:
truncated_buf.write(''.join(styles.values()))
return truncated_buf.getvalue()
+
+
+def get_styles_in_text(text: str) -> OrderedDict[int, str]:
+ """
+ Return an OrderedDict containing all ANSI style sequences found in a string
+
+ The structure of the dictionary is:
+ key: index where sequences begins
+ value: ANSI style sequence found at index in text
+
+ Keys are in ascending order
+
+ :param text: text to search for style sequences
+ """
+ from . import ansi
+
+ start = 0
+ styles = collections.OrderedDict()
+
+ while True:
+ match = ansi.ANSI_STYLE_RE.search(text, start)
+ if match is None:
+ break
+ styles[match.start()] = match.group()
+ start += len(match.group())
+
+ return styles
diff --git a/tests/test_utils.py b/tests/test_utils.py
index db432286..7546184e 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -368,6 +368,15 @@ def test_align_text_fill_char_is_tab():
aligned = cu.align_text(text, cu.TextAlignment.LEFT, fill_char=fill_char, width=width)
assert aligned == text + ' '
+def test_align_text_fill_char_has_color():
+ from cmd2 import ansi
+
+ text = 'foo'
+ fill_char = ansi.fg.bright_yellow + '-' + ansi.fg.reset
+ width = 5
+ aligned = cu.align_text(text, cu.TextAlignment.LEFT, fill_char=fill_char, width=width)
+ assert aligned == text + fill_char * 2
+
def test_align_text_width_is_too_small():
text = 'foo'
fill_char = '-'
@@ -382,7 +391,7 @@ def test_align_text_fill_char_is_too_long():
with pytest.raises(TypeError):
cu.align_text(text, cu.TextAlignment.LEFT, fill_char=fill_char, width=width)
-def test_align_text_fill_char_is_unprintable():
+def test_align_text_fill_char_is_newline():
text = 'foo'
fill_char = '\n'
width = 5