summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVik <vmuriart@gmail.com>2016-06-06 07:45:06 -0700
committerVik <vmuriart@gmail.com>2016-06-06 07:45:06 -0700
commitd296ead1ddd5373ccac5e03279c782d538f30f98 (patch)
tree95e887de388e5f6f1516b1ea862e9d48a4d174eb
parentb9d81ac4fe49114f57dc33c0d635f99ff56e62f2 (diff)
parenta7c7d9586208516de372cb01203b48a53f7095fb (diff)
downloadsqlparse-d296ead1ddd5373ccac5e03279c782d538f30f98.tar.gz
Merge pull request #252 from vmuriart/rb-aligned-format
Rebased and Updated Aligned-Indent Format
-rwxr-xr-xbin/sqlformat9
-rw-r--r--sqlparse/filters.py178
-rw-r--r--sqlparse/formatter.py21
-rw-r--r--sqlparse/sql.py5
-rw-r--r--tests/test_format.py228
-rw-r--r--tests/test_grouping.py16
6 files changed, 453 insertions, 4 deletions
diff --git a/bin/sqlformat b/bin/sqlformat
index fa20ded..3f61064 100755
--- a/bin/sqlformat
+++ b/bin/sqlformat
@@ -7,16 +7,13 @@
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
import optparse
-import os
import sys
import sqlparse
from sqlparse.exceptions import SQLParseError
-
_CASE_CHOICES = ['upper', 'lower', 'capitalize']
-
parser = optparse.OptionParser(usage='%prog [OPTIONS] FILE, ...',
version='%%prog %s' % sqlparse.__version__)
parser.set_description(('Format FILE according to OPTIONS. Use "-" as FILE '
@@ -45,6 +42,12 @@ group.add_option('-r', '--reindent', dest='reindent',
help='reindent statements')
group.add_option('--indent_width', dest='indent_width', default=2,
help='indentation width (defaults to 2 spaces)')
+group.add_option('-a', '--reindent_aligned',
+ action='store_true', default=False,
+ help='reindent statements to aligned format')
+group.add_option('-s', '--use_space_around_operators',
+ action='store_true', default=False,
+ help='place spaces around mathematical operators')
group.add_option('--wrap_after', dest='wrap_after', default=0,
help='Column after which lists should be wrapped')
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index 95ac74c..20f61a0 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -139,6 +139,35 @@ class StripWhitespaceFilter(object):
stmt.tokens.pop(-1)
+class SpacesAroundOperatorsFilter(object):
+ whitelist = (sql.Identifier, sql.Comparison, sql.Where)
+
+ def _process(self, tlist):
+ def next_token(idx):
+ return tlist.token_next_by(t=(T.Operator, T.Comparison), idx=idx)
+
+ idx = 0
+ token = next_token(idx)
+ while token:
+ idx = tlist.token_index(token)
+ if idx > 0 and tlist.tokens[idx - 1].ttype != T.Whitespace:
+ # insert before
+ tlist.tokens.insert(idx, sql.Token(T.Whitespace, ' '))
+ idx += 1
+ if idx < len(tlist.tokens) - 1:
+ if tlist.tokens[idx + 1].ttype != T.Whitespace:
+ tlist.tokens.insert(idx + 1, sql.Token(T.Whitespace, ' '))
+
+ idx += 1
+ token = next_token(idx)
+
+ for sgroup in tlist.get_sublists():
+ self._process(sgroup)
+
+ def process(self, stmt):
+ self._process(stmt)
+
+
class ReindentFilter(object):
def __init__(self, width=2, char=' ', line_width=None, wrap_after=0):
self.width = width
@@ -335,6 +364,155 @@ class ReindentFilter(object):
self._last_stmt = stmt
+class AlignedIndentFilter(object):
+ join_words = (r'((LEFT\s+|RIGHT\s+|FULL\s+)?'
+ r'(INNER\s+|OUTER\s+|STRAIGHT\s+)?|'
+ r'(CROSS\s+|NATURAL\s+)?)?JOIN\b')
+ split_words = ('FROM',
+ join_words, 'ON',
+ 'WHERE', 'AND', 'OR',
+ 'GROUP', 'HAVING', 'LIMIT',
+ 'ORDER', 'UNION', 'VALUES',
+ 'SET', 'BETWEEN', 'EXCEPT')
+
+ def __init__(self, char=' ', line_width=None):
+ self.char = char
+ self._max_kwd_len = len('select')
+
+ def newline(self):
+ return sql.Token(T.Newline, '\n')
+
+ def whitespace(self, chars=0, newline_before=False, newline_after=False):
+ return sql.Token(T.Whitespace, ('\n' if newline_before else '') +
+ self.char * chars + ('\n' if newline_after else ''))
+
+ def _process_statement(self, tlist, base_indent=0):
+ if tlist.tokens[0].is_whitespace() and base_indent == 0:
+ tlist.tokens.pop(0)
+
+ # process the main query body
+ return self._process(sql.TokenList(tlist.tokens),
+ base_indent=base_indent)
+
+ def _process_parenthesis(self, tlist, base_indent=0):
+ if not tlist.token_next_by(m=(T.DML, 'SELECT')):
+ # if this isn't a subquery, don't re-indent
+ return tlist
+
+ # add two for the space and parens
+ sub_indent = base_indent + self._max_kwd_len + 2
+ tlist.insert_after(tlist.tokens[0],
+ self.whitespace(sub_indent, newline_before=True))
+ # de-indent the last parenthesis
+ tlist.insert_before(tlist.tokens[-1],
+ self.whitespace(sub_indent - 1,
+ newline_before=True))
+
+ # process the inside of the parantheses
+ tlist.tokens = (
+ [tlist.tokens[0]] +
+ self._process(sql.TokenList(tlist._groupable_tokens),
+ base_indent=sub_indent).tokens +
+ [tlist.tokens[-1]]
+ )
+ return tlist
+
+ def _process_identifierlist(self, tlist, base_indent=0):
+ # columns being selected
+ new_tokens = []
+ identifiers = list(filter(
+ lambda t: t.ttype not in (T.Punctuation, T.Whitespace, T.Newline),
+ tlist.tokens))
+ for i, token in enumerate(identifiers):
+ if i > 0:
+ new_tokens.append(self.newline())
+ new_tokens.append(
+ self.whitespace(self._max_kwd_len + base_indent + 1))
+ new_tokens.append(token)
+ if i < len(identifiers) - 1:
+ # if not last column in select, add a comma seperator
+ new_tokens.append(sql.Token(T.Punctuation, ','))
+ tlist.tokens = new_tokens
+
+ # process any sub-sub statements (like case statements)
+ for sgroup in tlist.get_sublists():
+ self._process(sgroup, base_indent=base_indent)
+ return tlist
+
+ def _process_case(self, tlist, base_indent=0):
+ base_offset = base_indent + self._max_kwd_len + len('case ')
+ case_offset = len('when ')
+ cases = tlist.get_cases(skip_ws=True)
+ # align the end as well
+ end_token = tlist.token_next_by(m=(T.Keyword, 'END'))
+ cases.append((None, [end_token]))
+
+ condition_width = max(
+ len(' '.join(map(str, cond))) for cond, value in cases if cond)
+ for i, (cond, value) in enumerate(cases):
+ if cond is None: # else or end
+ stmt = value[0]
+ line = value
+ else:
+ stmt = cond[0]
+ line = cond + value
+ if i > 0:
+ tlist.insert_before(stmt, self.whitespace(
+ base_offset + case_offset - len(str(stmt))))
+ if cond:
+ tlist.insert_after(cond[-1], self.whitespace(
+ condition_width - len(' '.join(map(str, cond)))))
+
+ if i < len(cases) - 1:
+ # if not the END add a newline
+ tlist.insert_after(line[-1], self.newline())
+
+ def _process_substatement(self, tlist, base_indent=0):
+ def _next_token(i):
+ t = tlist.token_next_by(m=(T.Keyword, self.split_words, True),
+ idx=i)
+ # treat "BETWEEN x and y" as a single statement
+ if t and t.value.upper() == 'BETWEEN':
+ t = _next_token(tlist.token_index(t) + 1)
+ if t and t.value.upper() == 'AND':
+ t = _next_token(tlist.token_index(t) + 1)
+ return t
+
+ idx = 0
+ token = _next_token(idx)
+ while token:
+ # joins are special case. only consider the first word as aligner
+ if token.match(T.Keyword, self.join_words, regex=True):
+ token_indent = len(token.value.split()[0])
+ else:
+ token_indent = len(str(token))
+ tlist.insert_before(token, self.whitespace(
+ self._max_kwd_len - token_indent + base_indent,
+ newline_before=True))
+ next_idx = tlist.token_index(token) + 1
+ token = _next_token(next_idx)
+
+ # process any sub-sub statements
+ for sgroup in tlist.get_sublists():
+ prev_token = tlist.token_prev(tlist.token_index(sgroup))
+ indent_offset = 0
+ # HACK: make "group/order by" work. Longer than _max_kwd_len.
+ if prev_token and prev_token.match(T.Keyword, 'BY'):
+ # TODO: generalize this
+ indent_offset = 3
+ self._process(sgroup, base_indent=base_indent + indent_offset)
+ return tlist
+
+ def _process(self, tlist, base_indent=0):
+ token_name = tlist.__class__.__name__.lower()
+ func_name = '_process_%s' % token_name
+ func = getattr(self, func_name, self._process_substatement)
+ return func(tlist, base_indent=base_indent)
+
+ def process(self, stmt):
+ self._process(stmt)
+
+
# FIXME: Doesn't work
class RightMarginFilter(object):
keep_together = (
diff --git a/sqlparse/formatter.py b/sqlparse/formatter.py
index 7441313..069109b 100644
--- a/sqlparse/formatter.py
+++ b/sqlparse/formatter.py
@@ -30,6 +30,11 @@ def validate_options(options):
raise SQLParseError('Invalid value for strip_comments: %r'
% strip_comments)
+ space_around_operators = options.get('use_space_around_operators', False)
+ if space_around_operators not in [True, False]:
+ raise SQLParseError('Invalid value for use_space_around_operators: %r'
+ % space_around_operators)
+
strip_ws = options.get('strip_whitespace', False)
if strip_ws not in [True, False]:
raise SQLParseError('Invalid value for strip_whitespace: %r'
@@ -55,6 +60,13 @@ def validate_options(options):
elif reindent:
options['strip_whitespace'] = True
+ reindent_aligned = options.get('reindent_aligned', False)
+ if reindent_aligned not in [True, False]:
+ raise SQLParseError('Invalid value for reindent_aligned: %r'
+ % reindent)
+ elif reindent_aligned:
+ options['strip_whitespace'] = True
+
indent_tabs = options.get('indent_tabs', False)
if indent_tabs not in [True, False]:
raise SQLParseError('Invalid value for indent_tabs: %r' % indent_tabs)
@@ -114,6 +126,10 @@ def build_filter_stack(stack, options):
stack.preprocess.append(filters.TruncateStringFilter(
width=options['truncate_strings'], char=options['truncate_char']))
+ if options.get('use_space_around_operators', False):
+ stack.enable_grouping()
+ stack.stmtprocess.append(filters.SpacesAroundOperatorsFilter())
+
# After grouping
if options.get('strip_comments'):
stack.enable_grouping()
@@ -130,6 +146,11 @@ def build_filter_stack(stack, options):
width=options['indent_width'],
wrap_after=options['wrap_after']))
+ if options.get('reindent_aligned', False):
+ stack.enable_grouping()
+ stack.stmtprocess.append(
+ filters.AlignedIndentFilter(char=options['indent_char']))
+
if options.get('right_margin'):
stack.enable_grouping()
stack.stmtprocess.append(
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 57bf1e7..daa5cf5 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -538,7 +538,7 @@ class Case(TokenList):
M_OPEN = T.Keyword, 'CASE'
M_CLOSE = T.Keyword, 'END'
- def get_cases(self):
+ def get_cases(self, skip_ws=False):
"""Returns a list of 2-tuples (condition, value).
If an ELSE exists condition is None.
@@ -554,6 +554,9 @@ class Case(TokenList):
if token.match(T.Keyword, 'CASE'):
continue
+ elif skip_ws and token.ttype in T.Whitespace:
+ continue
+
elif token.match(T.Keyword, 'WHEN'):
ret.append(([], []))
mode = CONDITION
diff --git a/tests/test_format.py b/tests/test_format.py
index 9043e76..7b5af06 100644
--- a/tests/test_format.py
+++ b/tests/test_format.py
@@ -106,6 +106,234 @@ class TestFormat(TestCaseBase):
output_format='foo')
+class TestFormatReindentAligned(TestCaseBase):
+ @staticmethod
+ def formatter(sql):
+ return sqlparse.format(sql, reindent_aligned=True)
+
+ def test_basic(self):
+ sql = """
+ select a, b as bb,c from table
+ join (select a * 2 as a from new_table) other
+ on table.a = other.a
+ where c is true
+ and b between 3 and 4
+ or d is 'blue'
+ limit 10
+ """
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ '\n'.join([
+ 'select a,',
+ ' b as bb,',
+ ' c',
+ ' from table',
+ ' join (',
+ ' select a * 2 as a',
+ ' from new_table',
+ ' ) other',
+ ' on table.a = other.a',
+ ' where c is true',
+ ' and b between 3 and 4',
+ " or d is 'blue'",
+ ' limit 10',
+ ]))
+
+ def test_joins(self):
+ sql = """
+ select * from a
+ join b on a.one = b.one
+ left join c on c.two = a.two and c.three = a.three
+ full outer join d on d.three = a.three
+ cross join e on e.four = a.four
+ join f using (one, two, three)
+ """
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ '\n'.join([
+ 'select *',
+ ' from a',
+ ' join b',
+ ' on a.one = b.one',
+ ' left join c',
+ ' on c.two = a.two',
+ ' and c.three = a.three',
+ ' full outer join d',
+ ' on d.three = a.three',
+ ' cross join e',
+ ' on e.four = a.four',
+ ' join f using (one, two, three)',
+ ]))
+
+ def test_case_statement(self):
+ sql = """
+ select a,
+ case when a = 0
+ then 1
+ when bb = 1 then 1
+ when c = 2 then 2
+ else 0 end as d,
+ extra_col
+ from table
+ where c is true
+ and b between 3 and 4
+ """
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ '\n'.join([
+ 'select a,',
+ ' case when a = 0 then 1',
+ ' when bb = 1 then 1',
+ ' when c = 2 then 2',
+ ' else 0',
+ ' end as d,',
+ ' extra_col',
+ ' from table',
+ ' where c is true',
+ ' and b between 3 and 4'
+ ]))
+
+ def test_case_statement_with_between(self):
+ sql = """
+ select a,
+ case when a = 0
+ then 1
+ when bb = 1 then 1
+ when c = 2 then 2
+ when d between 3 and 5 then 3
+ else 0 end as d,
+ extra_col
+ from table
+ where c is true
+ and b between 3 and 4
+ """
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ '\n'.join([
+ 'select a,',
+ ' case when a = 0 then 1',
+ ' when bb = 1 then 1',
+ ' when c = 2 then 2',
+ ' when d between 3 and 5 then 3',
+ ' else 0',
+ ' end as d,',
+ ' extra_col',
+ ' from table',
+ ' where c is true',
+ ' and b between 3 and 4'
+ ]))
+
+ def test_group_by(self):
+ sql = """
+ select a, b, c, sum(x) as sum_x, count(y) as cnt_y
+ from table
+ group by a,b,c
+ having sum(x) > 1
+ and count(y) > 5
+ order by 3,2,1
+ """
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ '\n'.join([
+ 'select a,',
+ ' b,',
+ ' c,',
+ ' sum(x) as sum_x,',
+ ' count(y) as cnt_y',
+ ' from table',
+ ' group by a,',
+ ' b,',
+ ' c',
+ 'having sum(x) > 1',
+ ' and count(y) > 5',
+ ' order by 3,',
+ ' 2,',
+ ' 1',
+ ]))
+
+ def test_group_by_subquery(self):
+ # TODO: add subquery alias when test_identifier_list_subquery fixed
+ sql = """
+ select *, sum_b + 2 as mod_sum
+ from (
+ select a, sum(b) as sum_b
+ from table
+ group by a,z)
+ order by 1,2
+ """
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ '\n'.join([
+ 'select *,',
+ ' sum_b + 2 as mod_sum',
+ ' from (',
+ ' select a,',
+ ' sum(b) as sum_b',
+ ' from table',
+ ' group by a,',
+ ' z',
+ ' )',
+ ' order by 1,',
+ ' 2',
+ ]))
+
+ def test_window_functions(self):
+ sql = """
+ select a,
+ SUM(a) OVER (PARTITION BY b ORDER BY c ROWS
+ BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,
+ ROW_NUMBER() OVER
+ (PARTITION BY b, c ORDER BY d DESC) as row_num
+ from table
+ """
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ '\n'.join([
+ 'select a,',
+ (' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS '
+ 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,'),
+ (' ROW_NUMBER() OVER '
+ '(PARTITION BY b, c ORDER BY d DESC) as row_num'),
+ ' from table',
+ ]))
+
+
+class TestSpacesAroundOperators(TestCaseBase):
+ @staticmethod
+ def formatter(sql):
+ return sqlparse.format(sql, use_space_around_operators=True)
+
+ def test_basic(self):
+ sql = ('select a+b as d from table '
+ 'where (c-d)%2= 1 and e> 3.0/4 and z^2 <100')
+ self.ndiffAssertEqual(
+ self.formatter(sql), (
+ 'select a + b as d from table '
+ 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100')
+ )
+
+ def test_bools(self):
+ sql = 'select * from table where a &&b or c||d'
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ 'select * from table where a && b or c || d'
+ )
+
+ def test_nested(self):
+ sql = 'select *, case when a-b then c end from table'
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ 'select *, case when a - b then c end from table'
+ )
+
+ def test_wildcard_vs_mult(self):
+ sql = 'select a*b-c from table'
+ self.ndiffAssertEqual(
+ self.formatter(sql),
+ 'select a * b - c from table'
+ )
+
+
class TestFormatReindent(TestCaseBase):
def test_option(self):
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index c35c61b..fdcd4a7 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -122,6 +122,17 @@ class TestGrouping(TestCaseBase):
p = sqlparse.parse('(a, b, c)')[0]
self.assert_(isinstance(p.tokens[0].tokens[1], sql.IdentifierList))
+ def test_identifier_list_subquery(self):
+ """identifier lists should still work in subqueries with aliases"""
+ p = sqlparse.parse("select * from ("
+ "select a, b + c as d from table) sub")[0]
+ subquery = p.tokens[-1].tokens[0]
+ iden_list = subquery.token_next_by(i=sql.IdentifierList)
+ self.assert_(iden_list is not None)
+ # all the identifiers should be within the IdentifierList
+ self.assert_(subquery.token_next_by(i=sql.Identifier,
+ idx=iden_list) is None)
+
def test_identifier_list_case(self):
p = sqlparse.parse('a, case when 1 then 2 else 3 end as b, c')[0]
self.assert_(isinstance(p.tokens[0], sql.IdentifierList))
@@ -140,6 +151,11 @@ class TestGrouping(TestCaseBase):
self.assert_(isinstance(p.tokens[0].tokens[0], sql.Identifier))
self.assert_(isinstance(p.tokens[0].tokens[3], sql.Identifier))
+ def test_identifiers_with_operators(self):
+ p = sqlparse.parse('a+b as c from table where (d-e)%2= 1')[0]
+ self.assertEqual(len([x for x in p.flatten()
+ if x.ttype == sqlparse.tokens.Name]), 5)
+
def test_identifier_list_with_order(self): # issue101
p = sqlparse.parse('1, 2 desc, 3')[0]
self.assert_(isinstance(p.tokens[0], sql.IdentifierList))