diff options
| author | Victor Uriarte <victor.m.uriarte@intel.com> | 2016-06-14 06:26:41 -0700 |
|---|---|---|
| committer | Victor Uriarte <victor.m.uriarte@intel.com> | 2016-06-15 13:29:18 -0700 |
| commit | af9b82e0b2d00732704fedf7d7b03dcb598dca84 (patch) | |
| tree | 7be73b65279a63ecb9fe8c904f0c556c0ab53323 /sqlparse | |
| parent | 56b28dc15023d36bab8764bea6df75e28651646e (diff) | |
| download | sqlparse-af9b82e0b2d00732704fedf7d7b03dcb598dca84.tar.gz | |
Reorder grouping code and func call order
Remove repeated for-each/for grouping
Diffstat (limited to 'sqlparse')
| -rw-r--r-- | sqlparse/engine/grouping.py | 133 |
1 files changed, 66 insertions, 67 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index c52a759..7879f76 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -9,41 +9,11 @@ from sqlparse import sql from sqlparse import tokens as T from sqlparse.utils import recurse, imt -M_ROLE = (T.Keyword, ('null', 'role')) -M_SEMICOLON = (T.Punctuation, ';') -M_COMMA = (T.Punctuation, ',') - T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float) T_STRING = (T.String, T.String.Single, T.String.Symbol) T_NAME = (T.Name, T.Name.Placeholder) -def _group_left_right(tlist, m, cls, - valid_left=lambda t: t is not None, - valid_right=lambda t: t is not None, - semicolon=False): - """Groups together tokens that are joined by a middle token. ie. x < y""" - for token in list(tlist): - if token.is_group() and not isinstance(token, cls): - _group_left_right(token, m, cls, valid_left, valid_right, - semicolon) - continue - if not token.match(*m): - continue - - tidx = tlist.token_index(token) - pidx, prev_ = tlist.token_prev(tidx) - nidx, next_ = tlist.token_next(tidx) - - if valid_left(prev_) and valid_right(next_): - if semicolon: - # only overwrite if a semicolon present. - snidx, _ = tlist.token_next_by(m=M_SEMICOLON, idx=nidx) - nidx = snidx or nidx - # Luckily, this leaves the position of `token` intact. - tlist.group_tokens(cls, pidx, nidx, extend=True) - - def _group_matching(tlist, cls): """Groups Tokens that have beginning and end.""" opens = [] @@ -69,6 +39,18 @@ def _group_matching(tlist, cls): tlist.group_tokens(cls, oidx, cidx) +def group_brackets(tlist): + _group_matching(tlist, sql.SquareBrackets) + + +def group_parenthesis(tlist): + _group_matching(tlist, sql.Parenthesis) + + +def group_case(tlist): + _group_matching(tlist, sql.Case) + + def group_if(tlist): _group_matching(tlist, sql.If) @@ -77,16 +59,54 @@ def group_for(tlist): _group_matching(tlist, sql.For) -def group_foreach(tlist): - _group_matching(tlist, sql.For) - - def group_begin(tlist): _group_matching(tlist, sql.Begin) +def _group_left_right(tlist, m, cls, + valid_left=lambda t: t is not None, + valid_right=lambda t: t is not None, + semicolon=False): + """Groups together tokens that are joined by a middle token. ie. x < y""" + for token in list(tlist): + if token.is_group() and not isinstance(token, cls): + _group_left_right(token, m, cls, valid_left, valid_right, + semicolon) + continue + if not token.match(*m): + continue + + tidx = tlist.token_index(token) + pidx, prev_ = tlist.token_prev(tidx) + nidx, next_ = tlist.token_next(tidx) + + if valid_left(prev_) and valid_right(next_): + if semicolon: + # only overwrite if a semicolon present. + m_semicolon = T.Punctuation, ';' + snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx) + nidx = snidx or nidx + # Luckily, this leaves the position of `token` intact. + tlist.group_tokens(cls, pidx, nidx, extend=True) + + +def group_typecasts(tlist): + _group_left_right(tlist, (T.Punctuation, '::'), sql.Identifier) + + +def group_period(tlist): + lfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Identifier), + t=(T.Name, T.String.Symbol,)) + + rfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Function), + t=(T.Name, T.String.Symbol, T.Wildcard)) + + _group_left_right(tlist, (T.Punctuation, '.'), sql.Identifier, + valid_left=lfunc, valid_right=rfunc) + + def group_as(tlist): - lfunc = lambda tk: not imt(tk, t=T.Keyword) or tk.value == 'NULL' + lfunc = lambda tk: not imt(tk, t=T.Keyword) or tk.normalized == 'NULL' rfunc = lambda tk: not imt(tk, t=(T.DML, T.DDL)) _group_left_right(tlist, (T.Keyword, 'AS'), sql.Identifier, valid_left=lfunc, valid_right=rfunc) @@ -109,10 +129,6 @@ def group_comparison(tlist): valid_left=func, valid_right=func) -def group_case(tlist): - _group_matching(tlist, sql.Case) - - @recurse(sql.Identifier) def group_identifier(tlist): T_IDENT = (T.String.Symbol, T.Name) @@ -123,17 +139,6 @@ def group_identifier(tlist): tidx, token = tlist.token_next_by(t=T_IDENT, idx=tidx) -def group_period(tlist): - lfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Identifier), - t=(T.Name, T.String.Symbol,)) - - rfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Function), - t=(T.Name, T.String.Symbol, T.Wildcard)) - - _group_left_right(tlist, (T.Punctuation, '.'), sql.Identifier, - valid_left=lfunc, valid_right=rfunc) - - def group_arrays(tlist): tidx, token = tlist.token_next_by(i=sql.SquareBrackets) while token: @@ -168,6 +173,9 @@ def group_operator(tlist): @recurse(sql.IdentifierList) def group_identifier_list(tlist): + M_ROLE = T.Keyword, ('null', 'role') + M_COMMA = T.Punctuation, ',' + I_IDENT_LIST = (sql.Function, sql.Case, sql.Identifier, sql.Comparison, sql.IdentifierList, sql.Operation) T_IDENT_LIST = (T_NUMERICAL + T_STRING + T_NAME + @@ -186,14 +194,6 @@ def group_identifier_list(tlist): tidx, token = tlist.token_next_by(m=M_COMMA, idx=tidx) -def group_brackets(tlist): - _group_matching(tlist, sql.SquareBrackets) - - -def group_parenthesis(tlist): - _group_matching(tlist, sql.Parenthesis) - - @recurse(sql.Comment) def group_comments(tlist): tidx, token = tlist.token_next_by(t=T.Comment) @@ -237,10 +237,6 @@ def group_aliased(tlist): tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx) -def group_typecasts(tlist): - _group_left_right(tlist, (T.Punctuation, '::'), sql.Identifier) - - @recurse(sql.Function) def group_functions(tlist): has_create = False @@ -286,11 +282,17 @@ def align_comments(tlist): def group(stmt): for func in [ group_comments, + + # _group_matching group_brackets, group_parenthesis, + group_case, + group_if, + group_for, + group_begin, + group_functions, group_where, - group_case, group_period, group_arrays, group_identifier, @@ -301,12 +303,9 @@ def group(stmt): group_aliased, group_assignment, group_comparison, + align_comments, group_identifier_list, - group_if, - group_for, - group_foreach, - group_begin, ]: func(stmt) return stmt |
