diff options
Diffstat (limited to 'cmd2/utils.py')
-rw-r--r-- | cmd2/utils.py | 125 |
1 files changed, 105 insertions, 20 deletions
diff --git a/cmd2/utils.py b/cmd2/utils.py index 5f2ceaf4..855ad23e 100644 --- a/cmd2/utils.py +++ b/cmd2/utils.py @@ -737,6 +737,87 @@ class RedirectionSavedState: self.saved_redirecting = saved_redirecting +def _remove_overridden_styles(styles_to_parse: List[str]) -> List[str]: + """ + Utility function for align_text() / truncate_line() which filters a style list down + to only those which would still be in effect if all were processed in order. + + This is mainly used to reduce how many style strings are stored in memory when + building large multiline strings with ANSI styles. We only need to carry over + styles from previous lines that are still in effect. + + :param styles_to_parse: list of styles to evaluate. + :return: list of styles that are still in effect. + """ + from . import ( + ansi, + ) + + class StyleState: + """Keeps track of what text styles are enabled""" + + def __init__(self) -> None: + # Contains styles still in effect, keyed by their index in styles_to_parse + self.style_dict: Dict[int, str] = dict() + + # Indexes into style_dict + self.reset_all: Optional[int] = None + self.fg: Optional[int] = None + self.bg: Optional[int] = None + self.intensity: Optional[int] = None + self.italic: Optional[int] = None + self.overline: Optional[int] = None + self.strikethrough: Optional[int] = None + self.underline: Optional[int] = None + + # Read the previous styles in order and keep track of their states + style_state = StyleState() + + for index, style in enumerate(styles_to_parse): + # For styles types that we recognize, only keep their latest value from styles_to_parse. + # All unrecognized style types will be retained and their order preserved. + if style in (str(ansi.TextStyle.RESET_ALL), str(ansi.TextStyle.ALT_RESET_ALL)): + style_state = StyleState() + style_state.reset_all = index + elif ansi.STD_FG_RE.match(style) or ansi.EIGHT_BIT_FG_RE.match(style) or ansi.RGB_FG_RE.match(style): + if style_state.fg is not None: + style_state.style_dict.pop(style_state.fg) + style_state.fg = index + elif ansi.STD_BG_RE.match(style) or ansi.EIGHT_BIT_BG_RE.match(style) or ansi.RGB_BG_RE.match(style): + if style_state.bg is not None: + style_state.style_dict.pop(style_state.bg) + style_state.bg = index + elif style in ( + str(ansi.TextStyle.INTENSITY_BOLD), + str(ansi.TextStyle.INTENSITY_DIM), + str(ansi.TextStyle.INTENSITY_NORMAL), + ): + if style_state.intensity is not None: + style_state.style_dict.pop(style_state.intensity) + style_state.intensity = index + elif style in (str(ansi.TextStyle.ITALIC_ENABLE), str(ansi.TextStyle.ITALIC_DISABLE)): + if style_state.italic is not None: + style_state.style_dict.pop(style_state.italic) + style_state.italic = index + elif style in (str(ansi.TextStyle.OVERLINE_ENABLE), str(ansi.TextStyle.OVERLINE_DISABLE)): + if style_state.overline is not None: + style_state.style_dict.pop(style_state.overline) + style_state.overline = index + elif style in (str(ansi.TextStyle.STRIKETHROUGH_ENABLE), str(ansi.TextStyle.STRIKETHROUGH_DISABLE)): + if style_state.strikethrough is not None: + style_state.style_dict.pop(style_state.strikethrough) + style_state.strikethrough = index + elif style in (str(ansi.TextStyle.UNDERLINE_ENABLE), str(ansi.TextStyle.UNDERLINE_DISABLE)): + if style_state.underline is not None: + style_state.style_dict.pop(style_state.underline) + style_state.underline = index + + # Store this style and its location in the dictionary + style_state.style_dict[index] = style + + return list(style_state.style_dict.values()) + + class TextAlignment(Enum): """Horizontal text alignment""" @@ -801,7 +882,7 @@ def align_text( raise (ValueError("Fill character is an unprintable character")) # Isolate the style chars before and after the fill character. We will use them when building sequences of - # of fill characters. Instead of repeating the style characters for each fill character, we'll wrap each sequence. + # fill characters. Instead of repeating the style characters for each fill character, we'll wrap each sequence. fill_char_style_begin, fill_char_style_end = fill_char.split(stripped_fill_char) if text: @@ -811,10 +892,10 @@ def align_text( text_buf = io.StringIO() - # ANSI style sequences that may affect future lines will be cancelled by the fill_char's style. - # To avoid this, we save the state of a line's style so we can restore it when beginning the next line. - # This also allows the lines to be used independently and still have their style. TableCreator does this. - aggregate_styles = '' + # ANSI style sequences that may affect subsequent lines will be cancelled by the fill_char's style. + # To avoid this, we save styles which are still in effect so we can restore them when beginning the next line. + # This also allows lines to be used independently and still have their style. TableCreator does this. + previous_styles: List[str] = [] for index, line in enumerate(lines): if index > 0: @@ -827,8 +908,8 @@ def align_text( if line_width == -1: raise (ValueError("Text to align contains an unprintable character")) - # Get the styles in this line - line_styles = get_styles_in_text(line) + # Get list of styles in this line + line_styles = list(get_styles_dict(line).values()) # Calculate how wide each side of filling needs to be if line_width >= width: @@ -858,7 +939,7 @@ def align_text( right_fill += ' ' * (right_fill_width - ansi.style_aware_wcswidth(right_fill)) # Don't allow styles in fill characters and text to affect one another - if fill_char_style_begin or fill_char_style_end or aggregate_styles or line_styles: + if fill_char_style_begin or fill_char_style_end or previous_styles or line_styles: if left_fill: left_fill = ansi.TextStyle.RESET_ALL + fill_char_style_begin + left_fill + fill_char_style_end left_fill += ansi.TextStyle.RESET_ALL @@ -867,11 +948,12 @@ def align_text( right_fill = ansi.TextStyle.RESET_ALL + fill_char_style_begin + right_fill + fill_char_style_end right_fill += ansi.TextStyle.RESET_ALL - # Write the line and restore any styles from previous lines - text_buf.write(left_fill + aggregate_styles + line + right_fill) + # Write the line and restore styles from previous lines which are still in effect + text_buf.write(left_fill + ''.join(previous_styles) + line + right_fill) - # Update the aggregate with styles in this line - aggregate_styles += ''.join(line_styles.values()) + # Update list of styles that are still in effect for the next line + previous_styles.extend(line_styles) + previous_styles = _remove_overridden_styles(previous_styles) return text_buf.getvalue() @@ -985,7 +1067,7 @@ def truncate_line(line: str, max_width: int, *, tab_width: int = 4) -> str: return line # Find all style sequences in the line - styles = get_styles_in_text(line) + styles_dict = get_styles_dict(line) # Add characters one by one and preserve all style sequences done = False @@ -995,10 +1077,10 @@ def truncate_line(line: str, max_width: int, *, tab_width: int = 4) -> str: while not done: # Check if a style sequence is at this index. These don't count toward display width. - if index in styles: - truncated_buf.write(styles[index]) - style_len = len(styles[index]) - styles.pop(index) + if index in styles_dict: + truncated_buf.write(styles_dict[index]) + style_len = len(styles_dict[index]) + styles_dict.pop(index) index += style_len continue @@ -1015,13 +1097,16 @@ def truncate_line(line: str, max_width: int, *, tab_width: int = 4) -> str: truncated_buf.write(char) index += 1 - # Append remaining style sequences from original string - truncated_buf.write(''.join(styles.values())) + # Filter out overridden styles from the remaining ones + remaining_styles = _remove_overridden_styles(list(styles_dict.values())) + + # Append the remaining styles to the truncated text + truncated_buf.write(''.join(remaining_styles)) return truncated_buf.getvalue() -def get_styles_in_text(text: str) -> Dict[int, str]: +def get_styles_dict(text: str) -> Dict[int, str]: """ Return an OrderedDict containing all ANSI style sequences found in a string |