diff options
| author | Adam Greenhall <agreenhall@lyft.com> | 2015-09-12 14:49:47 -0700 |
|---|---|---|
| committer | Victor Uriarte <victor.m.uriarte@intel.com> | 2016-06-06 06:31:35 -0700 |
| commit | 2928a7c8f1192b8376795368825c2cf2dae243c3 (patch) | |
| tree | 00e1eb738b20012ed671e0e36ee27c99f619ee14 | |
| parent | 4c70aeaa2a8b1652553cc5a10bd593694cb3073f (diff) | |
| download | sqlparse-2928a7c8f1192b8376795368825c2cf2dae243c3.tar.gz | |
Add filter `Spaces around Operators`
| -rw-r--r-- | sqlparse/filters.py | 30 | ||||
| -rw-r--r-- | sqlparse/formatter.py | 9 | ||||
| -rw-r--r-- | tests/test_format.py | 34 | ||||
| -rw-r--r-- | tests/test_grouping.py | 4 |
4 files changed, 77 insertions, 0 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py index 193029f..464a570 100644 --- a/sqlparse/filters.py +++ b/sqlparse/filters.py @@ -139,6 +139,36 @@ class StripWhitespaceFilter(object): stmt.tokens.pop(-1) +class SpacesAroundOperatorsFilter: + whitelist = (sql.Identifier, sql.Comparison, sql.Where) + + def _process(self, tlist): + def next_token(idx): + # HACK: distinguish between real wildcard from multiplication operator + return tlist.token_next_by_type(idx, (T.Operator, T.Comparison, T.Wildcard)) + idx = 0 + token = next_token(idx) + while token: + idx = tlist.token_index(token) + if idx > 0 and tlist.tokens[idx - 1].ttype != T.Whitespace: + tlist.tokens.insert(idx, sql.Token(T.Whitespace, ' ')) # insert before + idx += 1 + if idx < len(tlist.tokens) - 1: + if token.ttype == T.Wildcard and tlist.tokens[idx + 1].match(T.Punctuation, ','): + pass # this must have been a real wildcard, not multiplication + elif tlist.tokens[idx + 1].ttype != T.Whitespace: + tlist.tokens.insert(idx + 1, sql.Token(T.Whitespace, ' ')) + + idx += 1 + token = next_token(idx) + + for sgroup in tlist.get_sublists(): + self._process(sgroup) + + def process(self, stack, stmt): + self._process(stmt) + + class ReindentFilter(object): def __init__(self, width=2, char=' ', line_width=None, wrap_after=0): self.width = width diff --git a/sqlparse/formatter.py b/sqlparse/formatter.py index 5af8743..0fa563c 100644 --- a/sqlparse/formatter.py +++ b/sqlparse/formatter.py @@ -30,6 +30,11 @@ def validate_options(options): raise SQLParseError('Invalid value for strip_comments: %r' % strip_comments) + use_space_around_operators = options.get('use_space_around_operators', False) + if use_space_around_operators not in [True, False]: + raise SQLParseError('Invalid value for use_space_around_operators: %r' + % use_space_around_operators) + strip_ws = options.get('strip_whitespace', False) if strip_ws not in [True, False]: raise SQLParseError('Invalid value for strip_whitespace: %r' @@ -121,6 +126,10 @@ def build_filter_stack(stack, options): stack.preprocess.append(filters.TruncateStringFilter( width=options['truncate_strings'], char=options['truncate_char'])) + if options.get('use_space_around_operators', False): + stack.enable_grouping() + stack.stmtprocess.append(filters.SpacesAroundOperatorsFilter()) + # After grouping if options.get('strip_comments'): stack.enable_grouping() diff --git a/tests/test_format.py b/tests/test_format.py index 8151bb4..22ab5b6 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -294,6 +294,40 @@ class TestFormatReindentAligned(TestCaseBase): ])) +class TestSpacesAroundOperators(TestCaseBase): + @staticmethod + def formatter(sql): + return sqlparse.format(sql, use_space_around_operators=True) + + def test_basic(self): + sql = 'select a+b as d from table where (c-d)%2= 1 and e> 3.0/4 and z^2 <100' + self.ndiffAssertEqual( + self.formatter(sql), + 'select a + b as d from table where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100' + ) + + def test_bools(self): + sql = 'select * from table where a &&b or c||d' + self.ndiffAssertEqual( + self.formatter(sql), + 'select * from table where a && b or c || d' + ) + + def test_nested(self): + sql = 'select *, case when a-b then c end from table' + self.ndiffAssertEqual( + self.formatter(sql), + 'select *, case when a - b then c end from table' + ) + + def test_wildcard_vs_mult(self): + sql = 'select a*b-c from table' + self.ndiffAssertEqual( + self.formatter(sql), + 'select a * b - c from table' + ) + + class TestFormatReindent(TestCaseBase): def test_option(self): diff --git a/tests/test_grouping.py b/tests/test_grouping.py index 1971fb7..40a35cf 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -150,6 +150,10 @@ class TestGrouping(TestCaseBase): self.assert_(isinstance(p.tokens[0].tokens[0], sql.Identifier)) self.assert_(isinstance(p.tokens[0].tokens[3], sql.Identifier)) + def test_identifiers_with_operators(self): + p = sqlparse.parse('a+b as c from table where (d-e)%2= 1')[0] + self.assertEqual(len([x for x in p.flatten() if x.ttype == sqlparse.tokens.Name]), 5) + def test_identifier_list_with_order(self): # issue101 p = sqlparse.parse('1, 2 desc, 3')[0] self.assert_(isinstance(p.tokens[0], sql.IdentifierList)) |
