summaryrefslogtreecommitdiff
path: root/tests/test_split.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_split.py')
-rw-r--r--tests/test_split.py256
1 files changed, 125 insertions, 131 deletions
diff --git a/tests/test_split.py b/tests/test_split.py
index 7c2645d..af7c9ce 100644
--- a/tests/test_split.py
+++ b/tests/test_split.py
@@ -4,139 +4,133 @@
import types
-from tests.utils import load_file, TestCaseBase
+import pytest
import sqlparse
-from sqlparse.compat import StringIO, u, text_type
-
-
-class SQLSplitTest(TestCaseBase):
- """Tests sqlparse.sqlsplit()."""
-
- _sql1 = 'select * from foo;'
- _sql2 = 'select * from bar;'
-
- def test_split_semicolon(self):
- sql2 = 'select * from foo where bar = \'foo;bar\';'
- stmts = sqlparse.parse(''.join([self._sql1, sql2]))
- self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(u(stmts[0]), self._sql1)
- self.ndiffAssertEqual(u(stmts[1]), sql2)
-
- def test_split_backslash(self):
- stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';")
- self.assertEqual(len(stmts), 3)
-
- def test_create_function(self):
- sql = load_file('function.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
-
- def test_create_function_psql(self):
- sql = load_file('function_psql.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
-
- def test_create_function_psql3(self):
- sql = load_file('function_psql3.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
-
- def test_create_function_psql2(self):
- sql = load_file('function_psql2.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
-
- def test_dashcomments(self):
- sql = load_file('dashcomment.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
-
- def test_dashcomments_eol(self):
- stmts = sqlparse.parse('select foo; -- comment\n')
- self.assertEqual(len(stmts), 1)
- stmts = sqlparse.parse('select foo; -- comment\r')
- self.assertEqual(len(stmts), 1)
- stmts = sqlparse.parse('select foo; -- comment\r\n')
- self.assertEqual(len(stmts), 1)
- stmts = sqlparse.parse('select foo; -- comment')
- self.assertEqual(len(stmts), 1)
-
- def test_begintag(self):
- sql = load_file('begintag.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
-
- def test_begintag_2(self):
- sql = load_file('begintag_2.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
-
- def test_dropif(self):
- sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
-
- def test_comment_with_umlaut(self):
- sql = (u'select * from foo;\n'
- u'-- Testing an umlaut: ä\n'
- u'select * from bar;')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
-
- def test_comment_end_of_line(self):
- sql = ('select * from foo; -- foo\n'
- 'select * from bar;')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
- # make sure the comment belongs to first query
- self.ndiffAssertEqual(u(stmts[0]), 'select * from foo; -- foo\n')
-
- def test_casewhen(self):
- sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
- 'comment on table actor is \'The actor table.\';')
- stmts = sqlparse.split(sql)
- self.assertEqual(len(stmts), 2)
-
- def test_cursor_declare(self):
- sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n'
- 'SELECT 2;')
- stmts = sqlparse.split(sql)
- self.assertEqual(len(stmts), 2)
-
- def test_if_function(self): # see issue 33
- # don't let IF as a function confuse the splitter
- sql = ('CREATE TEMPORARY TABLE tmp '
- 'SELECT IF(a=1, a, b) AS o FROM one; '
- 'SELECT t FROM two')
- stmts = sqlparse.split(sql)
- self.assertEqual(len(stmts), 2)
-
- def test_split_stream(self):
- stream = StringIO("SELECT 1; SELECT 2;")
- stmts = sqlparse.parsestream(stream)
- self.assertEqual(type(stmts), types.GeneratorType)
- self.assertEqual(len(list(stmts)), 2)
-
- def test_encoding_parsestream(self):
- stream = StringIO("SELECT 1; SELECT 2;")
- stmts = list(sqlparse.parsestream(stream))
- self.assertEqual(type(stmts[0].tokens[0].value), text_type)
-
- def test_unicode_parsestream(self):
- stream = StringIO(u"SELECT ö")
- stmts = list(sqlparse.parsestream(stream))
- self.assertEqual(str(stmts[0]), "SELECT ö")
+from sqlparse.compat import StringIO, text_type
+
+
+def test_split_semicolon():
+ sql1 = 'select * from foo;'
+ sql2 = "select * from foo where bar = 'foo;bar';"
+ stmts = sqlparse.parse(''.join([sql1, sql2]))
+ assert len(stmts) == 2
+ assert str(stmts[0]) == sql1
+ assert str(stmts[1]) == sql2
+
+
+def test_split_backslash():
+ stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';")
+ assert len(stmts) == 3
+
+
+@pytest.mark.parametrize('fn', ['function.sql',
+ 'function_psql.sql',
+ 'function_psql2.sql',
+ 'function_psql3.sql'])
+def test_split_create_function(load_file, fn):
+ sql = load_file(fn)
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 1
+ assert text_type(stmts[0]) == sql
+
+
+def test_split_dashcomments(load_file):
+ sql = load_file('dashcomment.sql')
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 3
+ assert ''.join(str(q) for q in stmts) == sql
+
+
+@pytest.mark.parametrize('s', ['select foo; -- comment\n',
+ 'select foo; -- comment\r',
+ 'select foo; -- comment\r\n',
+ 'select foo; -- comment'])
+def test_split_dashcomments_eol(s):
+ stmts = sqlparse.parse(s)
+ assert len(stmts) == 1
+
+
+def test_split_begintag(load_file):
+ sql = load_file('begintag.sql')
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 3
+ assert ''.join(str(q) for q in stmts) == sql
+
+
+def test_split_begintag_2(load_file):
+ sql = load_file('begintag_2.sql')
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 1
+ assert ''.join(str(q) for q in stmts) == sql
+
+
+def test_split_dropif():
+ sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 2
+ assert ''.join(str(q) for q in stmts) == sql
+
+
+def test_split_comment_with_umlaut():
+ sql = (u'select * from foo;\n'
+ u'-- Testing an umlaut: ä\n'
+ u'select * from bar;')
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 2
+ assert ''.join(text_type(q) for q in stmts) == sql
+
+
+def test_split_comment_end_of_line():
+ sql = ('select * from foo; -- foo\n'
+ 'select * from bar;')
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 2
+ assert ''.join(str(q) for q in stmts) == sql
+ # make sure the comment belongs to first query
+ assert str(stmts[0]) == 'select * from foo; -- foo\n'
+
+
+def test_split_casewhen():
+ sql = ("SELECT case when val = 1 then 2 else null end as foo;\n"
+ "comment on table actor is 'The actor table.';")
+ stmts = sqlparse.split(sql)
+ assert len(stmts) == 2
+
+
+def test_split_cursor_declare():
+ sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n'
+ 'SELECT 2;')
+ stmts = sqlparse.split(sql)
+ assert len(stmts) == 2
+
+
+def test_split_if_function(): # see issue 33
+ # don't let IF as a function confuse the splitter
+ sql = ('CREATE TEMPORARY TABLE tmp '
+ 'SELECT IF(a=1, a, b) AS o FROM one; '
+ 'SELECT t FROM two')
+ stmts = sqlparse.split(sql)
+ assert len(stmts) == 2
+
+
+def test_split_stream():
+ stream = StringIO("SELECT 1; SELECT 2;")
+ stmts = sqlparse.parsestream(stream)
+ assert isinstance(stmts, types.GeneratorType)
+ assert len(list(stmts)) == 2
+
+
+def test_split_encoding_parsestream():
+ stream = StringIO("SELECT 1; SELECT 2;")
+ stmts = list(sqlparse.parsestream(stream))
+ assert isinstance(stmts[0].tokens[0].value, text_type)
+
+
+def test_split_unicode_parsestream():
+ stream = StringIO(u'SELECT ö')
+ stmts = list(sqlparse.parsestream(stream))
+ assert str(stmts[0]) == 'SELECT ö'
def test_split_simple():