summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/engine/grouping.py10
-rw-r--r--sqlparse/sql.py20
2 files changed, 24 insertions, 6 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index 77a53ad..fddee0f 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -163,17 +163,16 @@ def group_identifier_list(tlist):
(T.Keyword, T.Comment, T.Wildcard))
func = lambda t: imt(t, i=I_IDENT_LIST, m=M_ROLE, t=T_IDENT_LIST)
- token = tlist.token_next_by(m=M_COMMA)
+ tidx, token = tlist.token_idx_next_by(m=M_COMMA)
while token:
- tidx = tlist.token_index(token)
before, after = tlist.token_prev(tidx), tlist.token_next(tidx)
if func(before) and func(after):
tidx = tlist.token_index(before)
token = tlist.group_tokens_between(sql.IdentifierList, tidx, after, extend=True)
- token = tlist.token_next_by(m=M_COMMA, idx=tidx + 1)
+ tidx, token = tlist.token_idx_next_by(m=M_COMMA, idx=tidx + 1)
def group_brackets(tlist):
@@ -217,13 +216,12 @@ def group_aliased(tlist):
I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier,
) # sql.Operation)
- token = tlist.token_next_by(i=I_ALIAS, t=T.Number)
+ tidx, token = tlist.token_idx_next_by(i=I_ALIAS, t=T.Number)
while token:
- tidx = tlist.token_index(token)
next_ = tlist.token_next(tidx)
if imt(next_, i=sql.Identifier):
token = tlist.group_tokens_between(sql.Identifier, tidx, next_, extend=True)
- token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx + 1)
+ tidx, token = tlist.token_idx_next_by(i=I_ALIAS, t=T.Number, idx=tidx + 1)
def group_typecasts(tlist):
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index dfe0430..928b784 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -225,6 +225,22 @@ class TokenList(Token):
def _groupable_tokens(self):
return self.tokens
+ def _token_idx_matching(self, funcs, start=0, end=None, reverse=False):
+ """next token that match functions"""
+ if start is None:
+ return None
+
+ if not isinstance(funcs, (list, tuple)):
+ funcs = (funcs,)
+
+ iterable = enumerate(self.tokens[start:end], start=start)
+
+ for idx, token in iterable:
+ for func in funcs:
+ if func(token):
+ return idx, token
+ return None, None
+
def _token_matching(self, funcs, start=0, end=None, reverse=False):
"""next token that match functions"""
if start is None:
@@ -259,6 +275,10 @@ class TokenList(Token):
(ignore_comments and imt(tk, i=Comment)))
return self._token_matching(funcs)
+ def token_idx_next_by(self, i=None, m=None, t=None, idx=0, end=None):
+ funcs = lambda tk: imt(tk, i, m, t)
+ return self._token_idx_matching(funcs, idx, end)
+
def token_next_by(self, i=None, m=None, t=None, idx=0, end=None):
funcs = lambda tk: imt(tk, i, m, t)
return self._token_matching(funcs, idx, end)