diff options
Diffstat (limited to 'sqlparse')
| -rw-r--r-- | sqlparse/engine/grouping.py | 47 | ||||
| -rw-r--r-- | sqlparse/sql.py | 9 |
2 files changed, 52 insertions, 4 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index 0be44da..73679e3 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -277,9 +277,48 @@ def group_identifier_list(tlist): tcomma = next_ -def group_parenthesis(tlist): - _group_matching(tlist, T.Punctuation, '(', T.Punctuation, ')', - sql.Parenthesis) +def group_brackets(tlist): + """Group parentheses () or square brackets [] + + This is just like _group_matching, but complicated by the fact that + round brackets can contain square bracket groups and vice versa + """ + + if isinstance(tlist, (sql.Parenthesis, sql.SquareBrackets)): + idx = 1 + else: + idx = 0 + + # Find the first opening bracket + token = tlist.token_next_match(idx, T.Punctuation, ['(', '[']) + + while token: + start_val = token.value # either '(' or '[' + if start_val == '(': + end_val = ')' + group_class = sql.Parenthesis + else: + end_val = ']' + group_class = sql.SquareBrackets + + tidx = tlist.token_index(token) + + # Find the corresponding closing bracket + end = _find_matching(tidx, tlist, T.Punctuation, start_val, + T.Punctuation, end_val) + + if end is None: + idx = tidx + 1 + else: + group = tlist.group_tokens(group_class, + tlist.tokens_between(token, end)) + + # Check for nested bracket groups within this group + group_brackets(group) + idx = tlist.token_index(group) + 1 + + # Find the next opening bracket + token = tlist.token_next_match(idx, T.Punctuation, ['(', '[']) def group_comments(tlist): @@ -395,7 +434,7 @@ def align_comments(tlist): def group(tlist): for func in [ group_comments, - group_parenthesis, + group_brackets, group_functions, group_where, group_case, diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 8492c5e..25d5243 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -542,6 +542,15 @@ class Parenthesis(TokenList): return self.tokens[1:-1] +class SquareBrackets(TokenList): + """Tokens between square brackets""" + + __slots__ = ('value', 'ttype', 'tokens') + + @property + def _groupable_tokens(self): + return self.tokens[1:-1] + class Assignment(TokenList): """An assignment like 'var := val;'""" __slots__ = ('value', 'ttype', 'tokens') |
