summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVik <vmuriart@users.noreply.github.com>2016-06-11 05:29:39 -0700
committerGitHub <noreply@github.com>2016-06-11 05:29:39 -0700
commit751933d3abdce2234bd869ee65a1ebc7ccbf6b53 (patch)
tree77fd5be087ae64c4416c0bd775f6c1843b45a007
parent00304afc15a554f2ac8decca1d916ba66c143b45 (diff)
parent1fd3da42bd55bfc1e916e3c3f301f0364b0ded21 (diff)
downloadsqlparse-751933d3abdce2234bd869ee65a1ebc7ccbf6b53.tar.gz
Merge pull request #254 from vmuriart/tests_str-format
Add various tests and change to new style str-format
-rw-r--r--docs/source/conf.py8
-rw-r--r--examples/column_defs_lowlevel.py4
-rw-r--r--examples/extract_table_names.py15
-rw-r--r--setup.cfg3
-rw-r--r--setup.py7
-rw-r--r--sqlparse/engine/grouping.py7
-rw-r--r--sqlparse/filters/output.py2
-rw-r--r--sqlparse/filters/reindent.py8
-rw-r--r--sqlparse/filters/right_margin.py2
-rw-r--r--sqlparse/formatter.py40
-rw-r--r--sqlparse/lexer.py15
-rw-r--r--sqlparse/sql.py48
-rw-r--r--sqlparse/utils.py20
-rw-r--r--tests/test_format.py30
-rw-r--r--tests/test_grouping.py2
-rw-r--r--tests/test_parse.py64
-rw-r--r--tests/test_regressions.py27
-rw-r--r--tests/test_split.py5
-rw-r--r--tests/test_tokenize.py2
19 files changed, 188 insertions, 121 deletions
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 5f7d34f..70bd69a 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -42,8 +42,8 @@ source_suffix = '.rst'
master_doc = 'index'
# General information about the project.
-project = u'python-sqlparse'
-copyright = u'%s, Andi Albrecht' % datetime.date.today().strftime('%Y')
+project = 'python-sqlparse'
+copyright = '{:%Y}, Andi Albrecht'.format(datetime.date.today())
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
@@ -177,8 +177,8 @@ htmlhelp_basename = 'python-sqlparsedoc'
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass [howto/manual]).
latex_documents = [
- ('index', 'python-sqlparse.tex', ur'python-sqlparse Documentation',
- ur'Andi Albrecht', 'manual'),
+ ('index', 'python-sqlparse.tex', 'python-sqlparse Documentation',
+ 'Andi Albrecht', 'manual'),
]
# The name of an image file (relative to this directory) to place at the top of
diff --git a/examples/column_defs_lowlevel.py b/examples/column_defs_lowlevel.py
index 7cce753..5acbdec 100644
--- a/examples/column_defs_lowlevel.py
+++ b/examples/column_defs_lowlevel.py
@@ -49,5 +49,5 @@ def extract_definitions(token_list):
columns = extract_definitions(par)
for column in columns:
- print('NAME: %-12s DEFINITION: %s' % (column[0],
- ''.join(str(t) for t in column[1:])))
+ print('NAME: {name:10} DEFINITION: {definition}'.format(
+ name=column[0], definition=''.join(str(t) for t in column[1:])))
diff --git a/examples/extract_table_names.py b/examples/extract_table_names.py
index b43ee5f..c1bcf8b 100644
--- a/examples/extract_table_names.py
+++ b/examples/extract_table_names.py
@@ -12,11 +12,6 @@
# See:
# http://groups.google.com/group/sqlparse/browse_thread/thread/b0bd9a022e9d4895
-sql = """
-select K.a,K.b from (select H.b from (select G.c from (select F.d from
-(select E.e from A, B, C, D, E), F), G), H), I, J, K order by 1,2;
-"""
-
import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, DML
@@ -59,10 +54,16 @@ def extract_table_identifiers(token_stream):
yield item.value
-def extract_tables():
+def extract_tables(sql):
stream = extract_from_part(sqlparse.parse(sql)[0])
return list(extract_table_identifiers(stream))
if __name__ == '__main__':
- print('Tables: %s' % ', '.join(extract_tables()))
+ sql = """
+ select K.a,K.b from (select H.b from (select G.c from (select F.d from
+ (select E.e from A, B, C, D, E), F), G), H), I, J, K order by 1,2;
+ """
+
+ tables = ', '.join(extract_tables(sql))
+ print('Tables: {0}'.format(tables))
diff --git a/setup.cfg b/setup.cfg
index c3bf82b..abc0206 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,6 +1,9 @@
[wheel]
universal = 1
+[pytest]
+xfail_strict=true
+
[flake8]
exclude =
sqlparse/compat.py
diff --git a/setup.py b/setup.py
index ffdbdb9..45d560f 100644
--- a/setup.py
+++ b/setup.py
@@ -21,14 +21,13 @@ except ImportError:
def get_version():
"""Parse __init__.py for version number instead of importing the file."""
VERSIONFILE = 'sqlparse/__init__.py'
- verstrline = open(VERSIONFILE, "rt").read()
VSRE = r'^__version__ = [\'"]([^\'"]*)[\'"]'
+ with open(VERSIONFILE) as f:
+ verstrline = f.read()
mo = re.search(VSRE, verstrline, re.M)
if mo:
return mo.group(1)
- else:
- raise RuntimeError('Unable to find version string in %s.'
- % (VERSIONFILE,))
+ raise RuntimeError('Unable to find version in {fn}'.format(fn=VERSIONFILE))
LONG_DESCRIPTION = """
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index c680995..e8e9dc3 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -42,13 +42,14 @@ def _group_left_right(tlist, m, cls,
def _group_matching(tlist, cls):
"""Groups Tokens that have beginning and end."""
- idx = 1 if imt(tlist, i=cls) else 0
+ idx = 1 if isinstance(tlist, cls) else 0
token = tlist.token_next_by(m=cls.M_OPEN, idx=idx)
while token:
end = find_matching(tlist, token, cls.M_OPEN, cls.M_CLOSE)
if end is not None:
- token = tlist.group_tokens(cls, tlist.tokens_between(token, end))
+ tokens = tlist.tokens_between(token, end)
+ token = tlist.group_tokens(cls, tokens)
_group_matching(token, cls)
token = tlist.token_next_by(m=cls.M_OPEN, idx=token)
@@ -120,7 +121,7 @@ def group_period(tlist):
def group_arrays(tlist):
token = tlist.token_next_by(i=sql.SquareBrackets)
while token:
- prev = tlist.token_prev(idx=token)
+ prev = tlist.token_prev(token)
if imt(prev, i=(sql.SquareBrackets, sql.Identifier, sql.Function),
t=(T.Name, T.String.Symbol,)):
tokens = tlist.tokens_between(prev, token)
diff --git a/sqlparse/filters/output.py b/sqlparse/filters/output.py
index d4528e9..bbc5076 100644
--- a/sqlparse/filters/output.py
+++ b/sqlparse/filters/output.py
@@ -22,7 +22,7 @@ class OutputFilter(object):
def process(self, stmt):
self.count += 1
if self.count > 1:
- varname = '%s%d' % (self.varname, self.count)
+ varname = '{f.varname}{f.count}'.format(f=self)
else:
varname = self.varname
diff --git a/sqlparse/filters/reindent.py b/sqlparse/filters/reindent.py
index f7ddfc9..b490631 100644
--- a/sqlparse/filters/reindent.py
+++ b/sqlparse/filters/reindent.py
@@ -97,13 +97,13 @@ class ReindentFilter(object):
self._process_default(tlist)
def _process_parenthesis(self, tlist):
- is_DML_DLL = tlist.token_next_by(t=(T.Keyword.DML, T.Keyword.DDL))
+ is_dml_dll = tlist.token_next_by(t=(T.Keyword.DML, T.Keyword.DDL))
first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN)
- with indent(self, 1 if is_DML_DLL else 0):
- tlist.tokens.insert(0, self.nl()) if is_DML_DLL else None
+ with indent(self, 1 if is_dml_dll else 0):
+ tlist.tokens.insert(0, self.nl()) if is_dml_dll else None
with offset(self, self._get_offset(first) + 1):
- self._process_default(tlist, not is_DML_DLL)
+ self._process_default(tlist, not is_dml_dll)
def _process_identifierlist(self, tlist):
identifiers = list(tlist.get_identifiers())
diff --git a/sqlparse/filters/right_margin.py b/sqlparse/filters/right_margin.py
index 4e10dc0..b3f905d 100644
--- a/sqlparse/filters/right_margin.py
+++ b/sqlparse/filters/right_margin.py
@@ -38,7 +38,7 @@ class RightMarginFilter(object):
indent = match.group()
else:
indent = ''
- yield sql.Token(T.Whitespace, '\n%s' % indent)
+ yield sql.Token(T.Whitespace, '\n{0}'.format(indent))
self.line = indent
self.line += val
yield token
diff --git a/sqlparse/formatter.py b/sqlparse/formatter.py
index 069109b..8f10557 100644
--- a/sqlparse/formatter.py
+++ b/sqlparse/formatter.py
@@ -15,61 +15,65 @@ def validate_options(options):
"""Validates options."""
kwcase = options.get('keyword_case')
if kwcase not in [None, 'upper', 'lower', 'capitalize']:
- raise SQLParseError('Invalid value for keyword_case: %r' % kwcase)
+ raise SQLParseError('Invalid value for keyword_case: '
+ '{0!r}'.format(kwcase))
idcase = options.get('identifier_case')
if idcase not in [None, 'upper', 'lower', 'capitalize']:
- raise SQLParseError('Invalid value for identifier_case: %r' % idcase)
+ raise SQLParseError('Invalid value for identifier_case: '
+ '{0!r}'.format(idcase))
ofrmt = options.get('output_format')
if ofrmt not in [None, 'sql', 'python', 'php']:
- raise SQLParseError('Unknown output format: %r' % ofrmt)
+ raise SQLParseError('Unknown output format: '
+ '{0!r}'.format(ofrmt))
strip_comments = options.get('strip_comments', False)
if strip_comments not in [True, False]:
- raise SQLParseError('Invalid value for strip_comments: %r'
- % strip_comments)
+ raise SQLParseError('Invalid value for strip_comments: '
+ '{0!r}'.format(strip_comments))
space_around_operators = options.get('use_space_around_operators', False)
if space_around_operators not in [True, False]:
- raise SQLParseError('Invalid value for use_space_around_operators: %r'
- % space_around_operators)
+ raise SQLParseError('Invalid value for use_space_around_operators: '
+ '{0!r}'.format(space_around_operators))
strip_ws = options.get('strip_whitespace', False)
if strip_ws not in [True, False]:
- raise SQLParseError('Invalid value for strip_whitespace: %r'
- % strip_ws)
+ raise SQLParseError('Invalid value for strip_whitespace: '
+ '{0!r}'.format(strip_ws))
truncate_strings = options.get('truncate_strings')
if truncate_strings is not None:
try:
truncate_strings = int(truncate_strings)
except (ValueError, TypeError):
- raise SQLParseError('Invalid value for truncate_strings: %r'
- % truncate_strings)
+ raise SQLParseError('Invalid value for truncate_strings: '
+ '{0!r}'.format(truncate_strings))
if truncate_strings <= 1:
- raise SQLParseError('Invalid value for truncate_strings: %r'
- % truncate_strings)
+ raise SQLParseError('Invalid value for truncate_strings: '
+ '{0!r}'.format(truncate_strings))
options['truncate_strings'] = truncate_strings
options['truncate_char'] = options.get('truncate_char', '[...]')
reindent = options.get('reindent', False)
if reindent not in [True, False]:
- raise SQLParseError('Invalid value for reindent: %r'
- % reindent)
+ raise SQLParseError('Invalid value for reindent: '
+ '{0!r}'.format(reindent))
elif reindent:
options['strip_whitespace'] = True
reindent_aligned = options.get('reindent_aligned', False)
if reindent_aligned not in [True, False]:
- raise SQLParseError('Invalid value for reindent_aligned: %r'
- % reindent)
+ raise SQLParseError('Invalid value for reindent_aligned: '
+ '{0!r}'.format(reindent))
elif reindent_aligned:
options['strip_whitespace'] = True
indent_tabs = options.get('indent_tabs', False)
if indent_tabs not in [True, False]:
- raise SQLParseError('Invalid value for indent_tabs: %r' % indent_tabs)
+ raise SQLParseError('Invalid value for indent_tabs: '
+ '{0!r}'.format(indent_tabs))
elif indent_tabs:
options['indent_char'] = '\t'
else:
diff --git a/sqlparse/lexer.py b/sqlparse/lexer.py
index dd15212..0fb8936 100644
--- a/sqlparse/lexer.py
+++ b/sqlparse/lexer.py
@@ -14,7 +14,7 @@
from sqlparse import tokens
from sqlparse.keywords import SQL_REGEX
-from sqlparse.compat import StringIO, string_types, text_type
+from sqlparse.compat import StringIO, string_types, u
from sqlparse.utils import consume
@@ -37,17 +37,10 @@ class Lexer(object):
``stack`` is the inital stack (default: ``['root']``)
"""
- encoding = encoding or 'utf-8'
-
if isinstance(text, string_types):
- text = StringIO(text)
-
- text = text.read()
- if not isinstance(text, text_type):
- try:
- text = text.decode(encoding)
- except UnicodeDecodeError:
- text = text.decode('unicode-escape')
+ text = u(text, encoding)
+ elif isinstance(text, StringIO):
+ text = u(text.read(), encoding)
iterable = enumerate(text)
for pos, char in iterable:
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index daa5cf5..eadd04f 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -167,7 +167,7 @@ class TokenList(Token):
idx = 0
for token in self.flatten():
end = idx + len(token.value)
- if idx <= offset <= end:
+ if idx <= offset < end:
return token
idx = end
@@ -248,8 +248,6 @@ class TokenList(Token):
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
``None`` is returned if there's no previous token.
"""
- if isinstance(idx, int):
- idx += 1 # alot of code usage current pre-compensates for this
funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
(skip_cm and imt(tk, t=T.Comment)))
return self._token_matching(funcs, idx, reverse=True)
@@ -260,8 +258,6 @@ class TokenList(Token):
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
``None`` is returned if there's no next token.
"""
- if isinstance(idx, int):
- idx += 1 # alot of code usage current pre-compensates for this
funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
(skip_cm and imt(tk, t=T.Comment)))
return self._token_matching(funcs, idx)
@@ -283,34 +279,26 @@ class TokenList(Token):
def group_tokens(self, grp_cls, tokens, skip_ws=False, extend=False):
"""Replace tokens by an instance of *grp_cls*."""
- if skip_ws:
- while tokens and tokens[-1].is_whitespace():
- tokens = tokens[:-1]
+
+ while skip_ws and tokens and tokens[-1].is_whitespace():
+ tokens = tokens[:-1]
left = tokens[0]
idx = self.token_index(left)
- if extend:
- if not isinstance(left, grp_cls):
- grp = grp_cls([left])
- self.tokens.remove(left)
- self.tokens.insert(idx, grp)
- left = grp
- left.parent = self
- tokens = tokens[1:]
- left.tokens.extend(tokens)
- left.value = str(left)
-
+ if extend and isinstance(left, grp_cls):
+ grp = left
+ grp.tokens.extend(tokens[1:])
else:
- left = grp_cls(tokens)
- left.parent = self
- self.tokens.insert(idx, left)
+ grp = grp_cls(tokens)
for token in tokens:
- token.parent = left
+ token.parent = grp
self.tokens.remove(token)
- return left
+ self.tokens.insert(idx, grp)
+ grp.parent = self
+ return grp
def insert_before(self, where, token):
"""Inserts *token* before *where*."""
@@ -322,7 +310,7 @@ class TokenList(Token):
if next_token is None:
self.tokens.append(token)
else:
- self.tokens.insert(self.token_index(next_token), token)
+ self.insert_before(next_token, token)
def has_alias(self):
"""Returns ``True`` if an alias is present."""
@@ -435,19 +423,13 @@ class Identifier(TokenList):
def get_typecast(self):
"""Returns the typecast or ``None`` of this object as a string."""
marker = self.token_next_by(m=(T.Punctuation, '::'))
- if marker is None:
- return None
next_ = self.token_next(marker, False)
- if next_ is None:
- return None
- return next_.value
+ return next_.value if next_ else None
def get_ordering(self):
"""Returns the ordering or ``None`` as uppercase string."""
ordering = self.token_next_by(t=T.Keyword.Order)
- if ordering is None:
- return None
- return ordering.normalized
+ return ordering.normalized if ordering else None
def get_array_indices(self):
"""Returns an iterator of index token lists"""
diff --git a/sqlparse/utils.py b/sqlparse/utils.py
index 4da44c6..8253e0b 100644
--- a/sqlparse/utils.py
+++ b/sqlparse/utils.py
@@ -78,36 +78,36 @@ def recurse(*cls):
def imt(token, i=None, m=None, t=None):
- """Aid function to refactor comparisons for Instance, Match and TokenType
- Aid fun
+ """Helper function to simplify comparisons Instance, Match and TokenType
:param token:
:param i: Class or Tuple/List of Classes
:param m: Tuple of TokenType & Value. Can be list of Tuple for multiple
:param t: TokenType or Tuple/List of TokenTypes
:return: bool
"""
- t = (t,) if t and not isinstance(t, (list, tuple)) else t
- m = (m,) if m and not isinstance(m, (list,)) else m
+ clss = i
+ types = [t, ] if t and not isinstance(t, list) else t
+ mpatterns = [m, ] if m and not isinstance(m, list) else m
if token is None:
return False
- elif i is not None and isinstance(token, i):
+ elif clss and isinstance(token, clss):
return True
- elif m is not None and any((token.match(*x) for x in m)):
+ elif mpatterns and any((token.match(*pattern) for pattern in mpatterns)):
return True
- elif t is not None and token.ttype in t:
+ elif types and any([token.ttype in ttype for ttype in types]):
return True
else:
return False
-def find_matching(tlist, token, M1, M2):
+def find_matching(tlist, token, open_pattern, close_pattern):
idx = tlist.token_index(token)
depth = 0
for token in tlist.tokens[idx:]:
- if token.match(*M1):
+ if token.match(*open_pattern):
depth += 1
- elif token.match(*M2):
+ elif token.match(*close_pattern):
depth -= 1
if depth == 0:
return token
diff --git a/tests/test_format.py b/tests/test_format.py
index 7b5af06..74fce71 100644
--- a/tests/test_format.py
+++ b/tests/test_format.py
@@ -524,6 +524,22 @@ class TestOutputFormat(TestCaseBase):
self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n"
" 'from foo;')"))
+ def test_python_multiple_statements(self):
+ sql = 'select * from foo; select 1 from dual'
+ f = lambda sql: sqlparse.format(sql, output_format='python')
+ self.ndiffAssertEqual(f(sql), ("sql = 'select * from foo; '\n"
+ "sql2 = 'select 1 from dual'"))
+
+ @pytest.mark.xfail(reason="Needs fixing")
+ def test_python_multiple_statements_with_formatting(self):
+ sql = 'select * from foo; select 1 from dual'
+ f = lambda sql: sqlparse.format(sql, output_format='python',
+ reindent=True)
+ self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n"
+ " 'from foo;')\n"
+ "sql2 = ('select 1 '\n"
+ " 'from dual')"))
+
def test_php(self):
sql = 'select * from foo;'
f = lambda sql: sqlparse.format(sql, output_format='php')
@@ -587,3 +603,17 @@ def test_having_produces_newline():
'having sum(bar.value) > 100'
]
assert formatted == '\n'.join(expected)
+
+
+def test_format_right_margin_invalid_input():
+ with pytest.raises(SQLParseError):
+ sqlparse.format('foo', right_margin=2)
+
+ with pytest.raises(SQLParseError):
+ sqlparse.format('foo', right_margin="two")
+
+
+@pytest.mark.xfail(reason="Needs fixing")
+def test_format_right_margin():
+ # TODO: Needs better test, only raises exception right now
+ sqlparse.format('foo', right_margin="79")
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index fdcd4a7..7ea1c75 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -373,7 +373,7 @@ def test_comparison_with_functions(): # issue230
@pytest.mark.parametrize('start', ['FOR', 'FOREACH'])
def test_forloops(start):
- p = sqlparse.parse('%s foo in bar LOOP foobar END LOOP' % start)[0]
+ p = sqlparse.parse('{0} foo in bar LOOP foobar END LOOP'.format(start))[0]
assert (len(p.tokens)) == 1
assert isinstance(p.tokens[0], sql.For)
diff --git a/tests/test_parse.py b/tests/test_parse.py
index 2ea0f40..8654ec4 100644
--- a/tests/test_parse.py
+++ b/tests/test_parse.py
@@ -9,7 +9,7 @@ from tests.utils import TestCaseBase
import sqlparse
import sqlparse.sql
from sqlparse import tokens as T
-from sqlparse.compat import u
+from sqlparse.compat import u, StringIO
class SQLParseTest(TestCaseBase):
@@ -178,9 +178,9 @@ def test_psql_quotation_marks(): # issue83
def test_double_precision_is_builtin():
sql = 'DOUBLE PRECISION'
t = sqlparse.parse(sql)[0].tokens
- assert (len(t) == 1
- and t[0].ttype == sqlparse.tokens.Name.Builtin
- and t[0].value == 'DOUBLE PRECISION')
+ assert len(t) == 1
+ assert t[0].ttype == sqlparse.tokens.Name.Builtin
+ assert t[0].value == 'DOUBLE PRECISION'
@pytest.mark.parametrize('ph', ['?', ':1', ':foo', '%s', '%(foo)s'])
@@ -218,10 +218,10 @@ def test_single_quotes_with_linebreaks(): # issue118
def test_sqlite_identifiers():
# Make sure we still parse sqlite style escapes
p = sqlparse.parse('[col1],[col2]')[0].tokens
- assert (len(p) == 1
- and isinstance(p[0], sqlparse.sql.IdentifierList)
- and [id.get_name() for id in p[0].get_identifiers()]
- == ['[col1]', '[col2]'])
+ id_names = [id.get_name() for id in p[0].get_identifiers()]
+ assert len(p) == 1
+ assert isinstance(p[0], sqlparse.sql.IdentifierList)
+ assert id_names == ['[col1]', '[col2]']
p = sqlparse.parse('[col1]+[col2]')[0]
types = [tok.ttype for tok in p.flatten()]
@@ -233,9 +233,9 @@ def test_simple_1d_array_index():
assert len(p) == 1
assert p[0].get_name() == 'col'
indices = list(p[0].get_array_indices())
- assert (len(indices) == 1 # 1-dimensional index
- and len(indices[0]) == 1 # index is single token
- and indices[0][0].value == '1')
+ assert len(indices) == 1 # 1-dimensional index
+ assert len(indices[0]) == 1 # index is single token
+ assert indices[0][0].value == '1'
def test_2d_array_index():
@@ -303,3 +303,45 @@ def test_names_and_special_names(sql):
p = sqlparse.parse(sql)[0]
assert len(p.tokens) == 1
assert isinstance(p.tokens[0], sqlparse.sql.Identifier)
+
+
+def test_get_token_at_offset():
+ # 0123456789
+ p = sqlparse.parse('select * from dual')[0]
+ assert p.get_token_at_offset(0) == p.tokens[0]
+ assert p.get_token_at_offset(1) == p.tokens[0]
+ assert p.get_token_at_offset(6) == p.tokens[1]
+ assert p.get_token_at_offset(7) == p.tokens[2]
+ assert p.get_token_at_offset(8) == p.tokens[3]
+ assert p.get_token_at_offset(9) == p.tokens[4]
+ assert p.get_token_at_offset(10) == p.tokens[4]
+
+
+def test_pprint():
+ p = sqlparse.parse('select * from dual')[0]
+ output = StringIO()
+
+ p._pprint_tree(f=output)
+ pprint = u'\n'.join([
+ " | 0 DML 'select'",
+ " | 1 Whitespace ' '",
+ " | 2 Wildcard '*'",
+ " | 3 Whitespace ' '",
+ " | 4 Keyword 'from'",
+ " | 5 Whitespace ' '",
+ " +-6 Identifier 'dual'",
+ " | 0 Name 'dual'",
+ "",
+ ])
+ assert output.getvalue() == pprint
+
+
+def test_wildcard_multiplication():
+ p = sqlparse.parse('select * from dual')[0]
+ assert p.tokens[2].ttype == T.Wildcard
+
+ p = sqlparse.parse('select a0.* from dual a0')[0]
+ assert p.tokens[2][2].ttype == T.Wildcard
+
+ p = sqlparse.parse('select 1 * 2 from dual')[0]
+ assert p.tokens[2][2].ttype == T.Operator
diff --git a/tests/test_regressions.py b/tests/test_regressions.py
index 13ca04b..9f3c9a9 100644
--- a/tests/test_regressions.py
+++ b/tests/test_regressions.py
@@ -171,10 +171,11 @@ def test_comment_encoding_when_reindent():
def test_parse_sql_with_binary():
# See https://github.com/andialbrecht/sqlparse/pull/88
+ # digest = '‚|ËêŠplL4¡h‘øN{'
digest = '\x82|\xcb\x0e\xea\x8aplL4\xa1h\x91\xf8N{'
- sql = 'select * from foo where bar = \'%s\'' % digest
+ sql = "select * from foo where bar = '{0}'".format(digest)
formatted = sqlparse.format(sql, reindent=True)
- tformatted = 'select *\nfrom foo\nwhere bar = \'%s\'' % digest
+ tformatted = "select *\nfrom foo\nwhere bar = '{0}'".format(digest)
if sys.version_info < (3,):
tformatted = tformatted.decode('unicode-escape')
assert formatted == tformatted
@@ -193,10 +194,8 @@ def test_dont_alias_keywords():
def test_format_accepts_encoding(): # issue20
sql = load_file('test_cp1251.sql', 'cp1251')
formatted = sqlparse.format(sql, reindent=True, encoding='cp1251')
- if sys.version_info < (3,):
- tformatted = u'insert into foo\nvalues (1); -- Песня про надежду\n'
- else:
- tformatted = 'insert into foo\nvalues (1); -- Песня про надежду\n'
+ tformatted = u'insert into foo\nvalues (1); -- Песня про надежду\n'
+
assert formatted == tformatted
@@ -278,10 +277,7 @@ def test_issue186_get_type():
def test_issue212_py2unicode():
- if sys.version_info < (3,):
- t1 = sql.Token(T.String, u"schöner ")
- else:
- t1 = sql.Token(T.String, "schöner ")
+ t1 = sql.Token(T.String, u"schöner ")
t2 = sql.Token(T.String, u"bug")
l = sql.TokenList([t1, t2])
assert str(l) == 'schöner bug'
@@ -304,3 +300,14 @@ def test_issue227_gettype_cte():
INSERT INTO elsewhere SELECT * FROM foo JOIN bar;
''')
assert with2_stmt[0].get_type() == 'INSERT'
+
+
+def test_issue207_runaway_format():
+ sql = 'select 1 from (select 1 as one, 2 as two, 3 from dual) t0'
+ p = sqlparse.format(sql, reindent=True)
+ assert p == '\n'.join(["select 1",
+ "from",
+ " (select 1 as one,",
+ " 2 as two,",
+ " 3",
+ " from dual) t0"])
diff --git a/tests/test_split.py b/tests/test_split.py
index f6d5f50..7c2645d 100644
--- a/tests/test_split.py
+++ b/tests/test_split.py
@@ -133,6 +133,11 @@ class SQLSplitTest(TestCaseBase):
stmts = list(sqlparse.parsestream(stream))
self.assertEqual(type(stmts[0].tokens[0].value), text_type)
+ def test_unicode_parsestream(self):
+ stream = StringIO(u"SELECT ö")
+ stmts = list(sqlparse.parsestream(stream))
+ self.assertEqual(str(stmts[0]), "SELECT ö")
+
def test_split_simple():
stmts = sqlparse.split('select * from foo; select * from bar;')
diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py
index 2e931ba..adfd1ea 100644
--- a/tests/test_tokenize.py
+++ b/tests/test_tokenize.py
@@ -151,7 +151,7 @@ class TestStream(unittest.TestCase):
'CROSS JOIN', 'STRAIGHT JOIN',
'INNER JOIN', 'LEFT INNER JOIN'])
def test_parse_join(expr):
- p = sqlparse.parse('%s foo' % expr)[0]
+ p = sqlparse.parse('{0} foo'.format(expr))[0]
assert len(p.tokens) == 3
assert p.tokens[0].ttype is T.Keyword