summaryrefslogtreecommitdiff
path: root/tests/test_split.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_split.py')
-rw-r--r--tests/test_split.py35
1 files changed, 16 insertions, 19 deletions
diff --git a/tests/test_split.py b/tests/test_split.py
index 54e8d04..f6d5f50 100644
--- a/tests/test_split.py
+++ b/tests/test_split.py
@@ -2,11 +2,12 @@
# Tests splitting functions.
-import unittest
+import types
from tests.utils import load_file, TestCaseBase
import sqlparse
+from sqlparse.compat import StringIO, u, text_type
class SQLSplitTest(TestCaseBase):
@@ -19,8 +20,8 @@ class SQLSplitTest(TestCaseBase):
sql2 = 'select * from foo where bar = \'foo;bar\';'
stmts = sqlparse.parse(''.join([self._sql1, sql2]))
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(unicode(stmts[0]), self._sql1)
- self.ndiffAssertEqual(unicode(stmts[1]), sql2)
+ self.ndiffAssertEqual(u(stmts[0]), self._sql1)
+ self.ndiffAssertEqual(u(stmts[1]), sql2)
def test_split_backslash(self):
stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';")
@@ -30,31 +31,31 @@ class SQLSplitTest(TestCaseBase):
sql = load_file('function.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(u(stmts[0]), sql)
def test_create_function_psql(self):
sql = load_file('function_psql.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(u(stmts[0]), sql)
def test_create_function_psql3(self):
sql = load_file('function_psql3.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(u(stmts[0]), sql)
def test_create_function_psql2(self):
sql = load_file('function_psql2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(u(stmts[0]), sql)
def test_dashcomments(self):
sql = load_file('dashcomment.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
def test_dashcomments_eol(self):
stmts = sqlparse.parse('select foo; -- comment\n')
@@ -70,19 +71,19 @@ class SQLSplitTest(TestCaseBase):
sql = load_file('begintag.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
def test_begintag_2(self):
sql = load_file('begintag_2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
def test_dropif(self):
sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
def test_comment_with_umlaut(self):
sql = (u'select * from foo;\n'
@@ -90,16 +91,16 @@ class SQLSplitTest(TestCaseBase):
u'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
def test_comment_end_of_line(self):
sql = ('select * from foo; -- foo\n'
'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
# make sure the comment belongs to first query
- self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n')
+ self.ndiffAssertEqual(u(stmts[0]), 'select * from foo; -- foo\n')
def test_casewhen(self):
sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
@@ -122,19 +123,15 @@ class SQLSplitTest(TestCaseBase):
self.assertEqual(len(stmts), 2)
def test_split_stream(self):
- import types
- from cStringIO import StringIO
-
stream = StringIO("SELECT 1; SELECT 2;")
stmts = sqlparse.parsestream(stream)
self.assertEqual(type(stmts), types.GeneratorType)
self.assertEqual(len(list(stmts)), 2)
def test_encoding_parsestream(self):
- from cStringIO import StringIO
stream = StringIO("SELECT 1; SELECT 2;")
stmts = list(sqlparse.parsestream(stream))
- self.assertEqual(type(stmts[0].tokens[0].value), unicode)
+ self.assertEqual(type(stmts[0].tokens[0].value), text_type)
def test_split_simple():