summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2016-05-15 21:02:17 +0200
committerAndi Albrecht <albrecht.andi@gmail.com>2016-05-15 21:02:17 +0200
commitefb6fd4bdb80b985c356bb6eb996e6e25cf63b05 (patch)
tree5ca2ce45fc9cf0a35efbcf68bb643f3a12e2575d /sqlparse
parent9ab1464ea9c1d0296d698d9637ed3e3cd92326f9 (diff)
parent955996e3e5c49fb6b7f200ceecee2f8082656ac4 (diff)
downloadsqlparse-efb6fd4bdb80b985c356bb6eb996e6e25cf63b05.tar.gz
Merge pull request #235 from vmuriart/refactor
Refactor
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/compat.py23
-rw-r--r--sqlparse/engine/grouping.py524
-rw-r--r--sqlparse/filters.py4
-rw-r--r--sqlparse/sql.py240
-rw-r--r--sqlparse/utils.py93
5 files changed, 383 insertions, 501 deletions
diff --git a/sqlparse/compat.py b/sqlparse/compat.py
index 6b26384..334883b 100644
--- a/sqlparse/compat.py
+++ b/sqlparse/compat.py
@@ -14,29 +14,40 @@ PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
if PY3:
+ def u(s):
+ return str(s)
+
+
+ range = range
text_type = str
string_types = (str,)
from io import StringIO
- def u(s):
- return str(s)
elif PY2:
+ def u(s, encoding=None):
+ encoding = encoding or 'unicode-escape'
+ try:
+ return unicode(s)
+ except UnicodeDecodeError:
+ return unicode(s, encoding)
+
+
+ range = xrange
text_type = unicode
string_types = (basestring,)
- from StringIO import StringIO # flake8: noqa
-
- def u(s):
- return unicode(s)
+ from StringIO import StringIO
# Directly copied from six:
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
+
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(meta):
def __new__(cls, name, this_bases, d):
return meta(name, bases, d)
+
return type.__new__(metaclass, 'temporary_class', (), {})
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index 982488b..e30abab 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -1,450 +1,268 @@
# -*- coding: utf-8 -*-
-import itertools
-
from sqlparse import sql
from sqlparse import tokens as T
+from sqlparse.utils import recurse, imt, find_matching
+
+M_ROLE = (T.Keyword, ('null', 'role'))
+M_SEMICOLON = (T.Punctuation, ';')
+M_COMMA = (T.Punctuation, ',')
+
+T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float)
+T_STRING = (T.String, T.String.Single, T.String.Symbol)
+T_NAME = (T.Name, T.Name.Placeholder)
-def _group_left_right(tlist, ttype, value, cls,
- check_right=lambda t: True,
- check_left=lambda t: True,
- include_semicolon=False):
- [_group_left_right(sgroup, ttype, value, cls, check_right, check_left,
- include_semicolon) for sgroup in tlist.get_sublists()
- if not isinstance(sgroup, cls)]
- idx = 0
- token = tlist.token_next_match(idx, ttype, value)
+def _group_left_right(tlist, m, cls,
+ valid_left=lambda t: t is not None,
+ valid_right=lambda t: t is not None,
+ semicolon=False):
+ """Groups together tokens that are joined by a middle token. ie. x < y"""
+ [_group_left_right(sgroup, m, cls, valid_left, valid_right, semicolon)
+ for sgroup in tlist.get_sublists() if not isinstance(sgroup, cls)]
+
+ token = tlist.token_next_by(m=m)
while token:
- right = tlist.token_next(tlist.token_index(token))
- left = tlist.token_prev(tlist.token_index(token))
- if right is None or not check_right(right):
- token = tlist.token_next_match(tlist.token_index(token) + 1,
- ttype, value)
- elif left is None or not check_left(left):
- token = tlist.token_next_match(tlist.token_index(token) + 1,
- ttype, value)
- else:
- if include_semicolon:
- sright = tlist.token_next_match(tlist.token_index(right),
- T.Punctuation, ';')
- if sright is not None:
- # only overwrite "right" if a semicolon is actually
- # present.
- right = sright
- tokens = tlist.tokens_between(left, right)[1:]
- if not isinstance(left, cls):
- new = cls([left])
- new_idx = tlist.token_index(left)
- tlist.tokens.remove(left)
- tlist.tokens.insert(new_idx, new)
- left = new
- left.tokens.extend(tokens)
- for t in tokens:
- tlist.tokens.remove(t)
- token = tlist.token_next_match(tlist.token_index(left) + 1,
- ttype, value)
-
-
-def _find_matching(idx, tlist, start_ttype, start_value, end_ttype, end_value):
- depth = 1
- for tok in tlist.tokens[idx:]:
- if tok.match(start_ttype, start_value):
- depth += 1
- elif tok.match(end_ttype, end_value):
- depth -= 1
- if depth == 1:
- return tok
- return None
-
-
-def _group_matching(tlist, start_ttype, start_value, end_ttype, end_value,
- cls, include_semicolon=False, recurse=False):
-
- [_group_matching(sgroup, start_ttype, start_value, end_ttype, end_value,
- cls, include_semicolon) for sgroup in tlist.get_sublists()
- if recurse]
- if isinstance(tlist, cls):
- idx = 1
- else:
- idx = 0
- token = tlist.token_next_match(idx, start_ttype, start_value)
+ left, right = tlist.token_prev(token), tlist.token_next(token)
+
+ if valid_left(left) and valid_right(right):
+ if semicolon:
+ sright = tlist.token_next_by(m=M_SEMICOLON, idx=right)
+ right = sright or right # only overwrite if a semicolon present.
+ tokens = tlist.tokens_between(left, right)
+ token = tlist.group_tokens(cls, tokens, extend=True)
+ token = tlist.token_next_by(m=m, idx=token)
+
+
+def _group_matching(tlist, cls):
+ """Groups Tokens that have beginning and end. ie. parenthesis, brackets.."""
+ idx = 1 if imt(tlist, i=cls) else 0
+
+ token = tlist.token_next_by(m=cls.M_OPEN, idx=idx)
while token:
- tidx = tlist.token_index(token)
- end = _find_matching(tidx, tlist, start_ttype, start_value,
- end_ttype, end_value)
- if end is None:
- idx = tidx + 1
- else:
- if include_semicolon:
- next_ = tlist.token_next(tlist.token_index(end))
- if next_ and next_.match(T.Punctuation, ';'):
- end = next_
- group = tlist.group_tokens(cls, tlist.tokens_between(token, end))
- _group_matching(group, start_ttype, start_value,
- end_ttype, end_value, cls, include_semicolon)
- idx = tlist.token_index(group) + 1
- token = tlist.token_next_match(idx, start_ttype, start_value)
+ end = find_matching(tlist, token, cls.M_OPEN, cls.M_CLOSE)
+ if end is not None:
+ token = tlist.group_tokens(cls, tlist.tokens_between(token, end))
+ _group_matching(token, cls)
+ token = tlist.token_next_by(m=cls.M_OPEN, idx=token)
def group_if(tlist):
- _group_matching(tlist, T.Keyword, 'IF', T.Keyword, 'END IF', sql.If, True)
+ _group_matching(tlist, sql.If)
def group_for(tlist):
- _group_matching(tlist, T.Keyword, 'FOR', T.Keyword, 'END LOOP',
- sql.For, True)
+ _group_matching(tlist, sql.For)
def group_foreach(tlist):
- _group_matching(tlist, T.Keyword, 'FOREACH', T.Keyword, 'END LOOP',
- sql.For, True)
+ _group_matching(tlist, sql.For)
def group_begin(tlist):
- _group_matching(tlist, T.Keyword, 'BEGIN', T.Keyword, 'END',
- sql.Begin, True)
+ _group_matching(tlist, sql.Begin)
def group_as(tlist):
-
- def _right_valid(token):
- # Currently limited to DML/DDL. Maybe additional more non SQL reserved
- # keywords should appear here (see issue8).
- return token.ttype not in (T.DML, T.DDL)
-
- def _left_valid(token):
- if token.ttype is T.Keyword and token.value in ('NULL',):
- return True
- return token.ttype is not T.Keyword
-
- _group_left_right(tlist, T.Keyword, 'AS', sql.Identifier,
- check_right=_right_valid,
- check_left=_left_valid)
+ lfunc = lambda tk: not imt(tk, t=T.Keyword) or tk.value == 'NULL'
+ rfunc = lambda tk: not imt(tk, t=(T.DML, T.DDL))
+ _group_left_right(tlist, (T.Keyword, 'AS'), sql.Identifier,
+ valid_left=lfunc, valid_right=rfunc)
def group_assignment(tlist):
- _group_left_right(tlist, T.Assignment, ':=', sql.Assignment,
- include_semicolon=True)
+ _group_left_right(tlist, (T.Assignment, ':='), sql.Assignment,
+ semicolon=True)
def group_comparison(tlist):
+ I_COMPERABLE = (sql.Parenthesis, sql.Function, sql.Identifier)
+ T_COMPERABLE = T_NUMERICAL + T_STRING + T_NAME
- def _parts_valid(token):
- return (token.ttype in (T.String.Symbol, T.String.Single,
- T.Name, T.Number, T.Number.Float,
- T.Number.Integer, T.Literal,
- T.Literal.Number.Integer, T.Name.Placeholder)
- or isinstance(token, (sql.Identifier, sql.Parenthesis,
- sql.Function))
- or (token.ttype is T.Keyword
- and token.value.upper() in ['NULL', ]))
- _group_left_right(tlist, T.Operator.Comparison, None, sql.Comparison,
- check_left=_parts_valid, check_right=_parts_valid)
+ func = lambda tk: imt(tk, t=T_COMPERABLE, i=I_COMPERABLE) or (
+ imt(tk, t=T.Keyword) and tk.value.upper() == 'NULL')
+
+ _group_left_right(tlist, (T.Operator.Comparison, None), sql.Comparison,
+ valid_left=func, valid_right=func)
def group_case(tlist):
- _group_matching(tlist, T.Keyword, 'CASE', T.Keyword, 'END', sql.Case,
- include_semicolon=True, recurse=True)
+ _group_matching(tlist, sql.Case)
+@recurse(sql.Identifier)
def group_identifier(tlist):
- def _consume_cycle(tl, i):
- # TODO: Usage of Wildcard token is ambivalent here.
- x = itertools.cycle((
- lambda y: (y.match(T.Punctuation, '.')
- or y.ttype in (T.Operator,
- T.Wildcard,
- T.Name)
- or isinstance(y, sql.SquareBrackets)),
- lambda y: (y.ttype in (T.String.Symbol,
- T.Name,
- T.Wildcard,
- T.Literal.String.Single,
- T.Literal.Number.Integer,
- T.Literal.Number.Float)
- or isinstance(y, (sql.Parenthesis,
- sql.SquareBrackets,
- sql.Function)))))
- for t in tl.tokens[i:]:
- # Don't take whitespaces into account.
- if t.ttype is T.Whitespace:
- yield t
- continue
- if next(x)(t):
- yield t
- else:
- if isinstance(t, sql.Comment) and t.is_multiline():
- yield t
- if t.ttype is T.Keyword.Order:
- yield t
- return
-
- def _next_token(tl, i):
- # chooses the next token. if two tokens are found then the
- # first is returned.
- t1 = tl.token_next_by_type(
- i, (T.String.Symbol, T.Name, T.Literal.Number.Integer,
- T.Literal.Number.Float))
-
- i1 = tl.token_index(t1, start=i) if t1 else None
- t2_end = None if i1 is None else i1 + 1
- t2 = tl.token_next_by_instance(i, (sql.Function, sql.Parenthesis),
- end=t2_end)
-
- if t1 and t2:
- i2 = tl.token_index(t2, start=i)
- if i1 > i2:
- return t2
- else:
- return t1
- elif t1:
- return t1
- else:
- return t2
+ T_IDENT = (T.String.Symbol, T.Name)
- # bottom up approach: group subgroups first
- [group_identifier(sgroup) for sgroup in tlist.get_sublists()
- if not isinstance(sgroup, sql.Identifier)]
-
- # real processing
- idx = 0
- token = _next_token(tlist, idx)
+ token = tlist.token_next_by(t=T_IDENT)
while token:
- identifier_tokens = [token] + list(
- _consume_cycle(tlist,
- tlist.token_index(token, start=idx) + 1))
- # remove trailing whitespace
- if identifier_tokens and identifier_tokens[-1].ttype is T.Whitespace:
- identifier_tokens = identifier_tokens[:-1]
- if not (len(identifier_tokens) == 1
- and (isinstance(identifier_tokens[0], (sql.Function,
- sql.Parenthesis))
- or identifier_tokens[0].ttype in (
- T.Literal.Number.Integer, T.Literal.Number.Float))):
- group = tlist.group_tokens(sql.Identifier, identifier_tokens)
- idx = tlist.token_index(group, start=idx) + 1
- else:
- idx += 1
- token = _next_token(tlist, idx)
+ token = tlist.group_tokens(sql.Identifier, [token, ])
+ token = tlist.token_next_by(t=T_IDENT, idx=token)
-def group_identifier_list(tlist):
- [group_identifier_list(sgroup) for sgroup in tlist.get_sublists()
- if not isinstance(sgroup, sql.IdentifierList)]
- # Allowed list items
- fend1_funcs = [lambda t: isinstance(t, (sql.Identifier, sql.Function,
- sql.Case)),
- lambda t: t.is_whitespace(),
- lambda t: t.ttype == T.Name,
- lambda t: t.ttype == T.Wildcard,
- lambda t: t.match(T.Keyword, 'null'),
- lambda t: t.match(T.Keyword, 'role'),
- lambda t: t.ttype == T.Number.Integer,
- lambda t: t.ttype == T.String.Single,
- lambda t: t.ttype == T.Name.Placeholder,
- lambda t: t.ttype == T.Keyword,
- lambda t: isinstance(t, sql.Comparison),
- lambda t: isinstance(t, sql.Comment),
- lambda t: t.ttype == T.Comment.Multiline,
- ]
- tcomma = tlist.token_next_match(0, T.Punctuation, ',')
- start = None
- while tcomma is not None:
- # Go back one idx to make sure to find the correct tcomma
- idx = tlist.token_index(tcomma)
- before = tlist.token_prev(idx)
- after = tlist.token_next(idx)
- # Check if the tokens around tcomma belong to a list
- bpassed = apassed = False
- for func in fend1_funcs:
- if before is not None and func(before):
- bpassed = True
- if after is not None and func(after):
- apassed = True
- if not bpassed or not apassed:
- # Something's wrong here, skip ahead to next ","
- start = None
- tcomma = tlist.token_next_match(idx + 1,
- T.Punctuation, ',')
- else:
- if start is None:
- start = before
- after_idx = tlist.token_index(after, start=idx)
- next_ = tlist.token_next(after_idx)
- if next_ is None or not next_.match(T.Punctuation, ','):
- # Reached the end of the list
- tokens = tlist.tokens_between(start, after)
- group = tlist.group_tokens(sql.IdentifierList, tokens)
- start = None
- tcomma = tlist.token_next_match(tlist.token_index(group) + 1,
- T.Punctuation, ',')
- else:
- tcomma = next_
+def group_period(tlist):
+ lfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Identifier),
+ t=(T.Name, T.String.Symbol,))
+ rfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Function),
+ t=(T.Name, T.String.Symbol, T.Wildcard))
-def group_brackets(tlist):
- """Group parentheses () or square brackets []
+ _group_left_right(tlist, (T.Punctuation, '.'), sql.Identifier,
+ valid_left=lfunc, valid_right=rfunc)
- This is just like _group_matching, but complicated by the fact that
- round brackets can contain square bracket groups and vice versa
- """
- if isinstance(tlist, (sql.Parenthesis, sql.SquareBrackets)):
- idx = 1
- else:
- idx = 0
+def group_arrays(tlist):
+ token = tlist.token_next_by(i=sql.SquareBrackets)
+ while token:
+ prev = tlist.token_prev(idx=token)
+ if imt(prev, i=(sql.SquareBrackets, sql.Identifier, sql.Function),
+ t=(T.Name, T.String.Symbol,)):
+ tokens = tlist.tokens_between(prev, token)
+ token = tlist.group_tokens(sql.Identifier, tokens, extend=True)
+ token = tlist.token_next_by(i=sql.SquareBrackets, idx=token)
+
+
+@recurse(sql.Identifier)
+def group_operator(tlist):
+ I_CYCLE = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
+ sql.Identifier,) # sql.Operation)
+ # wilcards wouldn't have operations next to them
+ T_CYCLE = T_NUMERICAL + T_STRING + T_NAME # + T.Wildcard
+ func = lambda tk: imt(tk, i=I_CYCLE, t=T_CYCLE)
+
+ token = tlist.token_next_by(t=(T.Operator, T.Wildcard))
+ while token:
+ left, right = tlist.token_prev(token), tlist.token_next(token)
+
+ if func(left) and func(right):
+ token.ttype = T.Operator
+ tokens = tlist.tokens_between(left, right)
+ # token = tlist.group_tokens(sql.Operation, tokens)
+ token = tlist.group_tokens(sql.Identifier, tokens)
- # Find the first opening bracket
- token = tlist.token_next_match(idx, T.Punctuation, ['(', '['])
+ token = tlist.token_next_by(t=(T.Operator, T.Wildcard), idx=token)
+
+
+@recurse(sql.IdentifierList)
+def group_identifier_list(tlist):
+ I_IDENT_LIST = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
+ sql.IdentifierList) # sql.Operation
+ T_IDENT_LIST = (T_NUMERICAL + T_STRING + T_NAME +
+ (T.Keyword, T.Comment, T.Wildcard))
+
+ func = lambda t: imt(t, i=I_IDENT_LIST, m=M_ROLE, t=T_IDENT_LIST)
+ token = tlist.token_next_by(m=M_COMMA)
while token:
- start_val = token.value # either '(' or '['
- if start_val == '(':
- end_val = ')'
- group_class = sql.Parenthesis
- else:
- end_val = ']'
- group_class = sql.SquareBrackets
+ before, after = tlist.token_prev(token), tlist.token_next(token)
- tidx = tlist.token_index(token)
+ if func(before) and func(after):
+ tokens = tlist.tokens_between(before, after)
+ token = tlist.group_tokens(sql.IdentifierList, tokens, extend=True)
+ token = tlist.token_next_by(m=M_COMMA, idx=token)
- # Find the corresponding closing bracket
- end = _find_matching(tidx, tlist, T.Punctuation, start_val,
- T.Punctuation, end_val)
- if end is None:
- idx = tidx + 1
- else:
- group = tlist.group_tokens(group_class,
- tlist.tokens_between(token, end))
+def group_brackets(tlist):
+ _group_matching(tlist, sql.SquareBrackets)
- # Check for nested bracket groups within this group
- group_brackets(group)
- idx = tlist.token_index(group) + 1
- # Find the next opening bracket
- token = tlist.token_next_match(idx, T.Punctuation, ['(', '['])
+def group_parenthesis(tlist):
+ _group_matching(tlist, sql.Parenthesis)
+@recurse(sql.Comment)
def group_comments(tlist):
- [group_comments(sgroup) for sgroup in tlist.get_sublists()
- if not isinstance(sgroup, sql.Comment)]
- idx = 0
- token = tlist.token_next_by_type(idx, T.Comment)
+ token = tlist.token_next_by(t=T.Comment)
while token:
- tidx = tlist.token_index(token)
- end = tlist.token_not_matching(tidx + 1,
- [lambda t: t.ttype in T.Comment,
- lambda t: t.is_whitespace()])
- if end is None:
- idx = tidx + 1
- else:
- eidx = tlist.token_index(end)
- grp_tokens = tlist.tokens_between(token,
- tlist.token_prev(eidx, False))
- group = tlist.group_tokens(sql.Comment, grp_tokens)
- idx = tlist.token_index(group)
- token = tlist.token_next_by_type(idx, T.Comment)
+ end = tlist.token_not_matching(
+ token, lambda tk: imt(tk, t=T.Comment) or tk.is_whitespace())
+ if end is not None:
+ end = tlist.token_prev(end, False)
+ tokens = tlist.tokens_between(token, end)
+ token = tlist.group_tokens(sql.Comment, tokens)
+
+ token = tlist.token_next_by(t=T.Comment, idx=token)
+@recurse(sql.Where)
def group_where(tlist):
- [group_where(sgroup) for sgroup in tlist.get_sublists()
- if not isinstance(sgroup, sql.Where)]
- idx = 0
- token = tlist.token_next_match(idx, T.Keyword, 'WHERE')
- stopwords = ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', 'HAVING')
+ token = tlist.token_next_by(m=sql.Where.M_OPEN)
while token:
- tidx = tlist.token_index(token)
- end = tlist.token_next_match(tidx + 1, T.Keyword, stopwords)
+ end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=token)
+
if end is None:
- end = tlist._groupable_tokens[-1]
+ tokens = tlist.tokens_between(token, tlist._groupable_tokens[-1])
else:
- end = tlist.tokens[tlist.token_index(end) - 1]
- group = tlist.group_tokens(sql.Where,
- tlist.tokens_between(token, end),
- ignore_ws=True)
- idx = tlist.token_index(group)
- token = tlist.token_next_match(idx, T.Keyword, 'WHERE')
+ tokens = tlist.tokens_between(
+ token, tlist.tokens[tlist.token_index(end) - 1])
+
+ token = tlist.group_tokens(sql.Where, tokens)
+ token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=token)
+@recurse()
def group_aliased(tlist):
- clss = (sql.Identifier, sql.Function, sql.Case)
- [group_aliased(sgroup) for sgroup in tlist.get_sublists()
- if not isinstance(sgroup, clss)]
- idx = 0
- token = tlist.token_next_by_instance(idx, clss)
+ I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier,
+ ) # sql.Operation)
+
+ token = tlist.token_next_by(i=I_ALIAS, t=T.Number)
while token:
- next_ = tlist.token_next(tlist.token_index(token))
- if next_ is not None and isinstance(next_, clss):
- if not next_.value.upper().startswith('VARCHAR'):
- grp = tlist.tokens_between(token, next_)[1:]
- token.tokens.extend(grp)
- for t in grp:
- tlist.tokens.remove(t)
- idx = tlist.token_index(token) + 1
- token = tlist.token_next_by_instance(idx, clss)
+ next_ = tlist.token_next(token)
+ if imt(next_, i=sql.Identifier):
+ tokens = tlist.tokens_between(token, next_)
+ token = tlist.group_tokens(sql.Identifier, tokens, extend=True)
+ token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=token)
def group_typecasts(tlist):
- _group_left_right(tlist, T.Punctuation, '::', sql.Identifier)
+ _group_left_right(tlist, (T.Punctuation, '::'), sql.Identifier)
+@recurse(sql.Function)
def group_functions(tlist):
- [group_functions(sgroup) for sgroup in tlist.get_sublists()
- if not isinstance(sgroup, sql.Function)]
- idx = 0
- token = tlist.token_next_by_type(idx, T.Name)
+ token = tlist.token_next_by(t=T.Name)
while token:
next_ = tlist.token_next(token)
- if not isinstance(next_, sql.Parenthesis):
- idx = tlist.token_index(token) + 1
- else:
- func = tlist.group_tokens(sql.Function,
- tlist.tokens_between(token, next_))
- idx = tlist.token_index(func) + 1
- token = tlist.token_next_by_type(idx, T.Name)
+ if imt(next_, i=sql.Parenthesis):
+ tokens = tlist.tokens_between(token, next_)
+ token = tlist.group_tokens(sql.Function, tokens)
+ token = tlist.token_next_by(t=T.Name, idx=token)
def group_order(tlist):
- idx = 0
- token = tlist.token_next_by_type(idx, T.Keyword.Order)
+ """Group together Identifier and Asc/Desc token"""
+ token = tlist.token_next_by(t=T.Keyword.Order)
while token:
prev = tlist.token_prev(token)
- if isinstance(prev, sql.Identifier):
- ido = tlist.group_tokens(sql.Identifier,
- tlist.tokens_between(prev, token))
- idx = tlist.token_index(ido) + 1
- else:
- idx = tlist.token_index(token) + 1
- token = tlist.token_next_by_type(idx, T.Keyword.Order)
+ if imt(prev, i=sql.Identifier, t=T.Number):
+ tokens = tlist.tokens_between(prev, token)
+ token = tlist.group_tokens(sql.Identifier, tokens)
+ token = tlist.token_next_by(t=T.Keyword.Order, idx=token)
+@recurse()
def align_comments(tlist):
- [align_comments(sgroup) for sgroup in tlist.get_sublists()]
- idx = 0
- token = tlist.token_next_by_instance(idx, sql.Comment)
+ token = tlist.token_next_by(i=sql.Comment)
while token:
- before = tlist.token_prev(tlist.token_index(token))
+ before = tlist.token_prev(token)
if isinstance(before, sql.TokenList):
- grp = tlist.tokens_between(before, token)[1:]
- before.tokens.extend(grp)
- for t in grp:
- tlist.tokens.remove(t)
- idx = tlist.token_index(before) + 1
- else:
- idx = tlist.token_index(token) + 1
- token = tlist.token_next_by_instance(idx, sql.Comment)
+ tokens = tlist.tokens_between(before, token)
+ token = tlist.group_tokens(sql.TokenList, tokens, extend=True)
+ token = tlist.token_next_by(i=sql.Comment, idx=token)
def group(tlist):
for func in [
group_comments,
group_brackets,
+ group_parenthesis,
group_functions,
group_where,
group_case,
+ group_period,
+ group_arrays,
group_identifier,
+ group_operator,
group_order,
group_typecasts,
group_as,
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index 68e9b1a..72f17d0 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -200,9 +200,7 @@ class StripCommentsFilter:
def _get_next_comment(self, tlist):
# TODO(andi) Comment types should be unified, see related issue38
- token = tlist.token_next_by_instance(0, sql.Comment)
- if token is None:
- token = tlist.token_next_by_type(0, T.Comment)
+ token = tlist.token_next_by(i=sql.Comment, t=T.Comment)
return token
def _process(self, tlist):
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index f357572..9afdac3 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -7,6 +7,7 @@ import sys
from sqlparse import tokens as T
from sqlparse.compat import string_types, u
+from sqlparse.utils import imt, remove_quotes
class Token(object):
@@ -77,7 +78,7 @@ class Token(object):
if regex:
if isinstance(values, string_types):
- values = set([values])
+ values = {values}
if self.ttype is T.Keyword:
values = set(re.compile(v, re.IGNORECASE) for v in values)
@@ -150,7 +151,7 @@ class TokenList(Token):
if tokens is None:
tokens = []
self.tokens = tokens
- Token.__init__(self, None, self._to_string())
+ super(TokenList, self).__init__(None, self.__str__())
def __unicode__(self):
return self._to_string()
@@ -184,14 +185,6 @@ class TokenList(Token):
if (token.is_group() and (max_depth is None or depth < max_depth)):
token._pprint_tree(max_depth, depth + 1)
- def _remove_quotes(self, val):
- """Helper that removes surrounding quotes from strings."""
- if not val:
- return val
- if val[0] in ('"', '\'') and val[-1] == val[0]:
- val = val[1:-1]
- return val
-
def get_token_at_offset(self, offset):
"""Returns the token that is on position offset."""
idx = 0
@@ -213,12 +206,12 @@ class TokenList(Token):
else:
yield token
-# def __iter__(self):
-# return self
-#
-# def next(self):
-# for token in self.tokens:
-# yield token
+ # def __iter__(self):
+ # return self
+ #
+ # def next(self):
+ # for token in self.tokens:
+ # yield token
def is_group(self):
return True
@@ -232,6 +225,27 @@ class TokenList(Token):
def _groupable_tokens(self):
return self.tokens
+ def _token_matching(self, funcs, start=0, end=None, reverse=False):
+ """next token that match functions"""
+ if start is None:
+ return None
+
+ if not isinstance(start, int):
+ start = self.token_index(start) + 1
+
+ if not isinstance(funcs, (list, tuple)):
+ funcs = (funcs,)
+
+ if reverse:
+ iterable = iter(reversed(self.tokens[end:start - 1]))
+ else:
+ iterable = self.tokens[start:end]
+
+ for token in iterable:
+ for func in funcs:
+ if func(token):
+ return token
+
def token_first(self, ignore_whitespace=True, ignore_comments=False):
"""Returns the first child token.
@@ -241,12 +255,13 @@ class TokenList(Token):
if *ignore_comments* is ``True`` (default: ``False``), comments are
ignored too.
"""
- for token in self.tokens:
- if ignore_whitespace and token.is_whitespace():
- continue
- if ignore_comments and isinstance(token, Comment):
- continue
- return token
+ funcs = lambda tk: not ((ignore_whitespace and tk.is_whitespace()) or
+ (ignore_comments and imt(tk, i=Comment)))
+ return self._token_matching(funcs)
+
+ def token_next_by(self, i=None, m=None, t=None, idx=0, end=None):
+ funcs = lambda tk: imt(tk, i, m, t)
+ return self._token_matching(funcs, idx, end)
def token_next_by_instance(self, idx, clss, end=None):
"""Returns the next token matching a class.
@@ -256,48 +271,26 @@ class TokenList(Token):
If no matching token can be found ``None`` is returned.
"""
- if not isinstance(clss, (list, tuple)):
- clss = (clss,)
-
- for token in self.tokens[idx:end]:
- if isinstance(token, clss):
- return token
+ funcs = lambda tk: imt(tk, i=clss)
+ return self._token_matching(funcs, idx, end)
def token_next_by_type(self, idx, ttypes):
"""Returns next matching token by it's token type."""
- if not isinstance(ttypes, (list, tuple)):
- ttypes = [ttypes]
-
- for token in self.tokens[idx:]:
- if token.ttype in ttypes:
- return token
+ funcs = lambda tk: imt(tk, t=ttypes)
+ return self._token_matching(funcs, idx)
def token_next_match(self, idx, ttype, value, regex=False):
"""Returns next token where it's ``match`` method returns ``True``."""
- if not isinstance(idx, int):
- idx = self.token_index(idx)
-
- for n in range(idx, len(self.tokens)):
- token = self.tokens[n]
- if token.match(ttype, value, regex):
- return token
+ funcs = lambda tk: imt(tk, m=(ttype, value, regex))
+ return self._token_matching(funcs, idx)
def token_not_matching(self, idx, funcs):
- for token in self.tokens[idx:]:
- passed = False
- for func in funcs:
- if func(token):
- passed = True
- break
-
- if not passed:
- return token
+ funcs = (funcs,) if not isinstance(funcs, (list, tuple)) else funcs
+ funcs = [lambda tk: not func(tk) for func in funcs]
+ return self._token_matching(funcs, idx)
def token_matching(self, idx, funcs):
- for token in self.tokens[idx:]:
- for func in funcs:
- if func(token):
- return token
+ return self._token_matching(funcs, idx)
def token_prev(self, idx, skip_ws=True):
"""Returns the previous token relative to *idx*.
@@ -305,17 +298,10 @@ class TokenList(Token):
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
``None`` is returned if there's no previous token.
"""
- if idx is None:
- return None
-
- if not isinstance(idx, int):
- idx = self.token_index(idx)
-
- while idx:
- idx -= 1
- if self.tokens[idx].is_whitespace() and skip_ws:
- continue
- return self.tokens[idx]
+ if isinstance(idx, int):
+ idx += 1 # alot of code usage current pre-compensates for this
+ funcs = lambda tk: not (tk.is_whitespace() and skip_ws)
+ return self._token_matching(funcs, idx, reverse=True)
def token_next(self, idx, skip_ws=True):
"""Returns the next token relative to *idx*.
@@ -323,59 +309,56 @@ class TokenList(Token):
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
``None`` is returned if there's no next token.
"""
- if idx is None:
- return None
-
- if not isinstance(idx, int):
- idx = self.token_index(idx)
-
- while idx < len(self.tokens) - 1:
- idx += 1
- if self.tokens[idx].is_whitespace() and skip_ws:
- continue
- return self.tokens[idx]
+ if isinstance(idx, int):
+ idx += 1 # alot of code usage current pre-compensates for this
+ funcs = lambda tk: not (tk.is_whitespace() and skip_ws)
+ return self._token_matching(funcs, idx)
def token_index(self, token, start=0):
"""Return list index of token."""
- if start > 0:
- # Performing `index` manually is much faster when starting
- # in the middle of the list of tokens and expecting to find
- # the token near to the starting index.
- for i in range(start, len(self.tokens)):
- if self.tokens[i] == token:
- return i
- return -1
- return self.tokens.index(token)
-
- def tokens_between(self, start, end, exclude_end=False):
+ start = self.token_index(start) if not isinstance(start, int) else start
+ return start + self.tokens[start:].index(token)
+
+ def tokens_between(self, start, end, include_end=True):
"""Return all tokens between (and including) start and end.
- If *exclude_end* is ``True`` (default is ``False``) the end token
- is included too.
+ If *include_end* is ``False`` (default is ``True``) the end token
+ is excluded.
"""
- # FIXME(andi): rename exclude_end to inlcude_end
- if exclude_end:
- offset = 0
- else:
- offset = 1
- end_idx = self.token_index(end) + offset
start_idx = self.token_index(start)
+ end_idx = include_end + self.token_index(end)
return self.tokens[start_idx:end_idx]
- def group_tokens(self, grp_cls, tokens, ignore_ws=False):
+ def group_tokens(self, grp_cls, tokens, ignore_ws=False, extend=False):
"""Replace tokens by an instance of *grp_cls*."""
- idx = self.token_index(tokens[0])
if ignore_ws:
while tokens and tokens[-1].is_whitespace():
tokens = tokens[:-1]
- for t in tokens:
- self.tokens.remove(t)
- grp = grp_cls(tokens)
+
+ left = tokens[0]
+ idx = self.token_index(left)
+
+ if extend:
+ if not isinstance(left, grp_cls):
+ grp = grp_cls([left])
+ self.tokens.remove(left)
+ self.tokens.insert(idx, grp)
+ left = grp
+ left.parent = self
+ tokens = tokens[1:]
+ left.tokens.extend(tokens)
+ left.value = left.__str__()
+
+ else:
+ left = grp_cls(tokens)
+ left.parent = self
+ self.tokens.insert(idx, left)
+
for token in tokens:
- token.parent = grp
- grp.parent = self
- self.tokens.insert(idx, grp)
- return grp
+ token.parent = left
+ self.tokens.remove(token)
+
+ return left
def insert_before(self, where, token):
"""Inserts *token* before *where*."""
@@ -397,13 +380,12 @@ class TokenList(Token):
"""Returns the alias for this identifier or ``None``."""
# "name AS alias"
- kw = self.token_next_match(0, T.Keyword, 'AS')
+ kw = self.token_next_by(m=(T.Keyword, 'AS'))
if kw is not None:
return self._get_first_name(kw, keywords=True)
# "name alias" or "complicated column expression alias"
- if len(self.tokens) > 2 \
- and self.token_next_by_type(0, T.Whitespace) is not None:
+ if len(self.tokens) > 2 and self.token_next_by(t=T.Whitespace):
return self._get_first_name(reverse=True)
return None
@@ -440,7 +422,7 @@ class TokenList(Token):
prev_ = self.token_prev(self.token_index(dot))
if prev_ is None: # something must be verry wrong here..
return None
- return self._remove_quotes(prev_.value)
+ return remove_quotes(prev_.value)
def _get_first_name(self, idx=None, reverse=False, keywords=False):
"""Returns the name of the first token with a name"""
@@ -457,7 +439,7 @@ class TokenList(Token):
for tok in tokens:
if tok.ttype in types:
- return self._remove_quotes(tok.value)
+ return remove_quotes(tok.value)
elif isinstance(tok, Identifier) or isinstance(tok, Function):
return tok.get_name()
return None
@@ -510,8 +492,6 @@ class Identifier(TokenList):
Identifiers may have aliases or typecasts.
"""
- __slots__ = ('value', 'ttype', 'tokens')
-
def is_wildcard(self):
"""Return ``True`` if this identifier contains a wildcard."""
token = self.token_next_by_type(0, T.Wildcard)
@@ -546,8 +526,6 @@ class Identifier(TokenList):
class IdentifierList(TokenList):
"""A list of :class:`~sqlparse.sql.Identifier`\'s."""
- __slots__ = ('value', 'ttype', 'tokens')
-
def get_identifiers(self):
"""Returns the identifiers.
@@ -560,7 +538,8 @@ class IdentifierList(TokenList):
class Parenthesis(TokenList):
"""Tokens between parenthesis."""
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Punctuation, '(')
+ M_CLOSE = (T.Punctuation, ')')
@property
def _groupable_tokens(self):
@@ -569,8 +548,8 @@ class Parenthesis(TokenList):
class SquareBrackets(TokenList):
"""Tokens between square brackets"""
-
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Punctuation, '[')
+ M_CLOSE = (T.Punctuation, ']')
@property
def _groupable_tokens(self):
@@ -579,22 +558,22 @@ class SquareBrackets(TokenList):
class Assignment(TokenList):
"""An assignment like 'var := val;'"""
- __slots__ = ('value', 'ttype', 'tokens')
class If(TokenList):
"""An 'if' clause with possible 'else if' or 'else' parts."""
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Keyword, 'IF')
+ M_CLOSE = (T.Keyword, 'END IF')
class For(TokenList):
"""A 'FOR' loop."""
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Keyword, ('FOR', 'FOREACH'))
+ M_CLOSE = (T.Keyword, 'END LOOP')
class Comparison(TokenList):
"""A comparison used for example in WHERE clauses."""
- __slots__ = ('value', 'ttype', 'tokens')
@property
def left(self):
@@ -607,7 +586,6 @@ class Comparison(TokenList):
class Comment(TokenList):
"""A comment."""
- __slots__ = ('value', 'ttype', 'tokens')
def is_multiline(self):
return self.tokens and self.tokens[0].ttype == T.Comment.Multiline
@@ -615,13 +593,15 @@ class Comment(TokenList):
class Where(TokenList):
"""A WHERE clause."""
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Keyword, 'WHERE')
+ M_CLOSE = (T.Keyword,
+ ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', 'HAVING'))
class Case(TokenList):
"""A CASE statement with one or more WHEN and possibly an ELSE part."""
-
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Keyword, 'CASE')
+ M_CLOSE = (T.Keyword, 'END')
def get_cases(self):
"""Returns a list of 2-tuples (condition, value).
@@ -671,22 +651,18 @@ class Case(TokenList):
class Function(TokenList):
"""A function or procedure call."""
- __slots__ = ('value', 'ttype', 'tokens')
-
def get_parameters(self):
"""Return a list of parameters."""
parenthesis = self.tokens[-1]
for t in parenthesis.tokens:
- if isinstance(t, IdentifierList):
+ if imt(t, i=IdentifierList):
return t.get_identifiers()
- elif (isinstance(t, Identifier) or
- isinstance(t, Function) or
- t.ttype in T.Literal):
+ elif imt(t, i=(Function, Identifier), t=T.Literal):
return [t, ]
return []
class Begin(TokenList):
"""A BEGIN/END block."""
-
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Keyword, 'BEGIN')
+ M_CLOSE = (T.Keyword, 'END')
diff --git a/sqlparse/utils.py b/sqlparse/utils.py
index 7db9a96..90acb5c 100644
--- a/sqlparse/utils.py
+++ b/sqlparse/utils.py
@@ -1,16 +1,13 @@
-'''
-Created on 17/05/2012
-
-@author: piranna
-'''
-
+import itertools
import re
-from collections import OrderedDict
+from collections import OrderedDict, deque
+from contextlib import contextmanager
class Cache(OrderedDict):
"""Cache with LRU algorithm using an OrderedDict as basis
"""
+
def __init__(self, maxsize=100):
OrderedDict.__init__(self)
@@ -113,3 +110,85 @@ def split_unquoted_newlines(text):
else:
outputlines[-1] += line
return outputlines
+
+
+def remove_quotes(val):
+ """Helper that removes surrounding quotes from strings."""
+ if val is None:
+ return
+ if val[0] in ('"', "'") and val[0] == val[-1]:
+ val = val[1:-1]
+ return val
+
+
+def recurse(*cls):
+ """Function decorator to help with recursion
+
+ :param cls: Classes to not recurse over
+ :return: function
+ """
+ def wrap(f):
+ def wrapped_f(tlist):
+ for sgroup in tlist.get_sublists():
+ if not isinstance(sgroup, cls):
+ wrapped_f(sgroup)
+ f(tlist)
+
+ return wrapped_f
+
+ return wrap
+
+
+def imt(token, i=None, m=None, t=None):
+ """Aid function to refactor comparisons for Instance, Match and TokenType
+ Aid fun
+ :param token:
+ :param i: Class or Tuple/List of Classes
+ :param m: Tuple of TokenType & Value. Can be list of Tuple for multiple
+ :param t: TokenType or Tuple/List of TokenTypes
+ :return: bool
+ """
+ t = (t,) if t and not isinstance(t, (list, tuple)) else t
+ m = (m,) if m and not isinstance(m, (list,)) else m
+
+ if token is None:
+ return False
+ elif i is not None and isinstance(token, i):
+ return True
+ elif m is not None and any((token.match(*x) for x in m)):
+ return True
+ elif t is not None and token.ttype in t:
+ return True
+ else:
+ return False
+
+
+def find_matching(tlist, token, M1, M2):
+ idx = tlist.token_index(token)
+ depth = 0
+ for token in tlist.tokens[idx:]:
+ if token.match(*M1):
+ depth += 1
+ elif token.match(*M2):
+ depth -= 1
+ if depth == 0:
+ return token
+
+
+def consume(iterator, n):
+ """Advance the iterator n-steps ahead. If n is none, consume entirely."""
+ deque(itertools.islice(iterator, n), maxlen=0)
+
+
+@contextmanager
+def offset(filter_, n=0):
+ filter_.offset += n
+ yield
+ filter_.offset -= n
+
+
+@contextmanager
+def indent(filter_, n=1):
+ filter_.indent += n
+ yield
+ filter_.indent -= n