summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2009-05-14 22:17:34 +0200
committerAndi Albrecht <albrecht.andi@gmail.com>2009-05-14 22:17:34 +0200
commit691c0400e5a7d8229b7dce09bf47176539add328 (patch)
tree066fcba8207dc039d0e9c121f668fccbbb22ff1e
parent5ccb54dae178189623b6223ea95e261046c6bb1a (diff)
downloadsqlparse-691c0400e5a7d8229b7dce09bf47176539add328.tar.gz
Fixed grouping of semicolons within assignments.
-rw-r--r--sqlparse/engine/grouping.py8
-rw-r--r--tests/test_grouping.py16
2 files changed, 15 insertions, 9 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index 181dae4..471116e 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -29,8 +29,12 @@ def _group_left_right(tlist, ttype, value, cls,
ttype, value)
else:
if include_semicolon:
- right = tlist.token_next_match(tlist.token_index(right),
- T.Punctuation, ';')
+ sright = tlist.token_next_match(tlist.token_index(right),
+ T.Punctuation, ';')
+ if sright is not None:
+ # only overwrite "right" if a semicolon is actually
+ # present.
+ right = sright
tokens = tlist.tokens_between(left, right)[1:]
if not isinstance(left, cls):
new = cls([left])
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index d2f08fe..6477123 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -26,13 +26,15 @@ class TestGrouping(TestCaseBase):
self.ndiffAssertEqual(s, unicode(parsed))
self.assertEqual(len(parsed.tokens), 2)
- #def test_xassignment(self):
- # s = 'foo := 1;'
- # parsed = sqlparse.parse(s)[0]
- # self.assertEqual(len(parsed.tokens), 1)
- #s = 'foo := 1'
- #parsed = sqlparse.parse(s)[0]
- #self.assertEqual(len(parsed.tokens), 1)
+ def test_assignment(self):
+ s = 'foo := 1;'
+ parsed = sqlparse.parse(s)[0]
+ self.assertEqual(len(parsed.tokens), 1)
+ self.assert_(isinstance(parsed.tokens[0], Assignment))
+ s = 'foo := 1'
+ parsed = sqlparse.parse(s)[0]
+ self.assertEqual(len(parsed.tokens), 1)
+ self.assert_(isinstance(parsed.tokens[0], Assignment))
def test_identifiers(self):
s = 'select foo.bar from "myscheme"."table" where fail. order'