summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Greenhall <agreenhall@lyft.com>2015-09-12 14:49:47 -0700
committerVictor Uriarte <victor.m.uriarte@intel.com>2016-06-06 06:31:35 -0700
commit2928a7c8f1192b8376795368825c2cf2dae243c3 (patch)
tree00e1eb738b20012ed671e0e36ee27c99f619ee14
parent4c70aeaa2a8b1652553cc5a10bd593694cb3073f (diff)
downloadsqlparse-2928a7c8f1192b8376795368825c2cf2dae243c3.tar.gz
Add filter `Spaces around Operators`
-rw-r--r--sqlparse/filters.py30
-rw-r--r--sqlparse/formatter.py9
-rw-r--r--tests/test_format.py34
-rw-r--r--tests/test_grouping.py4
4 files changed, 77 insertions, 0 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index 193029f..464a570 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -139,6 +139,36 @@ class StripWhitespaceFilter(object):
stmt.tokens.pop(-1)
+class SpacesAroundOperatorsFilter:
+ whitelist = (sql.Identifier, sql.Comparison, sql.Where)
+
+ def _process(self, tlist):
+ def next_token(idx):
+ # HACK: distinguish between real wildcard from multiplication operator
+ return tlist.token_next_by_type(idx, (T.Operator, T.Comparison, T.Wildcard))
+ idx = 0
+ token = next_token(idx)
+ while token:
+ idx = tlist.token_index(token)
+ if idx > 0 and tlist.tokens[idx - 1].ttype != T.Whitespace:
+ tlist.tokens.insert(idx, sql.Token(T.Whitespace, ' ')) # insert before
+ idx += 1
+ if idx < len(tlist.tokens) - 1:
+ if token.ttype == T.Wildcard and tlist.tokens[idx + 1].match(T.Punctuation, ','):
+ pass # this must have been a real wildcard, not multiplication
+ elif tlist.tokens[idx + 1].ttype != T.Whitespace:
+ tlist.tokens.insert(idx + 1, sql.Token(T.Whitespace, ' '))
+
+ idx += 1
+ token = next_token(idx)
+
+ for sgroup in tlist.get_sublists():
+ self._process(sgroup)
+
+ def process(self, stack, stmt):
+ self._process(stmt)
+
+
class ReindentFilter(object):
def __init__(self, width=2, char=' ', line_width=None, wrap_after=0):
self.width = width
diff --git a/sqlparse/formatter.py b/sqlparse/formatter.py
index 5af8743..0fa563c 100644
--- a/sqlparse/formatter.py
+++ b/sqlparse/formatter.py
@@ -30,6 +30,11 @@ def validate_options(options):
raise SQLParseError('Invalid value for strip_comments: %r'
% strip_comments)
+ use_space_around_operators = options.get('use_space_around_operators', False)
+ if use_space_around_operators not in [True, False]:
+ raise SQLParseError('Invalid value for use_space_around_operators: %r'
+ % use_space_around_operators)
+
strip_ws = options.get('strip_whitespace', False)
if strip_ws not in [True, False]:
raise SQLParseError('Invalid value for strip_whitespace: %r'
@@ -121,6 +126,10 @@ def build_filter_stack(stack, options):
stack.preprocess.append(filters.TruncateStringFilter(
width=options['truncate_strings'], char=options['truncate_char']))
+ if options.get('use_space_around_operators', False):
+ stack.enable_grouping()
+ stack.stmtprocess.append(filters.SpacesAroundOperatorsFilter())
+
# After grouping
if options.get('strip_comments'):
stack.enable_grouping()
diff --git a/tests/test_format.py b/tests/test_format.py
index 8151bb4..22ab5b6 100644
--- a/tests/test_format.py
+++ b/tests/test_format.py
@@ -294,6 +294,40 @@ class TestFormatReindentAligned(TestCaseBase):
]))
+class TestSpacesAroundOperators(TestCaseBase):
+ @staticmethod
+ def formatter(sql):
+ return sqlparse.format(sql, use_space_around_operators=True)
+
+ def test_basic(self):
+ sql = 'select a+b as d from table where (c-d)%2= 1 and e> 3.0/4 and z^2 <100'
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ 'select a + b as d from table where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100'
+ )
+
+ def test_bools(self):
+ sql = 'select * from table where a &&b or c||d'
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ 'select * from table where a && b or c || d'
+ )
+
+ def test_nested(self):
+ sql = 'select *, case when a-b then c end from table'
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ 'select *, case when a - b then c end from table'
+ )
+
+ def test_wildcard_vs_mult(self):
+ sql = 'select a*b-c from table'
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ 'select a * b - c from table'
+ )
+
+
class TestFormatReindent(TestCaseBase):
def test_option(self):
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index 1971fb7..40a35cf 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -150,6 +150,10 @@ class TestGrouping(TestCaseBase):
self.assert_(isinstance(p.tokens[0].tokens[0], sql.Identifier))
self.assert_(isinstance(p.tokens[0].tokens[3], sql.Identifier))
+ def test_identifiers_with_operators(self):
+ p = sqlparse.parse('a+b as c from table where (d-e)%2= 1')[0]
+ self.assertEqual(len([x for x in p.flatten() if x.ttype == sqlparse.tokens.Name]), 5)
+
def test_identifier_list_with_order(self): # issue101
p = sqlparse.parse('1, 2 desc, 3')[0]
self.assert_(isinstance(p.tokens[0], sql.IdentifierList))