diff options
| author | Vik <vmuriart@gmail.com> | 2016-06-06 07:45:06 -0700 |
|---|---|---|
| committer | Vik <vmuriart@gmail.com> | 2016-06-06 07:45:06 -0700 |
| commit | d296ead1ddd5373ccac5e03279c782d538f30f98 (patch) | |
| tree | 95e887de388e5f6f1516b1ea862e9d48a4d174eb | |
| parent | b9d81ac4fe49114f57dc33c0d635f99ff56e62f2 (diff) | |
| parent | a7c7d9586208516de372cb01203b48a53f7095fb (diff) | |
| download | sqlparse-d296ead1ddd5373ccac5e03279c782d538f30f98.tar.gz | |
Merge pull request #252 from vmuriart/rb-aligned-format
Rebased and Updated Aligned-Indent Format
| -rwxr-xr-x | bin/sqlformat | 9 | ||||
| -rw-r--r-- | sqlparse/filters.py | 178 | ||||
| -rw-r--r-- | sqlparse/formatter.py | 21 | ||||
| -rw-r--r-- | sqlparse/sql.py | 5 | ||||
| -rw-r--r-- | tests/test_format.py | 228 | ||||
| -rw-r--r-- | tests/test_grouping.py | 16 |
6 files changed, 453 insertions, 4 deletions
diff --git a/bin/sqlformat b/bin/sqlformat index fa20ded..3f61064 100755 --- a/bin/sqlformat +++ b/bin/sqlformat @@ -7,16 +7,13 @@ # the BSD License: http://www.opensource.org/licenses/bsd-license.php import optparse -import os import sys import sqlparse from sqlparse.exceptions import SQLParseError - _CASE_CHOICES = ['upper', 'lower', 'capitalize'] - parser = optparse.OptionParser(usage='%prog [OPTIONS] FILE, ...', version='%%prog %s' % sqlparse.__version__) parser.set_description(('Format FILE according to OPTIONS. Use "-" as FILE ' @@ -45,6 +42,12 @@ group.add_option('-r', '--reindent', dest='reindent', help='reindent statements') group.add_option('--indent_width', dest='indent_width', default=2, help='indentation width (defaults to 2 spaces)') +group.add_option('-a', '--reindent_aligned', + action='store_true', default=False, + help='reindent statements to aligned format') +group.add_option('-s', '--use_space_around_operators', + action='store_true', default=False, + help='place spaces around mathematical operators') group.add_option('--wrap_after', dest='wrap_after', default=0, help='Column after which lists should be wrapped') diff --git a/sqlparse/filters.py b/sqlparse/filters.py index 95ac74c..20f61a0 100644 --- a/sqlparse/filters.py +++ b/sqlparse/filters.py @@ -139,6 +139,35 @@ class StripWhitespaceFilter(object): stmt.tokens.pop(-1) +class SpacesAroundOperatorsFilter(object): + whitelist = (sql.Identifier, sql.Comparison, sql.Where) + + def _process(self, tlist): + def next_token(idx): + return tlist.token_next_by(t=(T.Operator, T.Comparison), idx=idx) + + idx = 0 + token = next_token(idx) + while token: + idx = tlist.token_index(token) + if idx > 0 and tlist.tokens[idx - 1].ttype != T.Whitespace: + # insert before + tlist.tokens.insert(idx, sql.Token(T.Whitespace, ' ')) + idx += 1 + if idx < len(tlist.tokens) - 1: + if tlist.tokens[idx + 1].ttype != T.Whitespace: + tlist.tokens.insert(idx + 1, sql.Token(T.Whitespace, ' ')) + + idx += 1 + token = next_token(idx) + + for sgroup in tlist.get_sublists(): + self._process(sgroup) + + def process(self, stmt): + self._process(stmt) + + class ReindentFilter(object): def __init__(self, width=2, char=' ', line_width=None, wrap_after=0): self.width = width @@ -335,6 +364,155 @@ class ReindentFilter(object): self._last_stmt = stmt +class AlignedIndentFilter(object): + join_words = (r'((LEFT\s+|RIGHT\s+|FULL\s+)?' + r'(INNER\s+|OUTER\s+|STRAIGHT\s+)?|' + r'(CROSS\s+|NATURAL\s+)?)?JOIN\b') + split_words = ('FROM', + join_words, 'ON', + 'WHERE', 'AND', 'OR', + 'GROUP', 'HAVING', 'LIMIT', + 'ORDER', 'UNION', 'VALUES', + 'SET', 'BETWEEN', 'EXCEPT') + + def __init__(self, char=' ', line_width=None): + self.char = char + self._max_kwd_len = len('select') + + def newline(self): + return sql.Token(T.Newline, '\n') + + def whitespace(self, chars=0, newline_before=False, newline_after=False): + return sql.Token(T.Whitespace, ('\n' if newline_before else '') + + self.char * chars + ('\n' if newline_after else '')) + + def _process_statement(self, tlist, base_indent=0): + if tlist.tokens[0].is_whitespace() and base_indent == 0: + tlist.tokens.pop(0) + + # process the main query body + return self._process(sql.TokenList(tlist.tokens), + base_indent=base_indent) + + def _process_parenthesis(self, tlist, base_indent=0): + if not tlist.token_next_by(m=(T.DML, 'SELECT')): + # if this isn't a subquery, don't re-indent + return tlist + + # add two for the space and parens + sub_indent = base_indent + self._max_kwd_len + 2 + tlist.insert_after(tlist.tokens[0], + self.whitespace(sub_indent, newline_before=True)) + # de-indent the last parenthesis + tlist.insert_before(tlist.tokens[-1], + self.whitespace(sub_indent - 1, + newline_before=True)) + + # process the inside of the parantheses + tlist.tokens = ( + [tlist.tokens[0]] + + self._process(sql.TokenList(tlist._groupable_tokens), + base_indent=sub_indent).tokens + + [tlist.tokens[-1]] + ) + return tlist + + def _process_identifierlist(self, tlist, base_indent=0): + # columns being selected + new_tokens = [] + identifiers = list(filter( + lambda t: t.ttype not in (T.Punctuation, T.Whitespace, T.Newline), + tlist.tokens)) + for i, token in enumerate(identifiers): + if i > 0: + new_tokens.append(self.newline()) + new_tokens.append( + self.whitespace(self._max_kwd_len + base_indent + 1)) + new_tokens.append(token) + if i < len(identifiers) - 1: + # if not last column in select, add a comma seperator + new_tokens.append(sql.Token(T.Punctuation, ',')) + tlist.tokens = new_tokens + + # process any sub-sub statements (like case statements) + for sgroup in tlist.get_sublists(): + self._process(sgroup, base_indent=base_indent) + return tlist + + def _process_case(self, tlist, base_indent=0): + base_offset = base_indent + self._max_kwd_len + len('case ') + case_offset = len('when ') + cases = tlist.get_cases(skip_ws=True) + # align the end as well + end_token = tlist.token_next_by(m=(T.Keyword, 'END')) + cases.append((None, [end_token])) + + condition_width = max( + len(' '.join(map(str, cond))) for cond, value in cases if cond) + for i, (cond, value) in enumerate(cases): + if cond is None: # else or end + stmt = value[0] + line = value + else: + stmt = cond[0] + line = cond + value + if i > 0: + tlist.insert_before(stmt, self.whitespace( + base_offset + case_offset - len(str(stmt)))) + if cond: + tlist.insert_after(cond[-1], self.whitespace( + condition_width - len(' '.join(map(str, cond))))) + + if i < len(cases) - 1: + # if not the END add a newline + tlist.insert_after(line[-1], self.newline()) + + def _process_substatement(self, tlist, base_indent=0): + def _next_token(i): + t = tlist.token_next_by(m=(T.Keyword, self.split_words, True), + idx=i) + # treat "BETWEEN x and y" as a single statement + if t and t.value.upper() == 'BETWEEN': + t = _next_token(tlist.token_index(t) + 1) + if t and t.value.upper() == 'AND': + t = _next_token(tlist.token_index(t) + 1) + return t + + idx = 0 + token = _next_token(idx) + while token: + # joins are special case. only consider the first word as aligner + if token.match(T.Keyword, self.join_words, regex=True): + token_indent = len(token.value.split()[0]) + else: + token_indent = len(str(token)) + tlist.insert_before(token, self.whitespace( + self._max_kwd_len - token_indent + base_indent, + newline_before=True)) + next_idx = tlist.token_index(token) + 1 + token = _next_token(next_idx) + + # process any sub-sub statements + for sgroup in tlist.get_sublists(): + prev_token = tlist.token_prev(tlist.token_index(sgroup)) + indent_offset = 0 + # HACK: make "group/order by" work. Longer than _max_kwd_len. + if prev_token and prev_token.match(T.Keyword, 'BY'): + # TODO: generalize this + indent_offset = 3 + self._process(sgroup, base_indent=base_indent + indent_offset) + return tlist + + def _process(self, tlist, base_indent=0): + token_name = tlist.__class__.__name__.lower() + func_name = '_process_%s' % token_name + func = getattr(self, func_name, self._process_substatement) + return func(tlist, base_indent=base_indent) + + def process(self, stmt): + self._process(stmt) + + # FIXME: Doesn't work class RightMarginFilter(object): keep_together = ( diff --git a/sqlparse/formatter.py b/sqlparse/formatter.py index 7441313..069109b 100644 --- a/sqlparse/formatter.py +++ b/sqlparse/formatter.py @@ -30,6 +30,11 @@ def validate_options(options): raise SQLParseError('Invalid value for strip_comments: %r' % strip_comments) + space_around_operators = options.get('use_space_around_operators', False) + if space_around_operators not in [True, False]: + raise SQLParseError('Invalid value for use_space_around_operators: %r' + % space_around_operators) + strip_ws = options.get('strip_whitespace', False) if strip_ws not in [True, False]: raise SQLParseError('Invalid value for strip_whitespace: %r' @@ -55,6 +60,13 @@ def validate_options(options): elif reindent: options['strip_whitespace'] = True + reindent_aligned = options.get('reindent_aligned', False) + if reindent_aligned not in [True, False]: + raise SQLParseError('Invalid value for reindent_aligned: %r' + % reindent) + elif reindent_aligned: + options['strip_whitespace'] = True + indent_tabs = options.get('indent_tabs', False) if indent_tabs not in [True, False]: raise SQLParseError('Invalid value for indent_tabs: %r' % indent_tabs) @@ -114,6 +126,10 @@ def build_filter_stack(stack, options): stack.preprocess.append(filters.TruncateStringFilter( width=options['truncate_strings'], char=options['truncate_char'])) + if options.get('use_space_around_operators', False): + stack.enable_grouping() + stack.stmtprocess.append(filters.SpacesAroundOperatorsFilter()) + # After grouping if options.get('strip_comments'): stack.enable_grouping() @@ -130,6 +146,11 @@ def build_filter_stack(stack, options): width=options['indent_width'], wrap_after=options['wrap_after'])) + if options.get('reindent_aligned', False): + stack.enable_grouping() + stack.stmtprocess.append( + filters.AlignedIndentFilter(char=options['indent_char'])) + if options.get('right_margin'): stack.enable_grouping() stack.stmtprocess.append( diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 57bf1e7..daa5cf5 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -538,7 +538,7 @@ class Case(TokenList): M_OPEN = T.Keyword, 'CASE' M_CLOSE = T.Keyword, 'END' - def get_cases(self): + def get_cases(self, skip_ws=False): """Returns a list of 2-tuples (condition, value). If an ELSE exists condition is None. @@ -554,6 +554,9 @@ class Case(TokenList): if token.match(T.Keyword, 'CASE'): continue + elif skip_ws and token.ttype in T.Whitespace: + continue + elif token.match(T.Keyword, 'WHEN'): ret.append(([], [])) mode = CONDITION diff --git a/tests/test_format.py b/tests/test_format.py index 9043e76..7b5af06 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -106,6 +106,234 @@ class TestFormat(TestCaseBase): output_format='foo') +class TestFormatReindentAligned(TestCaseBase): + @staticmethod + def formatter(sql): + return sqlparse.format(sql, reindent_aligned=True) + + def test_basic(self): + sql = """ + select a, b as bb,c from table + join (select a * 2 as a from new_table) other + on table.a = other.a + where c is true + and b between 3 and 4 + or d is 'blue' + limit 10 + """ + self.ndiffAssertEqual( + self.formatter(sql), + '\n'.join([ + 'select a,', + ' b as bb,', + ' c', + ' from table', + ' join (', + ' select a * 2 as a', + ' from new_table', + ' ) other', + ' on table.a = other.a', + ' where c is true', + ' and b between 3 and 4', + " or d is 'blue'", + ' limit 10', + ])) + + def test_joins(self): + sql = """ + select * from a + join b on a.one = b.one + left join c on c.two = a.two and c.three = a.three + full outer join d on d.three = a.three + cross join e on e.four = a.four + join f using (one, two, three) + """ + self.ndiffAssertEqual( + self.formatter(sql), + '\n'.join([ + 'select *', + ' from a', + ' join b', + ' on a.one = b.one', + ' left join c', + ' on c.two = a.two', + ' and c.three = a.three', + ' full outer join d', + ' on d.three = a.three', + ' cross join e', + ' on e.four = a.four', + ' join f using (one, two, three)', + ])) + + def test_case_statement(self): + sql = """ + select a, + case when a = 0 + then 1 + when bb = 1 then 1 + when c = 2 then 2 + else 0 end as d, + extra_col + from table + where c is true + and b between 3 and 4 + """ + self.ndiffAssertEqual( + self.formatter(sql), + '\n'.join([ + 'select a,', + ' case when a = 0 then 1', + ' when bb = 1 then 1', + ' when c = 2 then 2', + ' else 0', + ' end as d,', + ' extra_col', + ' from table', + ' where c is true', + ' and b between 3 and 4' + ])) + + def test_case_statement_with_between(self): + sql = """ + select a, + case when a = 0 + then 1 + when bb = 1 then 1 + when c = 2 then 2 + when d between 3 and 5 then 3 + else 0 end as d, + extra_col + from table + where c is true + and b between 3 and 4 + """ + self.ndiffAssertEqual( + self.formatter(sql), + '\n'.join([ + 'select a,', + ' case when a = 0 then 1', + ' when bb = 1 then 1', + ' when c = 2 then 2', + ' when d between 3 and 5 then 3', + ' else 0', + ' end as d,', + ' extra_col', + ' from table', + ' where c is true', + ' and b between 3 and 4' + ])) + + def test_group_by(self): + sql = """ + select a, b, c, sum(x) as sum_x, count(y) as cnt_y + from table + group by a,b,c + having sum(x) > 1 + and count(y) > 5 + order by 3,2,1 + """ + self.ndiffAssertEqual( + self.formatter(sql), + '\n'.join([ + 'select a,', + ' b,', + ' c,', + ' sum(x) as sum_x,', + ' count(y) as cnt_y', + ' from table', + ' group by a,', + ' b,', + ' c', + 'having sum(x) > 1', + ' and count(y) > 5', + ' order by 3,', + ' 2,', + ' 1', + ])) + + def test_group_by_subquery(self): + # TODO: add subquery alias when test_identifier_list_subquery fixed + sql = """ + select *, sum_b + 2 as mod_sum + from ( + select a, sum(b) as sum_b + from table + group by a,z) + order by 1,2 + """ + self.ndiffAssertEqual( + self.formatter(sql), + '\n'.join([ + 'select *,', + ' sum_b + 2 as mod_sum', + ' from (', + ' select a,', + ' sum(b) as sum_b', + ' from table', + ' group by a,', + ' z', + ' )', + ' order by 1,', + ' 2', + ])) + + def test_window_functions(self): + sql = """ + select a, + SUM(a) OVER (PARTITION BY b ORDER BY c ROWS + BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a, + ROW_NUMBER() OVER + (PARTITION BY b, c ORDER BY d DESC) as row_num + from table + """ + self.ndiffAssertEqual( + self.formatter(sql), + '\n'.join([ + 'select a,', + (' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS ' + 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,'), + (' ROW_NUMBER() OVER ' + '(PARTITION BY b, c ORDER BY d DESC) as row_num'), + ' from table', + ])) + + +class TestSpacesAroundOperators(TestCaseBase): + @staticmethod + def formatter(sql): + return sqlparse.format(sql, use_space_around_operators=True) + + def test_basic(self): + sql = ('select a+b as d from table ' + 'where (c-d)%2= 1 and e> 3.0/4 and z^2 <100') + self.ndiffAssertEqual( + self.formatter(sql), ( + 'select a + b as d from table ' + 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100') + ) + + def test_bools(self): + sql = 'select * from table where a &&b or c||d' + self.ndiffAssertEqual( + self.formatter(sql), + 'select * from table where a && b or c || d' + ) + + def test_nested(self): + sql = 'select *, case when a-b then c end from table' + self.ndiffAssertEqual( + self.formatter(sql), + 'select *, case when a - b then c end from table' + ) + + def test_wildcard_vs_mult(self): + sql = 'select a*b-c from table' + self.ndiffAssertEqual( + self.formatter(sql), + 'select a * b - c from table' + ) + + class TestFormatReindent(TestCaseBase): def test_option(self): diff --git a/tests/test_grouping.py b/tests/test_grouping.py index c35c61b..fdcd4a7 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -122,6 +122,17 @@ class TestGrouping(TestCaseBase): p = sqlparse.parse('(a, b, c)')[0] self.assert_(isinstance(p.tokens[0].tokens[1], sql.IdentifierList)) + def test_identifier_list_subquery(self): + """identifier lists should still work in subqueries with aliases""" + p = sqlparse.parse("select * from (" + "select a, b + c as d from table) sub")[0] + subquery = p.tokens[-1].tokens[0] + iden_list = subquery.token_next_by(i=sql.IdentifierList) + self.assert_(iden_list is not None) + # all the identifiers should be within the IdentifierList + self.assert_(subquery.token_next_by(i=sql.Identifier, + idx=iden_list) is None) + def test_identifier_list_case(self): p = sqlparse.parse('a, case when 1 then 2 else 3 end as b, c')[0] self.assert_(isinstance(p.tokens[0], sql.IdentifierList)) @@ -140,6 +151,11 @@ class TestGrouping(TestCaseBase): self.assert_(isinstance(p.tokens[0].tokens[0], sql.Identifier)) self.assert_(isinstance(p.tokens[0].tokens[3], sql.Identifier)) + def test_identifiers_with_operators(self): + p = sqlparse.parse('a+b as c from table where (d-e)%2= 1')[0] + self.assertEqual(len([x for x in p.flatten() + if x.ttype == sqlparse.tokens.Name]), 5) + def test_identifier_list_with_order(self): # issue101 p = sqlparse.parse('1, 2 desc, 3')[0] self.assert_(isinstance(p.tokens[0], sql.IdentifierList)) |
