diff options
| -rw-r--r-- | CHANGES | 8 | ||||
| -rw-r--r-- | sqlparse/engine/grouping.py | 25 | ||||
| -rw-r--r-- | tests/test_format.py | 7 |
3 files changed, 33 insertions, 7 deletions
@@ -1,3 +1,11 @@ +Development +----------- + +Bug Fixes + * Avoid "stair case" effects when identifiers and functions are mixed + in identifier lists (issue45). + + Release 0.1.3 (Jul 29, 2011) ---------------------------- diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index 9bc9612..6e99782 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -144,15 +144,30 @@ def group_identifier(tlist): else: raise StopIteration + def _next_token(tl, i): + # chooses the next token. if two tokens are found then the + # first is returned. + t1 = tl.token_next_by_type(i, (T.String.Symbol, T.Name)) + t2 = tl.token_next_by_instance(i, sql.Function) + if t1 and t2: + i1 = tl.token_index(t1) + i2 = tl.token_index(t2) + if i1 > i2: + return t2 + else: + return t1 + elif t1: + return t1 + else: + return t2 + # bottom up approach: group subgroups first [group_identifier(sgroup) for sgroup in tlist.get_sublists() if not isinstance(sgroup, sql.Identifier)] # real processing idx = 0 - token = tlist.token_next_by_instance(idx, sql.Function) - if token is None: - token = tlist.token_next_by_type(idx, (T.String.Symbol, T.Name)) + token = _next_token(tlist, idx) while token: identifier_tokens = [token] + list( _consume_cycle(tlist, @@ -163,9 +178,7 @@ def group_identifier(tlist): idx = tlist.token_index(group) + 1 else: idx += 1 - token = tlist.token_next_by_instance(idx, sql.Function) - if token is None: - token = tlist.token_next_by_type(idx, (T.String.Symbol, T.Name)) + token = _next_token(tlist, idx) def group_identifier_list(tlist): diff --git a/tests/test_format.py b/tests/test_format.py index e41b6b6..7a2c655 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -228,7 +228,12 @@ class TestFormatReindent(TestCaseBase): ' foo,', ' bar'])) - + def test_identifier_and_functions(self): # issue45 + f = lambda sql: sqlparse.format(sql, reindent=True) + s = 'select foo.bar, nvl(1) from dual' + self.ndiffAssertEqual(f(s), '\n'.join(['select foo.bar,', + ' nvl(1)', + 'from dual'])) class TestOutputFormat(TestCaseBase): |
