diff options
Diffstat (limited to 'tests/test_split.py')
| -rw-r--r-- | tests/test_split.py | 256 |
1 files changed, 125 insertions, 131 deletions
diff --git a/tests/test_split.py b/tests/test_split.py index 7c2645d..af7c9ce 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -4,139 +4,133 @@ import types -from tests.utils import load_file, TestCaseBase +import pytest import sqlparse -from sqlparse.compat import StringIO, u, text_type - - -class SQLSplitTest(TestCaseBase): - """Tests sqlparse.sqlsplit().""" - - _sql1 = 'select * from foo;' - _sql2 = 'select * from bar;' - - def test_split_semicolon(self): - sql2 = 'select * from foo where bar = \'foo;bar\';' - stmts = sqlparse.parse(''.join([self._sql1, sql2])) - self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(u(stmts[0]), self._sql1) - self.ndiffAssertEqual(u(stmts[1]), sql2) - - def test_split_backslash(self): - stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';") - self.assertEqual(len(stmts), 3) - - def test_create_function(self): - sql = load_file('function.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) - - def test_create_function_psql(self): - sql = load_file('function_psql.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) - - def test_create_function_psql3(self): - sql = load_file('function_psql3.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) - - def test_create_function_psql2(self): - sql = load_file('function_psql2.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) - - def test_dashcomments(self): - sql = load_file('dashcomment.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 3) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - - def test_dashcomments_eol(self): - stmts = sqlparse.parse('select foo; -- comment\n') - self.assertEqual(len(stmts), 1) - stmts = sqlparse.parse('select foo; -- comment\r') - self.assertEqual(len(stmts), 1) - stmts = sqlparse.parse('select foo; -- comment\r\n') - self.assertEqual(len(stmts), 1) - stmts = sqlparse.parse('select foo; -- comment') - self.assertEqual(len(stmts), 1) - - def test_begintag(self): - sql = load_file('begintag.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 3) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - - def test_begintag_2(self): - sql = load_file('begintag_2.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - - def test_dropif(self): - sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;' - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - - def test_comment_with_umlaut(self): - sql = (u'select * from foo;\n' - u'-- Testing an umlaut: ä\n' - u'select * from bar;') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - - def test_comment_end_of_line(self): - sql = ('select * from foo; -- foo\n' - 'select * from bar;') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - # make sure the comment belongs to first query - self.ndiffAssertEqual(u(stmts[0]), 'select * from foo; -- foo\n') - - def test_casewhen(self): - sql = ('SELECT case when val = 1 then 2 else null end as foo;\n' - 'comment on table actor is \'The actor table.\';') - stmts = sqlparse.split(sql) - self.assertEqual(len(stmts), 2) - - def test_cursor_declare(self): - sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n' - 'SELECT 2;') - stmts = sqlparse.split(sql) - self.assertEqual(len(stmts), 2) - - def test_if_function(self): # see issue 33 - # don't let IF as a function confuse the splitter - sql = ('CREATE TEMPORARY TABLE tmp ' - 'SELECT IF(a=1, a, b) AS o FROM one; ' - 'SELECT t FROM two') - stmts = sqlparse.split(sql) - self.assertEqual(len(stmts), 2) - - def test_split_stream(self): - stream = StringIO("SELECT 1; SELECT 2;") - stmts = sqlparse.parsestream(stream) - self.assertEqual(type(stmts), types.GeneratorType) - self.assertEqual(len(list(stmts)), 2) - - def test_encoding_parsestream(self): - stream = StringIO("SELECT 1; SELECT 2;") - stmts = list(sqlparse.parsestream(stream)) - self.assertEqual(type(stmts[0].tokens[0].value), text_type) - - def test_unicode_parsestream(self): - stream = StringIO(u"SELECT ö") - stmts = list(sqlparse.parsestream(stream)) - self.assertEqual(str(stmts[0]), "SELECT ö") +from sqlparse.compat import StringIO, text_type + + +def test_split_semicolon(): + sql1 = 'select * from foo;' + sql2 = "select * from foo where bar = 'foo;bar';" + stmts = sqlparse.parse(''.join([sql1, sql2])) + assert len(stmts) == 2 + assert str(stmts[0]) == sql1 + assert str(stmts[1]) == sql2 + + +def test_split_backslash(): + stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';") + assert len(stmts) == 3 + + +@pytest.mark.parametrize('fn', ['function.sql', + 'function_psql.sql', + 'function_psql2.sql', + 'function_psql3.sql']) +def test_split_create_function(load_file, fn): + sql = load_file(fn) + stmts = sqlparse.parse(sql) + assert len(stmts) == 1 + assert text_type(stmts[0]) == sql + + +def test_split_dashcomments(load_file): + sql = load_file('dashcomment.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 3 + assert ''.join(str(q) for q in stmts) == sql + + +@pytest.mark.parametrize('s', ['select foo; -- comment\n', + 'select foo; -- comment\r', + 'select foo; -- comment\r\n', + 'select foo; -- comment']) +def test_split_dashcomments_eol(s): + stmts = sqlparse.parse(s) + assert len(stmts) == 1 + + +def test_split_begintag(load_file): + sql = load_file('begintag.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 3 + assert ''.join(str(q) for q in stmts) == sql + + +def test_split_begintag_2(load_file): + sql = load_file('begintag_2.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 1 + assert ''.join(str(q) for q in stmts) == sql + + +def test_split_dropif(): + sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;' + stmts = sqlparse.parse(sql) + assert len(stmts) == 2 + assert ''.join(str(q) for q in stmts) == sql + + +def test_split_comment_with_umlaut(): + sql = (u'select * from foo;\n' + u'-- Testing an umlaut: ä\n' + u'select * from bar;') + stmts = sqlparse.parse(sql) + assert len(stmts) == 2 + assert ''.join(text_type(q) for q in stmts) == sql + + +def test_split_comment_end_of_line(): + sql = ('select * from foo; -- foo\n' + 'select * from bar;') + stmts = sqlparse.parse(sql) + assert len(stmts) == 2 + assert ''.join(str(q) for q in stmts) == sql + # make sure the comment belongs to first query + assert str(stmts[0]) == 'select * from foo; -- foo\n' + + +def test_split_casewhen(): + sql = ("SELECT case when val = 1 then 2 else null end as foo;\n" + "comment on table actor is 'The actor table.';") + stmts = sqlparse.split(sql) + assert len(stmts) == 2 + + +def test_split_cursor_declare(): + sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n' + 'SELECT 2;') + stmts = sqlparse.split(sql) + assert len(stmts) == 2 + + +def test_split_if_function(): # see issue 33 + # don't let IF as a function confuse the splitter + sql = ('CREATE TEMPORARY TABLE tmp ' + 'SELECT IF(a=1, a, b) AS o FROM one; ' + 'SELECT t FROM two') + stmts = sqlparse.split(sql) + assert len(stmts) == 2 + + +def test_split_stream(): + stream = StringIO("SELECT 1; SELECT 2;") + stmts = sqlparse.parsestream(stream) + assert isinstance(stmts, types.GeneratorType) + assert len(list(stmts)) == 2 + + +def test_split_encoding_parsestream(): + stream = StringIO("SELECT 1; SELECT 2;") + stmts = list(sqlparse.parsestream(stream)) + assert isinstance(stmts[0].tokens[0].value, text_type) + + +def test_split_unicode_parsestream(): + stream = StringIO(u'SELECT ö') + stmts = list(sqlparse.parsestream(stream)) + assert str(stmts[0]) == 'SELECT ö' def test_split_simple(): |
