diff options
| author | Andi Albrecht <albrecht.andi@gmail.com> | 2015-10-26 19:40:29 +0100 |
|---|---|---|
| committer | Andi Albrecht <albrecht.andi@gmail.com> | 2015-10-26 19:40:29 +0100 |
| commit | e6a51a0bc3f87e284de74cec838d3ee98c2f9cf5 (patch) | |
| tree | 8ac469820a09c31e9e49543ecbbbdeebad53c85e | |
| parent | 8bfdaf3cc37ffe48a60c7f4ee5d5e99d0b07e696 (diff) | |
| download | sqlparse-e6a51a0bc3f87e284de74cec838d3ee98c2f9cf5.tar.gz | |
Use compat module for single Python 2/3 code base.
This change includes minor fixes and code cleanup too.
| -rw-r--r-- | .travis.yml | 1 | ||||
| -rw-r--r-- | docs/source/intro.rst | 8 | ||||
| -rw-r--r-- | setup.py | 5 | ||||
| -rw-r--r-- | sqlparse/__init__.py | 3 | ||||
| -rw-r--r-- | sqlparse/compat.py | 4 | ||||
| -rw-r--r-- | sqlparse/filters.py | 23 | ||||
| -rw-r--r-- | sqlparse/lexer.py | 29 | ||||
| -rw-r--r-- | sqlparse/sql.py | 23 | ||||
| -rw-r--r-- | tests/test_filters.py | 1 | ||||
| -rw-r--r-- | tests/test_grouping.py | 25 | ||||
| -rw-r--r-- | tests/test_parse.py | 10 | ||||
| -rw-r--r-- | tests/test_regressions.py | 1 | ||||
| -rw-r--r-- | tests/test_split.py | 35 | ||||
| -rw-r--r-- | tests/test_tokenize.py | 5 | ||||
| -rw-r--r-- | tests/utils.py | 10 | ||||
| -rw-r--r-- | tox.ini | 2 |
16 files changed, 93 insertions, 92 deletions
diff --git a/.travis.yml b/.travis.yml index 78bdc90..febc5bd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,6 @@ python: - "3.5" - "3.4" - "3.3" - - "3.2" - "2.7" - "pypy" - "pypy3" diff --git a/docs/source/intro.rst b/docs/source/intro.rst index 76d8fba..2a2d27d 100644 --- a/docs/source/intro.rst +++ b/docs/source/intro.rst @@ -100,10 +100,10 @@ Each object can be converted back to a string at any time: .. code-block:: python - >>> unicode(stmt) # str(stmt) for Python 3 - u'select * from "someschema"."mytable" where id = 1' - >>> unicode(stmt.tokens[-1]) # or just the WHERE part - u'where id = 1' + >>> str(stmt) # str(stmt) for Python 3 + 'select * from "someschema"."mytable" where id = 1' + >>> str(stmt.tokens[-1]) # or just the WHERE part + 'where id = 1' Details of the returned objects are described in :ref:`analyze`. @@ -67,8 +67,8 @@ Parsing:: >>> res (<Statement 'select...' at 0x9ad08ec>,) >>> stmt = res[0] - >>> unicode(stmt) # converting it back to unicode - u'select * from someschema.mytable where id = 1' + >>> str(stmt) # converting it back to unicode + 'select * from someschema.mytable where id = 1' >>> # This is how the internal representation looks like: >>> stmt.tokens (<DML 'select' at 0x9b63c34>, @@ -110,7 +110,6 @@ setup( 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.2', 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', diff --git a/sqlparse/__init__.py b/sqlparse/__init__.py index 77a09f4..e8252d7 100644 --- a/sqlparse/__init__.py +++ b/sqlparse/__init__.py @@ -14,6 +14,7 @@ from sqlparse import engine from sqlparse import filters from sqlparse import formatter +from sqlparse.compat import u # Deprecated in 0.1.5. Will be removed in 0.2.0 from sqlparse.exceptions import SQLParseError @@ -67,7 +68,7 @@ def split(sql, encoding=None): """ stack = engine.FilterStack() stack.split_statements = True - return [unicode(stmt).strip() for stmt in stack.run(sql, encoding)] + return [u(stmt).strip() for stmt in stack.run(sql, encoding)] from sqlparse.engine.filter import StatementFilter diff --git a/sqlparse/compat.py b/sqlparse/compat.py index 9efae26..6b26384 100644 --- a/sqlparse/compat.py +++ b/sqlparse/compat.py @@ -19,7 +19,7 @@ if PY3: from io import StringIO def u(s): - return s + return str(s) elif PY2: text_type = unicode @@ -27,7 +27,7 @@ elif PY2: from StringIO import StringIO # flake8: noqa def u(s): - return unicode(s, 'unicode_escape') + return unicode(s) # Directly copied from six: diff --git a/sqlparse/filters.py b/sqlparse/filters.py index 676344f..eabf863 100644 --- a/sqlparse/filters.py +++ b/sqlparse/filters.py @@ -5,6 +5,7 @@ import re from os.path import abspath, join from sqlparse import sql, tokens as T +from sqlparse.compat import u, text_type from sqlparse.engine import FilterStack from sqlparse.lexer import tokenize from sqlparse.pipeline import Pipeline @@ -25,7 +26,7 @@ class _CaseFilter: if case is None: case = 'upper' assert case in ['lower', 'upper', 'capitalize'] - self.convert = getattr(unicode, case) + self.convert = getattr(text_type, case) def process(self, stack, stream): for ttype, value in stream: @@ -52,7 +53,7 @@ class TruncateStringFilter: def __init__(self, width, char): self.width = max(width, 1) - self.char = unicode(char) + self.char = u(char) def process(self, stack, stream): for ttype, value in stream: @@ -154,7 +155,7 @@ class IncludeStatement: f.close() # There was a problem loading the include file - except IOError, err: + except IOError as err: # Raise the exception to the interpreter if self.raiseexceptions: raise @@ -171,7 +172,7 @@ class IncludeStatement: self.raiseexceptions) # Max recursion limit reached - except ValueError, err: + except ValueError as err: # Raise the exception to the interpreter if self.raiseexceptions: raise @@ -300,7 +301,7 @@ class ReindentFilter: raise StopIteration def _get_offset(self, token): - raw = ''.join(map(unicode, self._flatten_up_to_token(token))) + raw = ''.join(map(text_type, self._flatten_up_to_token(token))) line = raw.splitlines()[-1] # Now take current offset into account and return relative offset. full_offset = len(line) - len(self.char * (self.width * self.indent)) @@ -340,7 +341,7 @@ class ReindentFilter: if prev and prev.is_whitespace() and prev not in added: tlist.tokens.pop(tlist.token_index(prev)) offset += 1 - uprev = unicode(prev) + uprev = u(prev) if (prev and (uprev.endswith('\n') or uprev.endswith('\r'))): nl = tlist.token_next(token) else: @@ -462,7 +463,7 @@ class ReindentFilter: self._process(stmt) if isinstance(stmt, sql.Statement): if self._last_stmt is not None: - if unicode(self._last_stmt).endswith('\n'): + if u(self._last_stmt).endswith('\n'): nl = '\n' else: nl = '\n\n' @@ -494,7 +495,7 @@ class RightMarginFilter: and not token.__class__ in self.keep_together): token.tokens = self._process(stack, token, token.tokens) else: - val = unicode(token) + val = u(token) if len(self.line) + len(val) > self.width: match = re.search('^ +', self.line) if match is not None: @@ -568,7 +569,7 @@ class ColumnsSelect: class SerializerUnicode: def process(self, stack, stmt): - raw = unicode(stmt) + raw = u(stmt) lines = split_unquoted_newlines(raw) res = '\n'.join(line.rstrip() for line in lines) return res @@ -578,7 +579,7 @@ def Tokens2Unicode(stream): result = "" for _, value in stream: - result += unicode(value) + result += u(value) return result @@ -600,7 +601,7 @@ class OutputFilter: else: varname = self.varname - has_nl = len(unicode(stmt).strip().splitlines()) > 1 + has_nl = len(u(stmt).strip().splitlines()) > 1 stmt.tokens = self._process(stmt.tokens, varname, has_nl) return stmt diff --git a/sqlparse/lexer.py b/sqlparse/lexer.py index fd29f5c..2b0688a 100644 --- a/sqlparse/lexer.py +++ b/sqlparse/lexer.py @@ -17,7 +17,7 @@ import sys from sqlparse import tokens from sqlparse.keywords import KEYWORDS, KEYWORDS_COMMON -from cStringIO import StringIO +from sqlparse.compat import StringIO, string_types, with_metaclass, text_type class include(str): @@ -81,14 +81,14 @@ class LexerMeta(type): try: rex = re.compile(tdef[0], rflags).match - except Exception, err: + except Exception as err: raise ValueError(("uncompilable regex %r in state" " %r of %r: %s" % (tdef[0], state, cls, err))) assert type(tdef[1]) is tokens._TokenType or callable(tdef[1]), \ - ('token type must be simple type or callable, not %r' - % (tdef[1],)) + ('token type must be simple type or callable, not %r' + % (tdef[1],)) if len(tdef) == 2: new_state = None @@ -113,7 +113,7 @@ class LexerMeta(type): itokens = [] for istate in tdef2: assert istate != state, \ - 'circular state ref %r' % istate + 'circular state ref %r' % istate itokens.extend(cls._process_state(unprocessed, processed, istate)) processed[new_state] = itokens @@ -123,7 +123,7 @@ class LexerMeta(type): for state in tdef2: assert (state in unprocessed or state in ('#pop', '#push')), \ - 'unknown new state ' + state + 'unknown new state ' + state new_state = tdef2 else: assert False, 'unknown new state def %r' % tdef2 @@ -134,7 +134,7 @@ class LexerMeta(type): cls._all_tokens = {} cls._tmpname = 0 processed = cls._all_tokens[cls.__name__] = {} - #tokendefs = tokendefs or cls.tokens[name] + # tokendefs = tokendefs or cls.tokens[name] for state in cls.tokens.keys(): cls._process_state(cls.tokens, processed, state) return processed @@ -152,9 +152,7 @@ class LexerMeta(type): return type.__call__(cls, *args, **kwds) -class Lexer(object): - - __metaclass__ = LexerMeta +class _Lexer(object): encoding = 'utf-8' stripall = False @@ -201,7 +199,8 @@ class Lexer(object): # cannot be preceded by word character or a right bracket -- # otherwise it's probably an array index (r'(?<![\w\])])(\[[^\]]+\])', tokens.Name), - (r'((LEFT\s+|RIGHT\s+|FULL\s+)?(INNER\s+|OUTER\s+|STRAIGHT\s+)?|(CROSS\s+|NATURAL\s+)?)?JOIN\b', tokens.Keyword), + (r'((LEFT\s+|RIGHT\s+|FULL\s+)?(INNER\s+|OUTER\s+|STRAIGHT\s+)?' + r'|(CROSS\s+|NATURAL\s+)?)?JOIN\b', tokens.Keyword), (r'END(\s+IF|\s+LOOP)?\b', tokens.Keyword), (r'NOT NULL\b', tokens.Keyword), (r'CREATE(\s+OR\s+REPLACE)?\b', tokens.Keyword.DDL), @@ -258,13 +257,13 @@ class Lexer(object): Also preprocess the text, i.e. expand tabs and strip it if wanted and applies registered filters. """ - if isinstance(text, basestring): + if isinstance(text, string_types): if self.stripall: text = text.strip() elif self.stripnl: text = text.strip('\n') - if sys.version_info[0] < 3 and isinstance(text, unicode): + if sys.version_info[0] < 3 and isinstance(text, text_type): text = StringIO(text.encode('utf-8')) self.encoding = 'utf-8' else: @@ -350,6 +349,10 @@ class Lexer(object): break +class Lexer(with_metaclass(LexerMeta, _Lexer)): + pass + + def tokenize(sql, encoding=None): """Tokenize sql. diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 7325712..97dd24e 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -6,6 +6,7 @@ import re import sys from sqlparse import tokens as T +from sqlparse.compat import string_types, u class Token(object): @@ -32,7 +33,7 @@ class Token(object): if sys.version_info[0] == 3: return self.value else: - return unicode(self).encode('utf-8') + return u(self).encode('utf-8') def __repr__(self): short = self._get_repr_value() @@ -51,13 +52,13 @@ class Token(object): .. deprecated:: 0.1.5 Use ``unicode(token)`` (for Python 3: ``str(token)``) instead. """ - return unicode(self) + return u(self) def _get_repr_name(self): return str(self.ttype).split('.')[-1] def _get_repr_value(self): - raw = unicode(self) + raw = u(self) if len(raw) > 7: raw = raw[:6] + u'...' return re.sub('\s+', ' ', raw) @@ -83,7 +84,7 @@ class Token(object): return type_matched if regex: - if isinstance(values, basestring): + if isinstance(values, string_types): values = set([values]) if self.ttype is T.Keyword: @@ -96,7 +97,7 @@ class Token(object): return True return False - if isinstance(values, basestring): + if isinstance(values, string_types): if self.is_keyword: return values.upper() == self.normalized return values == self.value @@ -172,7 +173,7 @@ class TokenList(Token): if sys.version_info[0] == 3: return ''.join(x.value for x in self.flatten()) else: - return ''.join(unicode(x) for x in self.flatten()) + return ''.join(u(x) for x in self.flatten()) def _get_repr_name(self): return self.__class__.__name__ @@ -185,9 +186,9 @@ class TokenList(Token): pre = ' +-' else: pre = ' | ' - print '%s%s%d %s \'%s\'' % (indent, pre, idx, + print('%s%s%d %s \'%s\'' % (indent, pre, idx, token._get_repr_name(), - token._get_repr_value()) + token._get_repr_value())) if (token.is_group() and (max_depth is None or depth < max_depth)): token._pprint_tree(max_depth, depth + 1) @@ -285,7 +286,7 @@ class TokenList(Token): if not isinstance(idx, int): idx = self.token_index(idx) - for n in xrange(idx, len(self.tokens)): + for n in range(idx, len(self.tokens)): token = self.tokens[n] if token.match(ttype, value, regex): return token @@ -349,7 +350,7 @@ class TokenList(Token): # Performing `index` manually is much faster when starting in the middle # of the list of tokens and expecting to find the token near to the starting # index. - for i in xrange(start, len(self.tokens)): + for i in range(start, len(self.tokens)): if self.tokens[i] == token: return i return -1 @@ -518,7 +519,7 @@ class Identifier(TokenList): next_ = self.token_next(self.token_index(marker), False) if next_ is None: return None - return unicode(next_) + return u(next_) def get_ordering(self): """Returns the ordering or ``None`` as uppercase string.""" diff --git a/tests/test_filters.py b/tests/test_filters.py index d827454..eb61604 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -5,6 +5,7 @@ Created on 24/03/2012 ''' import unittest +from sqlparse.compat import u from sqlparse.filters import StripWhitespace, Tokens2Unicode from sqlparse.lexer import tokenize diff --git a/tests/test_grouping.py b/tests/test_grouping.py index 5ade830..fa68ab2 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -5,6 +5,7 @@ import pytest import sqlparse from sqlparse import sql from sqlparse import tokens as T +from sqlparse.compat import u from tests.utils import TestCaseBase @@ -26,7 +27,7 @@ class TestGrouping(TestCaseBase): def test_comments(self): s = '/*\n * foo\n */ \n bar' parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(parsed)) + self.ndiffAssertEqual(s, u(parsed)) self.assertEqual(len(parsed.tokens), 2) def test_assignment(self): @@ -42,18 +43,18 @@ class TestGrouping(TestCaseBase): def test_identifiers(self): s = 'select foo.bar from "myscheme"."table" where fail. order' parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(parsed)) + self.ndiffAssertEqual(s, u(parsed)) self.assert_(isinstance(parsed.tokens[2], sql.Identifier)) self.assert_(isinstance(parsed.tokens[6], sql.Identifier)) self.assert_(isinstance(parsed.tokens[8], sql.Where)) s = 'select * from foo where foo.id = 1' parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(parsed)) + self.ndiffAssertEqual(s, u(parsed)) self.assert_(isinstance(parsed.tokens[-1].tokens[-1].tokens[0], sql.Identifier)) s = 'select * from (select "foo"."id" from foo)' parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(parsed)) + self.ndiffAssertEqual(s, u(parsed)) self.assert_(isinstance(parsed.tokens[-1].tokens[3], sql.Identifier)) s = "INSERT INTO `test` VALUES('foo', 'bar');" @@ -141,44 +142,44 @@ class TestGrouping(TestCaseBase): def test_where(self): s = 'select * from foo where bar = 1 order by id desc' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(p)) + self.ndiffAssertEqual(s, u(p)) self.assertTrue(len(p.tokens), 16) s = 'select x from (select y from foo where bar = 1) z' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(p)) + self.ndiffAssertEqual(s, u(p)) self.assertTrue(isinstance(p.tokens[-1].tokens[0].tokens[-2], sql.Where)) def test_typecast(self): s = 'select foo::integer from bar' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(p)) + self.ndiffAssertEqual(s, u(p)) self.assertEqual(p.tokens[2].get_typecast(), 'integer') self.assertEqual(p.tokens[2].get_name(), 'foo') s = 'select (current_database())::information_schema.sql_identifier' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(p)) + self.ndiffAssertEqual(s, u(p)) self.assertEqual(p.tokens[2].get_typecast(), 'information_schema.sql_identifier') def test_alias(self): s = 'select foo as bar from mytable' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(p)) + self.ndiffAssertEqual(s, u(p)) self.assertEqual(p.tokens[2].get_real_name(), 'foo') self.assertEqual(p.tokens[2].get_alias(), 'bar') s = 'select foo from mytable t1' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(p)) + self.ndiffAssertEqual(s, u(p)) self.assertEqual(p.tokens[6].get_real_name(), 'mytable') self.assertEqual(p.tokens[6].get_alias(), 't1') s = 'select foo::integer as bar from mytable' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(p)) + self.ndiffAssertEqual(s, u(p)) self.assertEqual(p.tokens[2].get_alias(), 'bar') s = ('SELECT DISTINCT ' '(current_database())::information_schema.sql_identifier AS view') p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, unicode(p)) + self.ndiffAssertEqual(s, u(p)) self.assertEqual(p.tokens[4].get_alias(), 'view') def test_alias_case(self): # see issue46 diff --git a/tests/test_parse.py b/tests/test_parse.py index 6c9d6a6..fb7b24b 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -8,8 +8,8 @@ from tests.utils import TestCaseBase import sqlparse import sqlparse.sql - from sqlparse import tokens as T +from sqlparse.compat import u class SQLParseTest(TestCaseBase): @@ -32,16 +32,16 @@ class SQLParseTest(TestCaseBase): def test_newlines(self): sql = u'select\n*from foo;' p = sqlparse.parse(sql)[0] - self.assertEqual(unicode(p), sql) + self.assertEqual(u(p), sql) sql = u'select\r\n*from foo' p = sqlparse.parse(sql)[0] - self.assertEqual(unicode(p), sql) + self.assertEqual(u(p), sql) sql = u'select\r*from foo' p = sqlparse.parse(sql)[0] - self.assertEqual(unicode(p), sql) + self.assertEqual(u(p), sql) sql = u'select\r\n*from foo\n' p = sqlparse.parse(sql)[0] - self.assertEqual(unicode(p), sql) + self.assertEqual(u(p), sql) def test_within(self): sql = 'foo(col1, col2)' diff --git a/tests/test_regressions.py b/tests/test_regressions.py index a64b400..f873c78 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -256,6 +256,7 @@ SELECT * FROM a.b;""" splitted = sqlparse.split(sql) assert len(splitted) == 2 + def test_issue194_splitting_function(): sql = """CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20) BEGIN diff --git a/tests/test_split.py b/tests/test_split.py index 54e8d04..f6d5f50 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -2,11 +2,12 @@ # Tests splitting functions. -import unittest +import types from tests.utils import load_file, TestCaseBase import sqlparse +from sqlparse.compat import StringIO, u, text_type class SQLSplitTest(TestCaseBase): @@ -19,8 +20,8 @@ class SQLSplitTest(TestCaseBase): sql2 = 'select * from foo where bar = \'foo;bar\';' stmts = sqlparse.parse(''.join([self._sql1, sql2])) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(unicode(stmts[0]), self._sql1) - self.ndiffAssertEqual(unicode(stmts[1]), sql2) + self.ndiffAssertEqual(u(stmts[0]), self._sql1) + self.ndiffAssertEqual(u(stmts[1]), sql2) def test_split_backslash(self): stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';") @@ -30,31 +31,31 @@ class SQLSplitTest(TestCaseBase): sql = load_file('function.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(unicode(stmts[0]), sql) + self.ndiffAssertEqual(u(stmts[0]), sql) def test_create_function_psql(self): sql = load_file('function_psql.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(unicode(stmts[0]), sql) + self.ndiffAssertEqual(u(stmts[0]), sql) def test_create_function_psql3(self): sql = load_file('function_psql3.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(unicode(stmts[0]), sql) + self.ndiffAssertEqual(u(stmts[0]), sql) def test_create_function_psql2(self): sql = load_file('function_psql2.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(unicode(stmts[0]), sql) + self.ndiffAssertEqual(u(stmts[0]), sql) def test_dashcomments(self): sql = load_file('dashcomment.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 3) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) def test_dashcomments_eol(self): stmts = sqlparse.parse('select foo; -- comment\n') @@ -70,19 +71,19 @@ class SQLSplitTest(TestCaseBase): sql = load_file('begintag.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 3) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) def test_begintag_2(self): sql = load_file('begintag_2.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) def test_dropif(self): sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;' stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) def test_comment_with_umlaut(self): sql = (u'select * from foo;\n' @@ -90,16 +91,16 @@ class SQLSplitTest(TestCaseBase): u'select * from bar;') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) def test_comment_end_of_line(self): sql = ('select * from foo; -- foo\n' 'select * from bar;') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) # make sure the comment belongs to first query - self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n') + self.ndiffAssertEqual(u(stmts[0]), 'select * from foo; -- foo\n') def test_casewhen(self): sql = ('SELECT case when val = 1 then 2 else null end as foo;\n' @@ -122,19 +123,15 @@ class SQLSplitTest(TestCaseBase): self.assertEqual(len(stmts), 2) def test_split_stream(self): - import types - from cStringIO import StringIO - stream = StringIO("SELECT 1; SELECT 2;") stmts = sqlparse.parsestream(stream) self.assertEqual(type(stmts), types.GeneratorType) self.assertEqual(len(list(stmts)), 2) def test_encoding_parsestream(self): - from cStringIO import StringIO stream = StringIO("SELECT 1; SELECT 2;") stmts = list(sqlparse.parsestream(stream)) - self.assertEqual(type(stmts[0].tokens[0].value), unicode) + self.assertEqual(type(stmts[0].tokens[0].value), text_type) def test_split_simple(): diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py index 0b23fa8..ceaf24e 100644 --- a/tests/test_tokenize.py +++ b/tests/test_tokenize.py @@ -9,6 +9,7 @@ import pytest import sqlparse from sqlparse import lexer from sqlparse import sql +from sqlparse.compat import StringIO from sqlparse.tokens import * @@ -133,8 +134,6 @@ class TestTokenList(unittest.TestCase): class TestStream(unittest.TestCase): def test_simple(self): - from cStringIO import StringIO - stream = StringIO("SELECT 1; SELECT 2;") lex = lexer.Lexer() @@ -152,8 +151,6 @@ class TestStream(unittest.TestCase): self.assertEqual(len(tokens), 9) def test_error(self): - from cStringIO import StringIO - stream = StringIO("FOOBAR{") lex = lexer.Lexer() diff --git a/tests/utils.py b/tests/utils.py index 9eb46bf..b596ff4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,9 +6,9 @@ import codecs import difflib import os import unittest -from StringIO import StringIO import sqlparse.utils +from sqlparse.compat import u, StringIO NL = '\n' DIR_PATH = os.path.abspath(os.path.dirname(__file__)) @@ -31,8 +31,8 @@ class TestCaseBase(unittest.TestCase): def ndiffAssertEqual(self, first, second): """Like failUnlessEqual except use ndiff for readable output.""" if first != second: - sfirst = unicode(first) - ssecond = unicode(second) + sfirst = u(first) + ssecond = u(second) # Using the built-in .splitlines() method here will cause incorrect # results when splitting statements that have quoted CR/CR+LF # characters. @@ -42,5 +42,5 @@ class TestCaseBase(unittest.TestCase): fp = StringIO() fp.write(NL) fp.write(NL.join(diff)) - print fp.getvalue() - raise self.failureException, fp.getvalue() + # print(fp.getvalue()) + raise self.failureException(fp.getvalue()) @@ -1,5 +1,5 @@ [tox] -envlist=py27,py32,py33,py34,py35,pypy,pypy3 +envlist=py27,py33,py34,py35,pypy,pypy3 [testenv] deps= |
