diff options
| author | Andi Albrecht <albrecht.andi@gmail.com> | 2009-05-14 22:17:34 +0200 |
|---|---|---|
| committer | Andi Albrecht <albrecht.andi@gmail.com> | 2009-05-14 22:17:34 +0200 |
| commit | 691c0400e5a7d8229b7dce09bf47176539add328 (patch) | |
| tree | 066fcba8207dc039d0e9c121f668fccbbb22ff1e | |
| parent | 5ccb54dae178189623b6223ea95e261046c6bb1a (diff) | |
| download | sqlparse-691c0400e5a7d8229b7dce09bf47176539add328.tar.gz | |
Fixed grouping of semicolons within assignments.
| -rw-r--r-- | sqlparse/engine/grouping.py | 8 | ||||
| -rw-r--r-- | tests/test_grouping.py | 16 |
2 files changed, 15 insertions, 9 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index 181dae4..471116e 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -29,8 +29,12 @@ def _group_left_right(tlist, ttype, value, cls, ttype, value) else: if include_semicolon: - right = tlist.token_next_match(tlist.token_index(right), - T.Punctuation, ';') + sright = tlist.token_next_match(tlist.token_index(right), + T.Punctuation, ';') + if sright is not None: + # only overwrite "right" if a semicolon is actually + # present. + right = sright tokens = tlist.tokens_between(left, right)[1:] if not isinstance(left, cls): new = cls([left]) diff --git a/tests/test_grouping.py b/tests/test_grouping.py index d2f08fe..6477123 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -26,13 +26,15 @@ class TestGrouping(TestCaseBase): self.ndiffAssertEqual(s, unicode(parsed)) self.assertEqual(len(parsed.tokens), 2) - #def test_xassignment(self): - # s = 'foo := 1;' - # parsed = sqlparse.parse(s)[0] - # self.assertEqual(len(parsed.tokens), 1) - #s = 'foo := 1' - #parsed = sqlparse.parse(s)[0] - #self.assertEqual(len(parsed.tokens), 1) + def test_assignment(self): + s = 'foo := 1;' + parsed = sqlparse.parse(s)[0] + self.assertEqual(len(parsed.tokens), 1) + self.assert_(isinstance(parsed.tokens[0], Assignment)) + s = 'foo := 1' + parsed = sqlparse.parse(s)[0] + self.assertEqual(len(parsed.tokens), 1) + self.assert_(isinstance(parsed.tokens[0], Assignment)) def test_identifiers(self): s = 'select foo.bar from "myscheme"."table" where fail. order' |
