summaryrefslogtreecommitdiff
path: root/sqlparse/engine
diff options
context:
space:
mode:
Diffstat (limited to 'sqlparse/engine')
-rw-r--r--sqlparse/engine/grouping.py80
1 files changed, 61 insertions, 19 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index 9314b89..a317044 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -51,19 +51,21 @@ def _group_left_right(tlist, ttype, value, cls,
ttype, value)
+def _find_matching(idx, tlist, start_ttype, start_value, end_ttype, end_value):
+ depth = 1
+ for tok in tlist.tokens[idx:]:
+ if tok.match(start_ttype, start_value):
+ depth += 1
+ elif tok.match(end_ttype, end_value):
+ depth -= 1
+ if depth == 1:
+ return tok
+ return None
+
+
def _group_matching(tlist, start_ttype, start_value, end_ttype, end_value,
cls, include_semicolon=False, recurse=False):
- def _find_matching(i, tl, stt, sva, ett, eva):
- depth = 1
- for n in xrange(i, len(tl.tokens)):
- t = tl.tokens[n]
- if t.match(stt, sva):
- depth += 1
- elif t.match(ett, eva):
- depth -= 1
- if depth == 1:
- return t
- return None
+
[_group_matching(sgroup, start_ttype, start_value, end_ttype, end_value,
cls, include_semicolon) for sgroup in tlist.get_sublists()
if recurse]
@@ -157,16 +159,17 @@ def group_identifier(tlist):
lambda y: (y.match(T.Punctuation, '.')
or y.ttype in (T.Operator,
T.Wildcard,
- T.ArrayIndex,
- T.Name)),
+ T.Name)
+ or isinstance(y, sql.SquareBrackets)),
lambda y: (y.ttype in (T.String.Symbol,
T.Name,
T.Wildcard,
- T.ArrayIndex,
T.Literal.String.Single,
T.Literal.Number.Integer,
T.Literal.Number.Float)
- or isinstance(y, (sql.Parenthesis, sql.Function)))))
+ or isinstance(y, (sql.Parenthesis,
+ sql.SquareBrackets,
+ sql.Function)))))
for t in tl.tokens[i:]:
# Don't take whitespaces into account.
if t.ttype is T.Whitespace:
@@ -275,9 +278,48 @@ def group_identifier_list(tlist):
tcomma = next_
-def group_parenthesis(tlist):
- _group_matching(tlist, T.Punctuation, '(', T.Punctuation, ')',
- sql.Parenthesis)
+def group_brackets(tlist):
+ """Group parentheses () or square brackets []
+
+ This is just like _group_matching, but complicated by the fact that
+ round brackets can contain square bracket groups and vice versa
+ """
+
+ if isinstance(tlist, (sql.Parenthesis, sql.SquareBrackets)):
+ idx = 1
+ else:
+ idx = 0
+
+ # Find the first opening bracket
+ token = tlist.token_next_match(idx, T.Punctuation, ['(', '['])
+
+ while token:
+ start_val = token.value # either '(' or '['
+ if start_val == '(':
+ end_val = ')'
+ group_class = sql.Parenthesis
+ else:
+ end_val = ']'
+ group_class = sql.SquareBrackets
+
+ tidx = tlist.token_index(token)
+
+ # Find the corresponding closing bracket
+ end = _find_matching(tidx, tlist, T.Punctuation, start_val,
+ T.Punctuation, end_val)
+
+ if end is None:
+ idx = tidx + 1
+ else:
+ group = tlist.group_tokens(group_class,
+ tlist.tokens_between(token, end))
+
+ # Check for nested bracket groups within this group
+ group_brackets(group)
+ idx = tlist.token_index(group) + 1
+
+ # Find the next opening bracket
+ token = tlist.token_next_match(idx, T.Punctuation, ['(', '['])
def group_comments(tlist):
@@ -393,7 +435,7 @@ def align_comments(tlist):
def group(tlist):
for func in [
group_comments,
- group_parenthesis,
+ group_brackets,
group_functions,
group_where,
group_case,