summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2015-10-26 19:40:29 +0100
committerAndi Albrecht <albrecht.andi@gmail.com>2015-10-26 19:40:29 +0100
commite6a51a0bc3f87e284de74cec838d3ee98c2f9cf5 (patch)
tree8ac469820a09c31e9e49543ecbbbdeebad53c85e /tests
parent8bfdaf3cc37ffe48a60c7f4ee5d5e99d0b07e696 (diff)
downloadsqlparse-e6a51a0bc3f87e284de74cec838d3ee98c2f9cf5.tar.gz
Use compat module for single Python 2/3 code base.
This change includes minor fixes and code cleanup too.
Diffstat (limited to 'tests')
-rw-r--r--tests/test_filters.py1
-rw-r--r--tests/test_grouping.py25
-rw-r--r--tests/test_parse.py10
-rw-r--r--tests/test_regressions.py1
-rw-r--r--tests/test_split.py35
-rw-r--r--tests/test_tokenize.py5
-rw-r--r--tests/utils.py10
7 files changed, 42 insertions, 45 deletions
diff --git a/tests/test_filters.py b/tests/test_filters.py
index d827454..eb61604 100644
--- a/tests/test_filters.py
+++ b/tests/test_filters.py
@@ -5,6 +5,7 @@ Created on 24/03/2012
'''
import unittest
+from sqlparse.compat import u
from sqlparse.filters import StripWhitespace, Tokens2Unicode
from sqlparse.lexer import tokenize
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index 5ade830..fa68ab2 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -5,6 +5,7 @@ import pytest
import sqlparse
from sqlparse import sql
from sqlparse import tokens as T
+from sqlparse.compat import u
from tests.utils import TestCaseBase
@@ -26,7 +27,7 @@ class TestGrouping(TestCaseBase):
def test_comments(self):
s = '/*\n * foo\n */ \n bar'
parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(parsed))
+ self.ndiffAssertEqual(s, u(parsed))
self.assertEqual(len(parsed.tokens), 2)
def test_assignment(self):
@@ -42,18 +43,18 @@ class TestGrouping(TestCaseBase):
def test_identifiers(self):
s = 'select foo.bar from "myscheme"."table" where fail. order'
parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(parsed))
+ self.ndiffAssertEqual(s, u(parsed))
self.assert_(isinstance(parsed.tokens[2], sql.Identifier))
self.assert_(isinstance(parsed.tokens[6], sql.Identifier))
self.assert_(isinstance(parsed.tokens[8], sql.Where))
s = 'select * from foo where foo.id = 1'
parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(parsed))
+ self.ndiffAssertEqual(s, u(parsed))
self.assert_(isinstance(parsed.tokens[-1].tokens[-1].tokens[0],
sql.Identifier))
s = 'select * from (select "foo"."id" from foo)'
parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(parsed))
+ self.ndiffAssertEqual(s, u(parsed))
self.assert_(isinstance(parsed.tokens[-1].tokens[3], sql.Identifier))
s = "INSERT INTO `test` VALUES('foo', 'bar');"
@@ -141,44 +142,44 @@ class TestGrouping(TestCaseBase):
def test_where(self):
s = 'select * from foo where bar = 1 order by id desc'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(p))
+ self.ndiffAssertEqual(s, u(p))
self.assertTrue(len(p.tokens), 16)
s = 'select x from (select y from foo where bar = 1) z'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(p))
+ self.ndiffAssertEqual(s, u(p))
self.assertTrue(isinstance(p.tokens[-1].tokens[0].tokens[-2], sql.Where))
def test_typecast(self):
s = 'select foo::integer from bar'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(p))
+ self.ndiffAssertEqual(s, u(p))
self.assertEqual(p.tokens[2].get_typecast(), 'integer')
self.assertEqual(p.tokens[2].get_name(), 'foo')
s = 'select (current_database())::information_schema.sql_identifier'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(p))
+ self.ndiffAssertEqual(s, u(p))
self.assertEqual(p.tokens[2].get_typecast(),
'information_schema.sql_identifier')
def test_alias(self):
s = 'select foo as bar from mytable'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(p))
+ self.ndiffAssertEqual(s, u(p))
self.assertEqual(p.tokens[2].get_real_name(), 'foo')
self.assertEqual(p.tokens[2].get_alias(), 'bar')
s = 'select foo from mytable t1'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(p))
+ self.ndiffAssertEqual(s, u(p))
self.assertEqual(p.tokens[6].get_real_name(), 'mytable')
self.assertEqual(p.tokens[6].get_alias(), 't1')
s = 'select foo::integer as bar from mytable'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(p))
+ self.ndiffAssertEqual(s, u(p))
self.assertEqual(p.tokens[2].get_alias(), 'bar')
s = ('SELECT DISTINCT '
'(current_database())::information_schema.sql_identifier AS view')
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, unicode(p))
+ self.ndiffAssertEqual(s, u(p))
self.assertEqual(p.tokens[4].get_alias(), 'view')
def test_alias_case(self): # see issue46
diff --git a/tests/test_parse.py b/tests/test_parse.py
index 6c9d6a6..fb7b24b 100644
--- a/tests/test_parse.py
+++ b/tests/test_parse.py
@@ -8,8 +8,8 @@ from tests.utils import TestCaseBase
import sqlparse
import sqlparse.sql
-
from sqlparse import tokens as T
+from sqlparse.compat import u
class SQLParseTest(TestCaseBase):
@@ -32,16 +32,16 @@ class SQLParseTest(TestCaseBase):
def test_newlines(self):
sql = u'select\n*from foo;'
p = sqlparse.parse(sql)[0]
- self.assertEqual(unicode(p), sql)
+ self.assertEqual(u(p), sql)
sql = u'select\r\n*from foo'
p = sqlparse.parse(sql)[0]
- self.assertEqual(unicode(p), sql)
+ self.assertEqual(u(p), sql)
sql = u'select\r*from foo'
p = sqlparse.parse(sql)[0]
- self.assertEqual(unicode(p), sql)
+ self.assertEqual(u(p), sql)
sql = u'select\r\n*from foo\n'
p = sqlparse.parse(sql)[0]
- self.assertEqual(unicode(p), sql)
+ self.assertEqual(u(p), sql)
def test_within(self):
sql = 'foo(col1, col2)'
diff --git a/tests/test_regressions.py b/tests/test_regressions.py
index a64b400..f873c78 100644
--- a/tests/test_regressions.py
+++ b/tests/test_regressions.py
@@ -256,6 +256,7 @@ SELECT * FROM a.b;"""
splitted = sqlparse.split(sql)
assert len(splitted) == 2
+
def test_issue194_splitting_function():
sql = """CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20)
BEGIN
diff --git a/tests/test_split.py b/tests/test_split.py
index 54e8d04..f6d5f50 100644
--- a/tests/test_split.py
+++ b/tests/test_split.py
@@ -2,11 +2,12 @@
# Tests splitting functions.
-import unittest
+import types
from tests.utils import load_file, TestCaseBase
import sqlparse
+from sqlparse.compat import StringIO, u, text_type
class SQLSplitTest(TestCaseBase):
@@ -19,8 +20,8 @@ class SQLSplitTest(TestCaseBase):
sql2 = 'select * from foo where bar = \'foo;bar\';'
stmts = sqlparse.parse(''.join([self._sql1, sql2]))
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(unicode(stmts[0]), self._sql1)
- self.ndiffAssertEqual(unicode(stmts[1]), sql2)
+ self.ndiffAssertEqual(u(stmts[0]), self._sql1)
+ self.ndiffAssertEqual(u(stmts[1]), sql2)
def test_split_backslash(self):
stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';")
@@ -30,31 +31,31 @@ class SQLSplitTest(TestCaseBase):
sql = load_file('function.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(u(stmts[0]), sql)
def test_create_function_psql(self):
sql = load_file('function_psql.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(u(stmts[0]), sql)
def test_create_function_psql3(self):
sql = load_file('function_psql3.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(u(stmts[0]), sql)
def test_create_function_psql2(self):
sql = load_file('function_psql2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(u(stmts[0]), sql)
def test_dashcomments(self):
sql = load_file('dashcomment.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
def test_dashcomments_eol(self):
stmts = sqlparse.parse('select foo; -- comment\n')
@@ -70,19 +71,19 @@ class SQLSplitTest(TestCaseBase):
sql = load_file('begintag.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
def test_begintag_2(self):
sql = load_file('begintag_2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
def test_dropif(self):
sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
def test_comment_with_umlaut(self):
sql = (u'select * from foo;\n'
@@ -90,16 +91,16 @@ class SQLSplitTest(TestCaseBase):
u'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
def test_comment_end_of_line(self):
sql = ('select * from foo; -- foo\n'
'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
# make sure the comment belongs to first query
- self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n')
+ self.ndiffAssertEqual(u(stmts[0]), 'select * from foo; -- foo\n')
def test_casewhen(self):
sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
@@ -122,19 +123,15 @@ class SQLSplitTest(TestCaseBase):
self.assertEqual(len(stmts), 2)
def test_split_stream(self):
- import types
- from cStringIO import StringIO
-
stream = StringIO("SELECT 1; SELECT 2;")
stmts = sqlparse.parsestream(stream)
self.assertEqual(type(stmts), types.GeneratorType)
self.assertEqual(len(list(stmts)), 2)
def test_encoding_parsestream(self):
- from cStringIO import StringIO
stream = StringIO("SELECT 1; SELECT 2;")
stmts = list(sqlparse.parsestream(stream))
- self.assertEqual(type(stmts[0].tokens[0].value), unicode)
+ self.assertEqual(type(stmts[0].tokens[0].value), text_type)
def test_split_simple():
diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py
index 0b23fa8..ceaf24e 100644
--- a/tests/test_tokenize.py
+++ b/tests/test_tokenize.py
@@ -9,6 +9,7 @@ import pytest
import sqlparse
from sqlparse import lexer
from sqlparse import sql
+from sqlparse.compat import StringIO
from sqlparse.tokens import *
@@ -133,8 +134,6 @@ class TestTokenList(unittest.TestCase):
class TestStream(unittest.TestCase):
def test_simple(self):
- from cStringIO import StringIO
-
stream = StringIO("SELECT 1; SELECT 2;")
lex = lexer.Lexer()
@@ -152,8 +151,6 @@ class TestStream(unittest.TestCase):
self.assertEqual(len(tokens), 9)
def test_error(self):
- from cStringIO import StringIO
-
stream = StringIO("FOOBAR{")
lex = lexer.Lexer()
diff --git a/tests/utils.py b/tests/utils.py
index 9eb46bf..b596ff4 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -6,9 +6,9 @@ import codecs
import difflib
import os
import unittest
-from StringIO import StringIO
import sqlparse.utils
+from sqlparse.compat import u, StringIO
NL = '\n'
DIR_PATH = os.path.abspath(os.path.dirname(__file__))
@@ -31,8 +31,8 @@ class TestCaseBase(unittest.TestCase):
def ndiffAssertEqual(self, first, second):
"""Like failUnlessEqual except use ndiff for readable output."""
if first != second:
- sfirst = unicode(first)
- ssecond = unicode(second)
+ sfirst = u(first)
+ ssecond = u(second)
# Using the built-in .splitlines() method here will cause incorrect
# results when splitting statements that have quoted CR/CR+LF
# characters.
@@ -42,5 +42,5 @@ class TestCaseBase(unittest.TestCase):
fp = StringIO()
fp.write(NL)
fp.write(NL.join(diff))
- print fp.getvalue()
- raise self.failureException, fp.getvalue()
+ # print(fp.getvalue())
+ raise self.failureException(fp.getvalue())