summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVik <vmuriart@gmail.com>2016-06-23 11:37:46 -0700
committerGitHub <noreply@github.com>2016-06-23 11:37:46 -0700
commit78fb2041aef1068eb6c4bac3caad5dd012186868 (patch)
treedb402888277a06ca450eba2e57dd6f94c9b369ad
parentc56652ef9fdac111dd59e26b913765719eaf1141 (diff)
parent85349e68592964e66e5dfe7e48e9f76cb93d48fd (diff)
downloadsqlparse-78fb2041aef1068eb6c4bac3caad5dd012186868.tar.gz
Merge pull request #263 from vmuriart/clean-tests
Clean-up tests. Fully migrate to Py.test
-rw-r--r--sqlparse/cli.py12
-rw-r--r--tests/conftest.py41
-rw-r--r--tests/test_cli.py49
-rw-r--r--tests/test_format.py654
-rw-r--r--tests/test_grouping.py586
-rw-r--r--tests/test_parse.py424
-rw-r--r--tests/test_regressions.py373
-rw-r--r--tests/test_split.py256
-rw-r--r--tests/test_tokenize.py292
-rw-r--r--tests/utils.py41
10 files changed, 1367 insertions, 1361 deletions
diff --git a/sqlparse/cli.py b/sqlparse/cli.py
index 03a4f8f..80d547d 100644
--- a/sqlparse/cli.py
+++ b/sqlparse/cli.py
@@ -123,7 +123,8 @@ def create_parser():
def _error(msg):
"""Print msg and optionally exit with return code exit_."""
- sys.stderr.write('[ERROR] %s\n' % msg)
+ sys.stderr.write('[ERROR] {0}\n'.format(msg))
+ return 1
def main(args=None):
@@ -137,15 +138,13 @@ def main(args=None):
# TODO: Needs to deal with encoding
data = ''.join(open(args.filename).readlines())
except IOError as e:
- _error('Failed to read %s: %s' % (args.filename, e))
- return 1
+ return _error('Failed to read {0}: {1}'.format(args.filename, e))
if args.outfile:
try:
stream = open(args.outfile, 'w')
except IOError as e:
- _error('Failed to open %s: %s' % (args.outfile, e))
- return 1
+ return _error('Failed to open {0}: {1}'.format(args.outfile, e))
else:
stream = sys.stdout
@@ -153,8 +152,7 @@ def main(args=None):
try:
formatter_opts = sqlparse.formatter.validate_options(formatter_opts)
except SQLParseError as e:
- _error('Invalid options: %s' % e)
- return 1
+ return _error('Invalid options: {0}'.format(e))
s = sqlparse.format(data, **formatter_opts)
if PY2:
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..d5621eb
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+
+"""Helpers for testing."""
+
+import io
+import os
+
+import pytest
+
+DIR_PATH = os.path.dirname(__file__)
+FILES_DIR = os.path.join(DIR_PATH, 'files')
+
+
+@pytest.fixture()
+def filepath():
+ """Returns full file path for test files."""
+
+ def make_filepath(filename):
+ # http://stackoverflow.com/questions/18011902/parameter-to-a-fixture
+ # Alternate solution is to use paramtrization `inderect=True`
+ # http://stackoverflow.com/a/33879151
+ # Syntax is noisy and requires specific variable names
+ return os.path.join(FILES_DIR, filename)
+
+ return make_filepath
+
+
+@pytest.fixture()
+def load_file(filepath):
+ """Opens filename with encoding and return its contents."""
+
+ def make_load_file(filename, encoding='utf-8'):
+ # http://stackoverflow.com/questions/18011902/parameter-to-a-fixture
+ # Alternate solution is to use paramtrization `inderect=True`
+ # http://stackoverflow.com/a/33879151
+ # Syntax is noisy and requires specific variable names
+ # And seems to be limited to only 1 argument.
+ with io.open(filepath(filename), encoding=encoding) as f:
+ return f.read()
+
+ return make_load_file
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 81d8449..77a764e 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,16 +1,11 @@
# -*- coding: utf-8 -*-
-import os
import subprocess
import sys
import pytest
import sqlparse
-import sqlparse.__main__
-from tests.utils import FILES_DIR
-
-path = os.path.join(FILES_DIR, 'function.sql')
def test_cli_main_empty():
@@ -31,12 +26,50 @@ def test_main_help():
assert exinfo.value.code == 0
-def test_valid_args():
+def test_valid_args(filepath):
# test doesn't abort
+ path = filepath('function.sql')
assert sqlparse.cli.main([path, '-r']) is not None
+def test_invalid_choise(filepath):
+ path = filepath('function.sql')
+ with pytest.raises(SystemExit):
+ sqlparse.cli.main([path, '-l', 'spanish'])
+
+
+def test_invalid_args(filepath, capsys):
+ path = filepath('function.sql')
+ sqlparse.cli.main([path, '-r', '--indent_width', '0'])
+ _, err = capsys.readouterr()
+ assert err == ("[ERROR] Invalid options: indent_width requires "
+ "a positive integer\n")
+
+
+def test_invalid_infile(filepath, capsys):
+ path = filepath('missing.sql')
+ sqlparse.cli.main([path, '-r'])
+ _, err = capsys.readouterr()
+ assert err[:22] == "[ERROR] Failed to read"
+
+
+def test_invalid_outfile(filepath, capsys):
+ path = filepath('function.sql')
+ outpath = filepath('/missing/function.sql')
+ sqlparse.cli.main([path, '-r', '-o', outpath])
+ _, err = capsys.readouterr()
+ assert err[:22] == "[ERROR] Failed to open"
+
+
+def test_stdout(filepath, load_file, capsys):
+ path = filepath('begintag.sql')
+ expected = load_file('begintag.sql')
+ sqlparse.cli.main([path])
+ out, _ = capsys.readouterr()
+ assert out == expected
+
+
def test_script():
# Call with the --help option as a basic sanity check.
- cmdl = "{0:s} -m sqlparse.cli --help".format(sys.executable)
- assert subprocess.call(cmdl.split()) == 0
+ cmd = "{0:s} -m sqlparse.cli --help".format(sys.executable)
+ assert subprocess.call(cmd.split()) == 0
diff --git a/tests/test_format.py b/tests/test_format.py
index 74fce71..023f26d 100644
--- a/tests/test_format.py
+++ b/tests/test_format.py
@@ -2,83 +2,94 @@
import pytest
-from tests.utils import TestCaseBase
-
import sqlparse
from sqlparse.exceptions import SQLParseError
-class TestFormat(TestCaseBase):
-
+class TestFormat(object):
def test_keywordcase(self):
sql = 'select * from bar; -- select foo\n'
res = sqlparse.format(sql, keyword_case='upper')
- self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- select foo\n')
+ assert res == 'SELECT * FROM bar; -- select foo\n'
res = sqlparse.format(sql, keyword_case='capitalize')
- self.ndiffAssertEqual(res, 'Select * From bar; -- select foo\n')
+ assert res == 'Select * From bar; -- select foo\n'
res = sqlparse.format(sql.upper(), keyword_case='lower')
- self.ndiffAssertEqual(res, 'select * from BAR; -- SELECT FOO\n')
- self.assertRaises(SQLParseError, sqlparse.format, sql,
- keyword_case='foo')
+ assert res == 'select * from BAR; -- SELECT FOO\n'
+
+ def test_keywordcase_invalid_option(self):
+ sql = 'select * from bar; -- select foo\n'
+ with pytest.raises(SQLParseError):
+ sqlparse.format(sql, keyword_case='foo')
def test_identifiercase(self):
sql = 'select * from bar; -- select foo\n'
res = sqlparse.format(sql, identifier_case='upper')
- self.ndiffAssertEqual(res, 'select * from BAR; -- select foo\n')
+ assert res == 'select * from BAR; -- select foo\n'
res = sqlparse.format(sql, identifier_case='capitalize')
- self.ndiffAssertEqual(res, 'select * from Bar; -- select foo\n')
+ assert res == 'select * from Bar; -- select foo\n'
res = sqlparse.format(sql.upper(), identifier_case='lower')
- self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- SELECT FOO\n')
- self.assertRaises(SQLParseError, sqlparse.format, sql,
- identifier_case='foo')
+ assert res == 'SELECT * FROM bar; -- SELECT FOO\n'
+
+ def test_identifiercase_invalid_option(self):
+ sql = 'select * from bar; -- select foo\n'
+ with pytest.raises(SQLParseError):
+ sqlparse.format(sql, identifier_case='foo')
+
+ def test_identifiercase_quotes(self):
sql = 'select * from "foo"."bar"'
res = sqlparse.format(sql, identifier_case="upper")
- self.ndiffAssertEqual(res, 'select * from "foo"."bar"')
+ assert res == 'select * from "foo"."bar"'
def test_strip_comments_single(self):
sql = 'select *-- statement starts here\nfrom foo'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select * from foo')
+ assert res == 'select * from foo'
sql = 'select * -- statement starts here\nfrom foo'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select * from foo')
+ assert res == 'select * from foo'
sql = 'select-- foo\nfrom -- bar\nwhere'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select from where')
- self.assertRaises(SQLParseError, sqlparse.format, sql,
- strip_comments=None)
+ assert res == 'select from where'
+
+ def test_strip_comments_invalid_option(self):
+ sql = 'select-- foo\nfrom -- bar\nwhere'
+ with pytest.raises(SQLParseError):
+ sqlparse.format(sql, strip_comments=None)
def test_strip_comments_multi(self):
sql = '/* sql starts here */\nselect'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select')
+ assert res == 'select'
sql = '/* sql starts here */ select'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select')
+ assert res == 'select'
sql = '/*\n * sql starts here\n */\nselect'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select')
+ assert res == 'select'
sql = 'select (/* sql starts here */ select 2)'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select (select 2)')
+ assert res == 'select (select 2)'
sql = 'select (/* sql /* starts here */ select 2)'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select (select 2)')
+ assert res == 'select (select 2)'
def test_strip_ws(self):
f = lambda sql: sqlparse.format(sql, strip_whitespace=True)
s = 'select\n* from foo\n\twhere ( 1 = 2 )\n'
- self.ndiffAssertEqual(f(s), 'select * from foo where (1 = 2)')
+ assert f(s) == 'select * from foo where (1 = 2)'
+ s = 'select -- foo\nfrom bar\n'
+ assert f(s) == 'select -- foo\nfrom bar'
+
+ def test_strip_ws_invalid_option(self):
s = 'select -- foo\nfrom bar\n'
- self.ndiffAssertEqual(f(s), 'select -- foo\nfrom bar')
- self.assertRaises(SQLParseError, sqlparse.format, s,
- strip_whitespace=None)
+ with pytest.raises(SQLParseError):
+ sqlparse.format(s, strip_whitespace=None)
def test_preserve_ws(self):
# preserve at least one whitespace after subgroups
f = lambda sql: sqlparse.format(sql, strip_whitespace=True)
s = 'select\n* /* foo */ from bar '
- self.ndiffAssertEqual(f(s), 'select * /* foo */ from bar')
+ assert f(s) == 'select * /* foo */ from bar'
def test_notransform_of_quoted_crlf(self):
# Make sure that CR/CR+LF characters inside string literals don't get
@@ -92,21 +103,14 @@ class TestFormat(TestCaseBase):
f = lambda x: sqlparse.format(x)
# Because of the use of
- self.ndiffAssertEqual(f(s1), "SELECT some_column LIKE 'value\r'")
- self.ndiffAssertEqual(
- f(s2), "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n")
- self.ndiffAssertEqual(
- f(s3), "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n")
- self.ndiffAssertEqual(
- f(s4), "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n")
-
- def test_outputformat(self):
- sql = 'select * from foo;'
- self.assertRaises(SQLParseError, sqlparse.format, sql,
- output_format='foo')
+ assert f(s1) == "SELECT some_column LIKE 'value\r'"
+ assert f(s2) == "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n"
+ assert f(s3) == "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n"
+ assert (f(s4) ==
+ "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n")
-class TestFormatReindentAligned(TestCaseBase):
+class TestFormatReindentAligned(object):
@staticmethod
def formatter(sql):
return sqlparse.format(sql, reindent_aligned=True)
@@ -121,23 +125,21 @@ class TestFormatReindentAligned(TestCaseBase):
or d is 'blue'
limit 10
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select a,',
- ' b as bb,',
- ' c',
- ' from table',
- ' join (',
- ' select a * 2 as a',
- ' from new_table',
- ' ) other',
- ' on table.a = other.a',
- ' where c is true',
- ' and b between 3 and 4',
- " or d is 'blue'",
- ' limit 10',
- ]))
+
+ assert self.formatter(sql) == '\n'.join([
+ 'select a,',
+ ' b as bb,',
+ ' c',
+ ' from table',
+ ' join (',
+ ' select a * 2 as a',
+ ' from new_table',
+ ' ) other',
+ ' on table.a = other.a',
+ ' where c is true',
+ ' and b between 3 and 4',
+ " or d is 'blue'",
+ ' limit 10'])
def test_joins(self):
sql = """
@@ -148,22 +150,19 @@ class TestFormatReindentAligned(TestCaseBase):
cross join e on e.four = a.four
join f using (one, two, three)
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select *',
- ' from a',
- ' join b',
- ' on a.one = b.one',
- ' left join c',
- ' on c.two = a.two',
- ' and c.three = a.three',
- ' full outer join d',
- ' on d.three = a.three',
- ' cross join e',
- ' on e.four = a.four',
- ' join f using (one, two, three)',
- ]))
+ assert self.formatter(sql) == '\n'.join([
+ 'select *',
+ ' from a',
+ ' join b',
+ ' on a.one = b.one',
+ ' left join c',
+ ' on c.two = a.two',
+ ' and c.three = a.three',
+ ' full outer join d',
+ ' on d.three = a.three',
+ ' cross join e',
+ ' on e.four = a.four',
+ ' join f using (one, two, three)'])
def test_case_statement(self):
sql = """
@@ -178,20 +177,17 @@ class TestFormatReindentAligned(TestCaseBase):
where c is true
and b between 3 and 4
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select a,',
- ' case when a = 0 then 1',
- ' when bb = 1 then 1',
- ' when c = 2 then 2',
- ' else 0',
- ' end as d,',
- ' extra_col',
- ' from table',
- ' where c is true',
- ' and b between 3 and 4'
- ]))
+ assert self.formatter(sql) == '\n'.join([
+ 'select a,',
+ ' case when a = 0 then 1',
+ ' when bb = 1 then 1',
+ ' when c = 2 then 2',
+ ' else 0',
+ ' end as d,',
+ ' extra_col',
+ ' from table',
+ ' where c is true',
+ ' and b between 3 and 4'])
def test_case_statement_with_between(self):
sql = """
@@ -207,21 +203,18 @@ class TestFormatReindentAligned(TestCaseBase):
where c is true
and b between 3 and 4
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select a,',
- ' case when a = 0 then 1',
- ' when bb = 1 then 1',
- ' when c = 2 then 2',
- ' when d between 3 and 5 then 3',
- ' else 0',
- ' end as d,',
- ' extra_col',
- ' from table',
- ' where c is true',
- ' and b between 3 and 4'
- ]))
+ assert self.formatter(sql) == '\n'.join([
+ 'select a,',
+ ' case when a = 0 then 1',
+ ' when bb = 1 then 1',
+ ' when c = 2 then 2',
+ ' when d between 3 and 5 then 3',
+ ' else 0',
+ ' end as d,',
+ ' extra_col',
+ ' from table',
+ ' where c is true',
+ ' and b between 3 and 4'])
def test_group_by(self):
sql = """
@@ -232,24 +225,21 @@ class TestFormatReindentAligned(TestCaseBase):
and count(y) > 5
order by 3,2,1
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select a,',
- ' b,',
- ' c,',
- ' sum(x) as sum_x,',
- ' count(y) as cnt_y',
- ' from table',
- ' group by a,',
- ' b,',
- ' c',
- 'having sum(x) > 1',
- ' and count(y) > 5',
- ' order by 3,',
- ' 2,',
- ' 1',
- ]))
+ assert self.formatter(sql) == '\n'.join([
+ 'select a,',
+ ' b,',
+ ' c,',
+ ' sum(x) as sum_x,',
+ ' count(y) as cnt_y',
+ ' from table',
+ ' group by a,',
+ ' b,',
+ ' c',
+ 'having sum(x) > 1',
+ ' and count(y) > 5',
+ ' order by 3,',
+ ' 2,',
+ ' 1'])
def test_group_by_subquery(self):
# TODO: add subquery alias when test_identifier_list_subquery fixed
@@ -261,21 +251,18 @@ class TestFormatReindentAligned(TestCaseBase):
group by a,z)
order by 1,2
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select *,',
- ' sum_b + 2 as mod_sum',
- ' from (',
- ' select a,',
- ' sum(b) as sum_b',
- ' from table',
- ' group by a,',
- ' z',
- ' )',
- ' order by 1,',
- ' 2',
- ]))
+ assert self.formatter(sql) == '\n'.join([
+ 'select *,',
+ ' sum_b + 2 as mod_sum',
+ ' from (',
+ ' select a,',
+ ' sum(b) as sum_b',
+ ' from table',
+ ' group by a,',
+ ' z',
+ ' )',
+ ' order by 1,',
+ ' 2'])
def test_window_functions(self):
sql = """
@@ -284,21 +271,17 @@ class TestFormatReindentAligned(TestCaseBase):
BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,
ROW_NUMBER() OVER
(PARTITION BY b, c ORDER BY d DESC) as row_num
- from table
- """
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select a,',
- (' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS '
- 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,'),
- (' ROW_NUMBER() OVER '
- '(PARTITION BY b, c ORDER BY d DESC) as row_num'),
- ' from table',
- ]))
-
-
-class TestSpacesAroundOperators(TestCaseBase):
+ from table"""
+ assert self.formatter(sql) == '\n'.join([
+ 'select a,',
+ ' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS '
+ 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,',
+ ' ROW_NUMBER() OVER '
+ '(PARTITION BY b, c ORDER BY d DESC) as row_num',
+ ' from table'])
+
+
+class TestSpacesAroundOperators(object):
@staticmethod
def formatter(sql):
return sqlparse.format(sql, use_space_around_operators=True)
@@ -306,311 +289,330 @@ class TestSpacesAroundOperators(TestCaseBase):
def test_basic(self):
sql = ('select a+b as d from table '
'where (c-d)%2= 1 and e> 3.0/4 and z^2 <100')
- self.ndiffAssertEqual(
- self.formatter(sql), (
- 'select a + b as d from table '
- 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100')
- )
+ assert self.formatter(sql) == (
+ 'select a + b as d from table '
+ 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100')
def test_bools(self):
sql = 'select * from table where a &&b or c||d'
- self.ndiffAssertEqual(
- self.formatter(sql),
- 'select * from table where a && b or c || d'
- )
+ assert self.formatter(
+ sql) == 'select * from table where a && b or c || d'
def test_nested(self):
sql = 'select *, case when a-b then c end from table'
- self.ndiffAssertEqual(
- self.formatter(sql),
- 'select *, case when a - b then c end from table'
- )
+ assert self.formatter(
+ sql) == 'select *, case when a - b then c end from table'
def test_wildcard_vs_mult(self):
sql = 'select a*b-c from table'
- self.ndiffAssertEqual(
- self.formatter(sql),
- 'select a * b - c from table'
- )
+ assert self.formatter(sql) == 'select a * b - c from table'
-class TestFormatReindent(TestCaseBase):
-
+class TestFormatReindent(object):
def test_option(self):
- self.assertRaises(SQLParseError, sqlparse.format, 'foo',
- reindent=2)
- self.assertRaises(SQLParseError, sqlparse.format, 'foo',
- indent_tabs=2)
- self.assertRaises(SQLParseError, sqlparse.format, 'foo',
- reindent=True, indent_width='foo')
- self.assertRaises(SQLParseError, sqlparse.format, 'foo',
- reindent=True, indent_width=-12)
- self.assertRaises(SQLParseError, sqlparse.format, 'foo',
- reindent=True, wrap_after='foo')
- self.assertRaises(SQLParseError, sqlparse.format, 'foo',
- reindent=True, wrap_after=-12)
+ with pytest.raises(SQLParseError):
+ sqlparse.format('foo', reindent=2)
+ with pytest.raises(SQLParseError):
+ sqlparse.format('foo', indent_tabs=2)
+ with pytest.raises(SQLParseError):
+ sqlparse.format('foo', reindent=True, indent_width='foo')
+ with pytest.raises(SQLParseError):
+ sqlparse.format('foo', reindent=True, indent_width=-12)
+ with pytest.raises(SQLParseError):
+ sqlparse.format('foo', reindent=True, wrap_after='foo')
+ with pytest.raises(SQLParseError):
+ sqlparse.format('foo', reindent=True, wrap_after=-12)
def test_stmts(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select foo; select bar'
- self.ndiffAssertEqual(f(s), 'select foo;\n\nselect bar')
+ assert f(s) == 'select foo;\n\nselect bar'
s = 'select foo'
- self.ndiffAssertEqual(f(s), 'select foo')
+ assert f(s) == 'select foo'
s = 'select foo; -- test\n select bar'
- self.ndiffAssertEqual(f(s), 'select foo; -- test\n\nselect bar')
+ assert f(s) == 'select foo; -- test\n\nselect bar'
def test_keywords(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select * from foo union select * from bar;'
- self.ndiffAssertEqual(f(s), '\n'.join(['select *',
- 'from foo',
- 'union',
- 'select *',
- 'from bar;']))
-
- def test_keywords_between(self): # issue 14
+ assert f(s) == '\n'.join([
+ 'select *',
+ 'from foo',
+ 'union',
+ 'select *',
+ 'from bar;'])
+
+ def test_keywords_between(self):
+ # issue 14
# don't break AND after BETWEEN
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'and foo between 1 and 2 and bar = 3'
- self.ndiffAssertEqual(f(s), '\n'.join(['',
- 'and foo between 1 and 2',
- 'and bar = 3']))
+ assert f(s) == '\n'.join([
+ '',
+ 'and foo between 1 and 2',
+ 'and bar = 3'])
def test_parenthesis(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select count(*) from (select * from foo);'
- self.ndiffAssertEqual(f(s),
- '\n'.join(['select count(*)',
- 'from',
- ' (select *',
- ' from foo);',
- ])
- )
+ assert f(s) == '\n'.join([
+ 'select count(*)',
+ 'from',
+ ' (select *',
+ ' from foo);'])
def test_where(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;'
- self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
- 'where bar = 1\n'
- ' and baz = 2\n'
- ' or bzz = 3;'))
+ assert f(s) == '\n'.join([
+ 'select *',
+ 'from foo',
+ 'where bar = 1',
+ ' and baz = 2',
+ ' or bzz = 3;'])
+
s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);'
- self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
- 'where bar = 1\n'
- ' and (baz = 2\n'
- ' or bzz = 3);'))
+ assert f(s) == '\n'.join([
+ 'select *',
+ 'from foo',
+ 'where bar = 1',
+ ' and (baz = 2',
+ ' or bzz = 3);'])
def test_join(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select * from foo join bar on 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select *',
- 'from foo',
- 'join bar on 1 = 2']))
+ assert f(s) == '\n'.join([
+ 'select *',
+ 'from foo',
+ 'join bar on 1 = 2'])
s = 'select * from foo inner join bar on 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select *',
- 'from foo',
- 'inner join bar on 1 = 2']))
+ assert f(s) == '\n'.join([
+ 'select *',
+ 'from foo',
+ 'inner join bar on 1 = 2'])
s = 'select * from foo left outer join bar on 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select *',
- 'from foo',
- 'left outer join bar on 1 = 2']
- ))
+ assert f(s) == '\n'.join([
+ 'select *',
+ 'from foo',
+ 'left outer join bar on 1 = 2'])
s = 'select * from foo straight_join bar on 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select *',
- 'from foo',
- 'straight_join bar on 1 = 2']
- ))
+ assert f(s) == '\n'.join([
+ 'select *',
+ 'from foo',
+ 'straight_join bar on 1 = 2'])
def test_identifier_list(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select foo, bar, baz from table1, table2 where 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select foo,',
- ' bar,',
- ' baz',
- 'from table1,',
- ' table2',
- 'where 1 = 2']))
+ assert f(s) == '\n'.join([
+ 'select foo,',
+ ' bar,',
+ ' baz',
+ 'from table1,',
+ ' table2',
+ 'where 1 = 2'])
s = 'select a.*, b.id from a, b'
- self.ndiffAssertEqual(f(s), '\n'.join(['select a.*,',
- ' b.id',
- 'from a,',
- ' b']))
+ assert f(s) == '\n'.join([
+ 'select a.*,',
+ ' b.id',
+ 'from a,',
+ ' b'])
def test_identifier_list_with_wrap_after(self):
f = lambda sql: sqlparse.format(sql, reindent=True, wrap_after=14)
s = 'select foo, bar, baz from table1, table2 where 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select foo, bar,',
- ' baz',
- 'from table1, table2',
- 'where 1 = 2']))
+ assert f(s) == '\n'.join([
+ 'select foo, bar,',
+ ' baz',
+ 'from table1, table2',
+ 'where 1 = 2'])
def test_identifier_list_with_functions(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = ("select 'abc' as foo, coalesce(col1, col2)||col3 as bar,"
"col3 from my_table")
- self.ndiffAssertEqual(f(s), '\n'.join(
- ["select 'abc' as foo,",
- " coalesce(col1, col2)||col3 as bar,",
- " col3",
- "from my_table"]))
+ assert f(s) == '\n'.join([
+ "select 'abc' as foo,",
+ " coalesce(col1, col2)||col3 as bar,",
+ " col3",
+ "from my_table"])
def test_case(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end'
- self.ndiffAssertEqual(f(s), '\n'.join(['case',
- ' when foo = 1 then 2',
- ' when foo = 3 then 4',
- ' else 5',
- 'end']))
+ assert f(s) == '\n'.join([
+ 'case',
+ ' when foo = 1 then 2',
+ ' when foo = 3 then 4',
+ ' else 5',
+ 'end'])
def test_case2(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'case(foo) when bar = 1 then 2 else 3 end'
- self.ndiffAssertEqual(f(s), '\n'.join(['case(foo)',
- ' when bar = 1 then 2',
- ' else 3',
- 'end']))
-
- def test_nested_identifier_list(self): # issue4
+ assert f(s) == '\n'.join([
+ 'case(foo)',
+ ' when bar = 1 then 2',
+ ' else 3',
+ 'end'])
+
+ def test_nested_identifier_list(self):
+ # issue4
f = lambda sql: sqlparse.format(sql, reindent=True)
s = '(foo as bar, bar1, bar2 as bar3, b4 as b5)'
- self.ndiffAssertEqual(f(s), '\n'.join(['(foo as bar,',
- ' bar1,',
- ' bar2 as bar3,',
- ' b4 as b5)']))
-
- def test_duplicate_linebreaks(self): # issue3
+ assert f(s) == '\n'.join([
+ '(foo as bar,',
+ ' bar1,',
+ ' bar2 as bar3,',
+ ' b4 as b5)'])
+
+ def test_duplicate_linebreaks(self):
+ # issue3
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select c1 -- column1\nfrom foo'
- self.ndiffAssertEqual(f(s), '\n'.join(['select c1 -- column1',
- 'from foo']))
+ assert f(s) == '\n'.join([
+ 'select c1 -- column1',
+ 'from foo'])
s = 'select c1 -- column1\nfrom foo'
r = sqlparse.format(s, reindent=True, strip_comments=True)
- self.ndiffAssertEqual(r, '\n'.join(['select c1',
- 'from foo']))
+ assert r == '\n'.join([
+ 'select c1',
+ 'from foo'])
s = 'select c1\nfrom foo\norder by c1'
- self.ndiffAssertEqual(f(s), '\n'.join(['select c1',
- 'from foo',
- 'order by c1']))
+ assert f(s) == '\n'.join([
+ 'select c1',
+ 'from foo',
+ 'order by c1'])
s = 'select c1 from t1 where (c1 = 1) order by c1'
- self.ndiffAssertEqual(f(s), '\n'.join(['select c1',
- 'from t1',
- 'where (c1 = 1)',
- 'order by c1']))
-
- def test_keywordfunctions(self): # issue36
+ assert f(s) == '\n'.join([
+ 'select c1',
+ 'from t1',
+ 'where (c1 = 1)',
+ 'order by c1'])
+
+ def test_keywordfunctions(self):
+ # issue36
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select max(a) b, foo, bar'
- self.ndiffAssertEqual(f(s), '\n'.join(['select max(a) b,',
- ' foo,',
- ' bar']))
+ assert f(s) == '\n'.join([
+ 'select max(a) b,',
+ ' foo,',
+ ' bar'])
- def test_identifier_and_functions(self): # issue45
+ def test_identifier_and_functions(self):
+ # issue45
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select foo.bar, nvl(1) from dual'
- self.ndiffAssertEqual(f(s), '\n'.join(['select foo.bar,',
- ' nvl(1)',
- 'from dual']))
+ assert f(s) == '\n'.join([
+ 'select foo.bar,',
+ ' nvl(1)',
+ 'from dual'])
-class TestOutputFormat(TestCaseBase):
-
+class TestOutputFormat(object):
def test_python(self):
sql = 'select * from foo;'
f = lambda sql: sqlparse.format(sql, output_format='python')
- self.ndiffAssertEqual(f(sql), "sql = 'select * from foo;'")
+ assert f(sql) == "sql = 'select * from foo;'"
f = lambda sql: sqlparse.format(sql, output_format='python',
reindent=True)
- self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n"
- " 'from foo;')"))
+ assert f(sql) == '\n'.join([
+ "sql = ('select * '",
+ " 'from foo;')"])
def test_python_multiple_statements(self):
sql = 'select * from foo; select 1 from dual'
f = lambda sql: sqlparse.format(sql, output_format='python')
- self.ndiffAssertEqual(f(sql), ("sql = 'select * from foo; '\n"
- "sql2 = 'select 1 from dual'"))
+ assert f(sql) == '\n'.join([
+ "sql = 'select * from foo; '",
+ "sql2 = 'select 1 from dual'"])
@pytest.mark.xfail(reason="Needs fixing")
def test_python_multiple_statements_with_formatting(self):
sql = 'select * from foo; select 1 from dual'
f = lambda sql: sqlparse.format(sql, output_format='python',
reindent=True)
- self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n"
- " 'from foo;')\n"
- "sql2 = ('select 1 '\n"
- " 'from dual')"))
+ assert f(sql) == '\n'.join([
+ "sql = ('select * '",
+ " 'from foo;')",
+ "sql2 = ('select 1 '",
+ " 'from dual')"])
def test_php(self):
sql = 'select * from foo;'
f = lambda sql: sqlparse.format(sql, output_format='php')
- self.ndiffAssertEqual(f(sql), '$sql = "select * from foo;";')
+ assert f(sql) == '$sql = "select * from foo;";'
f = lambda sql: sqlparse.format(sql, output_format='php',
reindent=True)
- self.ndiffAssertEqual(f(sql), ('$sql = "select * ";\n'
- '$sql .= "from foo;";'))
+ assert f(sql) == '\n'.join([
+ '$sql = "select * ";',
+ '$sql .= "from foo;";'])
- def test_sql(self): # "sql" is an allowed option but has no effect
+ def test_sql(self):
+ # "sql" is an allowed option but has no effect
sql = 'select * from foo;'
f = lambda sql: sqlparse.format(sql, output_format='sql')
- self.ndiffAssertEqual(f(sql), 'select * from foo;')
+ assert f(sql) == 'select * from foo;'
+
+ def test_invalid_option(self):
+ sql = 'select * from foo;'
+ with pytest.raises(SQLParseError):
+ sqlparse.format(sql, output_format='foo')
-def test_format_column_ordering(): # issue89
+def test_format_column_ordering():
+ # issue89
sql = 'select * from foo order by c1 desc, c2, c3;'
formatted = sqlparse.format(sql, reindent=True)
- expected = '\n'.join(['select *',
- 'from foo',
- 'order by c1 desc,',
- ' c2,',
- ' c3;'])
+ expected = '\n'.join([
+ 'select *',
+ 'from foo',
+ 'order by c1 desc,',
+ ' c2,',
+ ' c3;'])
assert formatted == expected
def test_truncate_strings():
- sql = 'update foo set value = \'' + 'x' * 1000 + '\';'
+ sql = "update foo set value = '{0}';".format('x' * 1000)
formatted = sqlparse.format(sql, truncate_strings=10)
- assert formatted == 'update foo set value = \'xxxxxxxxxx[...]\';'
+ assert formatted == "update foo set value = 'xxxxxxxxxx[...]';"
formatted = sqlparse.format(sql, truncate_strings=3, truncate_char='YYY')
- assert formatted == 'update foo set value = \'xxxYYY\';'
+ assert formatted == "update foo set value = 'xxxYYY';"
-def test_truncate_strings_invalid_option():
- pytest.raises(SQLParseError, sqlparse.format,
- 'foo', truncate_strings='bar')
- pytest.raises(SQLParseError, sqlparse.format,
- 'foo', truncate_strings=-1)
- pytest.raises(SQLParseError, sqlparse.format,
- 'foo', truncate_strings=0)
+@pytest.mark.parametrize('option', ['bar', -1, 0])
+def test_truncate_strings_invalid_option2(option):
+ with pytest.raises(SQLParseError):
+ sqlparse.format('foo', truncate_strings=option)
-@pytest.mark.parametrize('sql', ['select verrrylongcolumn from foo',
- 'select "verrrylongcolumn" from "foo"'])
+@pytest.mark.parametrize('sql', [
+ 'select verrrylongcolumn from foo',
+ 'select "verrrylongcolumn" from "foo"'])
def test_truncate_strings_doesnt_truncate_identifiers(sql):
formatted = sqlparse.format(sql, truncate_strings=2)
assert formatted == sql
def test_having_produces_newline():
- sql = (
- 'select * from foo, bar where bar.id = foo.bar_id'
- ' having sum(bar.value) > 100')
+ sql = ('select * from foo, bar where bar.id = foo.bar_id '
+ 'having sum(bar.value) > 100')
formatted = sqlparse.format(sql, reindent=True)
expected = [
'select *',
'from foo,',
' bar',
'where bar.id = foo.bar_id',
- 'having sum(bar.value) > 100'
- ]
+ 'having sum(bar.value) > 100']
assert formatted == '\n'.join(expected)
-def test_format_right_margin_invalid_input():
- with pytest.raises(SQLParseError):
- sqlparse.format('foo', right_margin=2)
-
+@pytest.mark.parametrize('right_margin', ['ten', 2])
+def test_format_right_margin_invalid_option(right_margin):
with pytest.raises(SQLParseError):
- sqlparse.format('foo', right_margin="two")
+ sqlparse.format('foo', right_margin=right_margin)
@pytest.mark.xfail(reason="Needs fixing")
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index bf6bfeb..12d7310 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -3,278 +3,294 @@
import pytest
import sqlparse
-from sqlparse import sql
-from sqlparse import tokens as T
-from sqlparse.compat import u
-
-from tests.utils import TestCaseBase
-
-
-class TestGrouping(TestCaseBase):
-
- def test_parenthesis(self):
- s = 'select (select (x3) x2) and (y2) bar'
- parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, str(parsed))
- self.assertEqual(len(parsed.tokens), 7)
- self.assert_(isinstance(parsed.tokens[2], sql.Parenthesis))
- self.assert_(isinstance(parsed.tokens[-1], sql.Identifier))
- self.assertEqual(len(parsed.tokens[2].tokens), 5)
- self.assert_(isinstance(parsed.tokens[2].tokens[3], sql.Identifier))
- self.assert_(isinstance(parsed.tokens[2].tokens[3].tokens[0],
- sql.Parenthesis))
- self.assertEqual(len(parsed.tokens[2].tokens[3].tokens), 3)
-
- def test_comments(self):
- s = '/*\n * foo\n */ \n bar'
- parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(parsed))
- self.assertEqual(len(parsed.tokens), 2)
-
- def test_assignment(self):
- s = 'foo := 1;'
- parsed = sqlparse.parse(s)[0]
- self.assertEqual(len(parsed.tokens), 1)
- self.assert_(isinstance(parsed.tokens[0], sql.Assignment))
- s = 'foo := 1'
- parsed = sqlparse.parse(s)[0]
- self.assertEqual(len(parsed.tokens), 1)
- self.assert_(isinstance(parsed.tokens[0], sql.Assignment))
-
- def test_identifiers(self):
- s = 'select foo.bar from "myscheme"."table" where fail. order'
- parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(parsed))
- self.assert_(isinstance(parsed.tokens[2], sql.Identifier))
- self.assert_(isinstance(parsed.tokens[6], sql.Identifier))
- self.assert_(isinstance(parsed.tokens[8], sql.Where))
- s = 'select * from foo where foo.id = 1'
- parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(parsed))
- self.assert_(isinstance(parsed.tokens[-1].tokens[-1].tokens[0],
- sql.Identifier))
- s = 'select * from (select "foo"."id" from foo)'
- parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(parsed))
- self.assert_(isinstance(parsed.tokens[-1].tokens[3], sql.Identifier))
-
- s = "INSERT INTO `test` VALUES('foo', 'bar');"
- parsed = sqlparse.parse(s)[0]
- types = [l.ttype for l in parsed.tokens if not l.is_whitespace()]
- self.assertEquals(types, [T.DML, T.Keyword, None,
- T.Keyword, None, T.Punctuation])
-
- s = "select 1.0*(a+b) as col, sum(c)/sum(d) from myschema.mytable"
- parsed = sqlparse.parse(s)[0]
- self.assertEqual(len(parsed.tokens), 7)
- self.assert_(isinstance(parsed.tokens[2], sql.IdentifierList))
- self.assertEqual(len(parsed.tokens[2].tokens), 4)
- identifiers = list(parsed.tokens[2].get_identifiers())
- self.assertEqual(len(identifiers), 2)
- self.assertEquals(identifiers[0].get_alias(), u"col")
-
- def test_identifier_wildcard(self):
- p = sqlparse.parse('a.*, b.id')[0]
- self.assert_(isinstance(p.tokens[0], sql.IdentifierList))
- self.assert_(isinstance(p.tokens[0].tokens[0], sql.Identifier))
- self.assert_(isinstance(p.tokens[0].tokens[-1], sql.Identifier))
-
- def test_identifier_name_wildcard(self):
- p = sqlparse.parse('a.*')[0]
- t = p.tokens[0]
- self.assertEqual(t.get_name(), '*')
- self.assertEqual(t.is_wildcard(), True)
-
- def test_identifier_invalid(self):
- p = sqlparse.parse('a.')[0]
- self.assert_(isinstance(p.tokens[0], sql.Identifier))
- self.assertEqual(p.tokens[0].has_alias(), False)
- self.assertEqual(p.tokens[0].get_name(), None)
- self.assertEqual(p.tokens[0].get_real_name(), None)
- self.assertEqual(p.tokens[0].get_parent_name(), 'a')
-
- def test_identifier_invalid_in_middle(self):
- # issue261
- s = 'SELECT foo. FROM foo'
- p = sqlparse.parse(s)[0]
- assert isinstance(p[2], sql.Identifier)
- assert p[2][1].ttype == T.Punctuation
- assert p[3].ttype == T.Whitespace
- assert str(p[2]) == 'foo.'
-
- def test_identifier_as_invalid(self): # issue8
- p = sqlparse.parse('foo as select *')[0]
- self.assert_(len(p.tokens), 5)
- self.assert_(isinstance(p.tokens[0], sql.Identifier))
- self.assertEqual(len(p.tokens[0].tokens), 1)
- self.assertEqual(p.tokens[2].ttype, T.Keyword)
-
- def test_identifier_function(self):
- p = sqlparse.parse('foo() as bar')[0]
- self.assert_(isinstance(p.tokens[0], sql.Identifier))
- self.assert_(isinstance(p.tokens[0].tokens[0], sql.Function))
- p = sqlparse.parse('foo()||col2 bar')[0]
- self.assert_(isinstance(p.tokens[0], sql.Identifier))
- self.assert_(isinstance(p.tokens[0].tokens[0], sql.Operation))
- self.assert_(isinstance(p.tokens[0].tokens[0].tokens[0], sql.Function))
-
- def test_identifier_extended(self): # issue 15
- p = sqlparse.parse('foo+100')[0]
- self.assert_(isinstance(p.tokens[0], sql.Operation))
- p = sqlparse.parse('foo + 100')[0]
- self.assert_(isinstance(p.tokens[0], sql.Operation))
- p = sqlparse.parse('foo*100')[0]
- self.assert_(isinstance(p.tokens[0], sql.Operation))
-
- def test_identifier_list(self):
- p = sqlparse.parse('a, b, c')[0]
- self.assert_(isinstance(p.tokens[0], sql.IdentifierList))
- p = sqlparse.parse('(a, b, c)')[0]
- self.assert_(isinstance(p.tokens[0].tokens[1], sql.IdentifierList))
-
- def test_identifier_list_subquery(self):
- """identifier lists should still work in subqueries with aliases"""
- p = sqlparse.parse("select * from ("
- "select a, b + c as d from table) sub")[0]
- subquery = p.tokens[-1].tokens[0]
- idx, iden_list = subquery.token_next_by(i=sql.IdentifierList)
- self.assert_(iden_list is not None)
- # all the identifiers should be within the IdentifierList
- _, ilist = subquery.token_next_by(i=sql.Identifier, idx=idx)
- self.assert_(ilist is None)
-
- def test_identifier_list_case(self):
- p = sqlparse.parse('a, case when 1 then 2 else 3 end as b, c')[0]
- self.assert_(isinstance(p.tokens[0], sql.IdentifierList))
- p = sqlparse.parse('(a, case when 1 then 2 else 3 end as b, c)')[0]
- self.assert_(isinstance(p.tokens[0].tokens[1], sql.IdentifierList))
-
- def test_identifier_list_other(self): # issue2
- p = sqlparse.parse("select *, null, 1, 'foo', bar from mytable, x")[0]
- self.assert_(isinstance(p.tokens[2], sql.IdentifierList))
- l = p.tokens[2]
- self.assertEqual(len(l.tokens), 13)
-
- def test_identifier_list_with_inline_comments(self): # issue163
- p = sqlparse.parse('foo /* a comment */, bar')[0]
- self.assert_(isinstance(p.tokens[0], sql.IdentifierList))
- self.assert_(isinstance(p.tokens[0].tokens[0], sql.Identifier))
- self.assert_(isinstance(p.tokens[0].tokens[3], sql.Identifier))
-
- def test_identifiers_with_operators(self):
- p = sqlparse.parse('a+b as c from table where (d-e)%2= 1')[0]
- self.assertEqual(len([x for x in p.flatten()
- if x.ttype == sqlparse.tokens.Name]), 5)
-
- def test_identifier_list_with_order(self): # issue101
- p = sqlparse.parse('1, 2 desc, 3')[0]
- self.assert_(isinstance(p.tokens[0], sql.IdentifierList))
- self.assert_(isinstance(p.tokens[0].tokens[3], sql.Identifier))
- self.ndiffAssertEqual(u(p.tokens[0].tokens[3]), '2 desc')
-
- def test_where(self):
- s = 'select * from foo where bar = 1 order by id desc'
- p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
- self.assert_(len(p.tokens) == 14)
-
- s = 'select x from (select y from foo where bar = 1) z'
- p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
- self.assert_(isinstance(p.tokens[-1].tokens[0].tokens[-2], sql.Where))
-
- def test_typecast(self):
- s = 'select foo::integer from bar'
- p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
- self.assertEqual(p.tokens[2].get_typecast(), 'integer')
- self.assertEqual(p.tokens[2].get_name(), 'foo')
- s = 'select (current_database())::information_schema.sql_identifier'
- p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
- self.assertEqual(p.tokens[2].get_typecast(),
- 'information_schema.sql_identifier')
-
- def test_alias(self):
- s = 'select foo as bar from mytable'
- p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
- self.assertEqual(p.tokens[2].get_real_name(), 'foo')
- self.assertEqual(p.tokens[2].get_alias(), 'bar')
- s = 'select foo from mytable t1'
- p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
- self.assertEqual(p.tokens[6].get_real_name(), 'mytable')
- self.assertEqual(p.tokens[6].get_alias(), 't1')
- s = 'select foo::integer as bar from mytable'
- p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
- self.assertEqual(p.tokens[2].get_alias(), 'bar')
- s = ('SELECT DISTINCT '
- '(current_database())::information_schema.sql_identifier AS view')
- p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
- self.assertEqual(p.tokens[4].get_alias(), 'view')
-
- def test_alias_case(self): # see issue46
- p = sqlparse.parse('CASE WHEN 1 THEN 2 ELSE 3 END foo')[0]
- self.assertEqual(len(p.tokens), 1)
- self.assertEqual(p.tokens[0].get_alias(), 'foo')
-
- def test_alias_returns_none(self): # see issue185
- p = sqlparse.parse('foo.bar')[0]
- self.assertEqual(len(p.tokens), 1)
- self.assertEqual(p.tokens[0].get_alias(), None)
-
- def test_idlist_function(self): # see issue10 too
- p = sqlparse.parse('foo(1) x, bar')[0]
- self.assert_(isinstance(p.tokens[0], sql.IdentifierList))
-
- def test_comparison_exclude(self):
- # make sure operators are not handled too lazy
- p = sqlparse.parse('(=)')[0]
- self.assert_(isinstance(p.tokens[0], sql.Parenthesis))
- self.assert_(not isinstance(p.tokens[0].tokens[1], sql.Comparison))
- p = sqlparse.parse('(a=1)')[0]
- self.assert_(isinstance(p.tokens[0].tokens[1], sql.Comparison))
- p = sqlparse.parse('(a>=1)')[0]
- self.assert_(isinstance(p.tokens[0].tokens[1], sql.Comparison))
-
- def test_function(self):
- p = sqlparse.parse('foo()')[0]
- self.assert_(isinstance(p.tokens[0], sql.Function))
- p = sqlparse.parse('foo(null, bar)')[0]
- self.assert_(isinstance(p.tokens[0], sql.Function))
- self.assertEqual(len(list(p.tokens[0].get_parameters())), 2)
-
- def test_function_not_in(self): # issue183
- p = sqlparse.parse('in(1, 2)')[0]
- self.assertEqual(len(p.tokens), 2)
- self.assertEqual(p.tokens[0].ttype, T.Keyword)
- self.assert_(isinstance(p.tokens[1], sql.Parenthesis))
-
- def test_varchar(self):
- p = sqlparse.parse('"text" Varchar(50) NOT NULL')[0]
- self.assert_(isinstance(p.tokens[2], sql.Function))
-
-
-class TestStatement(TestCaseBase):
-
- def test_get_type(self):
- def f(sql):
- return sqlparse.parse(sql)[0]
- self.assertEqual(f('select * from foo').get_type(), 'SELECT')
- self.assertEqual(f('update foo').get_type(), 'UPDATE')
- self.assertEqual(f(' update foo').get_type(), 'UPDATE')
- self.assertEqual(f('\nupdate foo').get_type(), 'UPDATE')
- self.assertEqual(f('foo').get_type(), 'UNKNOWN')
- # Statements that have a whitespace after the closing semicolon
- # are parsed as two statements where later only consists of the
- # trailing whitespace.
- self.assertEqual(f('\n').get_type(), 'UNKNOWN')
-
-
-def test_identifier_with_operators(): # issue 53
+from sqlparse import sql, tokens as T
+
+
+def test_grouping_parenthesis():
+ s = 'select (select (x3) x2) and (y2) bar'
+ parsed = sqlparse.parse(s)[0]
+ assert str(parsed) == s
+ assert len(parsed.tokens) == 7
+ assert isinstance(parsed.tokens[2], sql.Parenthesis)
+ assert isinstance(parsed.tokens[-1], sql.Identifier)
+ assert len(parsed.tokens[2].tokens) == 5
+ assert isinstance(parsed.tokens[2].tokens[3], sql.Identifier)
+ assert isinstance(parsed.tokens[2].tokens[3].tokens[0], sql.Parenthesis)
+ assert len(parsed.tokens[2].tokens[3].tokens) == 3
+
+
+def test_grouping_comments():
+ s = '/*\n * foo\n */ \n bar'
+ parsed = sqlparse.parse(s)[0]
+ assert str(parsed) == s
+ assert len(parsed.tokens) == 2
+
+
+@pytest.mark.parametrize('s', ['foo := 1;', 'foo := 1'])
+def test_grouping_assignment(s):
+ parsed = sqlparse.parse(s)[0]
+ assert len(parsed.tokens) == 1
+ assert isinstance(parsed.tokens[0], sql.Assignment)
+
+
+def test_grouping_identifiers():
+ s = 'select foo.bar from "myscheme"."table" where fail. order'
+ parsed = sqlparse.parse(s)[0]
+ assert str(parsed) == s
+ assert isinstance(parsed.tokens[2], sql.Identifier)
+ assert isinstance(parsed.tokens[6], sql.Identifier)
+ assert isinstance(parsed.tokens[8], sql.Where)
+ s = 'select * from foo where foo.id = 1'
+ parsed = sqlparse.parse(s)[0]
+ assert str(parsed) == s
+ assert isinstance(parsed.tokens[-1].tokens[-1].tokens[0], sql.Identifier)
+ s = 'select * from (select "foo"."id" from foo)'
+ parsed = sqlparse.parse(s)[0]
+ assert str(parsed) == s
+ assert isinstance(parsed.tokens[-1].tokens[3], sql.Identifier)
+
+ s = "INSERT INTO `test` VALUES('foo', 'bar');"
+ parsed = sqlparse.parse(s)[0]
+ types = [l.ttype for l in parsed.tokens if not l.is_whitespace()]
+ assert types == [T.DML, T.Keyword, None, T.Keyword, None, T.Punctuation]
+
+ s = "select 1.0*(a+b) as col, sum(c)/sum(d) from myschema.mytable"
+ parsed = sqlparse.parse(s)[0]
+ assert len(parsed.tokens) == 7
+ assert isinstance(parsed.tokens[2], sql.IdentifierList)
+ assert len(parsed.tokens[2].tokens) == 4
+ identifiers = list(parsed.tokens[2].get_identifiers())
+ assert len(identifiers) == 2
+ assert identifiers[0].get_alias() == "col"
+
+
+def test_grouping_identifier_wildcard():
+ p = sqlparse.parse('a.*, b.id')[0]
+ assert isinstance(p.tokens[0], sql.IdentifierList)
+ assert isinstance(p.tokens[0].tokens[0], sql.Identifier)
+ assert isinstance(p.tokens[0].tokens[-1], sql.Identifier)
+
+
+def test_grouping_identifier_name_wildcard():
+ p = sqlparse.parse('a.*')[0]
+ t = p.tokens[0]
+ assert t.get_name() == '*'
+ assert t.is_wildcard() is True
+
+
+def test_grouping_identifier_invalid():
+ p = sqlparse.parse('a.')[0]
+ assert isinstance(p.tokens[0], sql.Identifier)
+ assert p.tokens[0].has_alias() is False
+ assert p.tokens[0].get_name() is None
+ assert p.tokens[0].get_real_name() is None
+ assert p.tokens[0].get_parent_name() == 'a'
+
+
+def test_grouping_identifier_invalid_in_middle():
+ # issue261
+ s = 'SELECT foo. FROM foo'
+ p = sqlparse.parse(s)[0]
+ assert isinstance(p[2], sql.Identifier)
+ assert p[2][1].ttype == T.Punctuation
+ assert p[3].ttype == T.Whitespace
+ assert str(p[2]) == 'foo.'
+
+
+def test_grouping_identifier_as_invalid():
+ # issue8
+ p = sqlparse.parse('foo as select *')[0]
+ assert len(p.tokens), 5
+ assert isinstance(p.tokens[0], sql.Identifier)
+ assert len(p.tokens[0].tokens) == 1
+ assert p.tokens[2].ttype == T.Keyword
+
+
+def test_grouping_identifier_function():
+ p = sqlparse.parse('foo() as bar')[0]
+ assert isinstance(p.tokens[0], sql.Identifier)
+ assert isinstance(p.tokens[0].tokens[0], sql.Function)
+ p = sqlparse.parse('foo()||col2 bar')[0]
+ assert isinstance(p.tokens[0], sql.Identifier)
+ assert isinstance(p.tokens[0].tokens[0], sql.Operation)
+ assert isinstance(p.tokens[0].tokens[0].tokens[0], sql.Function)
+
+
+@pytest.mark.parametrize('s', ['foo+100', 'foo + 100', 'foo*100'])
+def test_grouping_operation(s):
+ p = sqlparse.parse(s)[0]
+ assert isinstance(p.tokens[0], sql.Operation)
+
+
+def test_grouping_identifier_list():
+ p = sqlparse.parse('a, b, c')[0]
+ assert isinstance(p.tokens[0], sql.IdentifierList)
+ p = sqlparse.parse('(a, b, c)')[0]
+ assert isinstance(p.tokens[0].tokens[1], sql.IdentifierList)
+
+
+def test_grouping_identifier_list_subquery():
+ """identifier lists should still work in subqueries with aliases"""
+ p = sqlparse.parse("select * from ("
+ "select a, b + c as d from table) sub")[0]
+ subquery = p.tokens[-1].tokens[0]
+ idx, iden_list = subquery.token_next_by(i=sql.IdentifierList)
+ assert iden_list is not None
+ # all the identifiers should be within the IdentifierList
+ _, ilist = subquery.token_next_by(i=sql.Identifier, idx=idx)
+ assert ilist is None
+
+
+def test_grouping_identifier_list_case():
+ p = sqlparse.parse('a, case when 1 then 2 else 3 end as b, c')[0]
+ assert isinstance(p.tokens[0], sql.IdentifierList)
+ p = sqlparse.parse('(a, case when 1 then 2 else 3 end as b, c)')[0]
+ assert isinstance(p.tokens[0].tokens[1], sql.IdentifierList)
+
+
+def test_grouping_identifier_list_other():
+ # issue2
+ p = sqlparse.parse("select *, null, 1, 'foo', bar from mytable, x")[0]
+ assert isinstance(p.tokens[2], sql.IdentifierList)
+ assert len(p.tokens[2].tokens) == 13
+
+
+def test_grouping_identifier_list_with_inline_comments():
+ # issue163
+ p = sqlparse.parse('foo /* a comment */, bar')[0]
+ assert isinstance(p.tokens[0], sql.IdentifierList)
+ assert isinstance(p.tokens[0].tokens[0], sql.Identifier)
+ assert isinstance(p.tokens[0].tokens[3], sql.Identifier)
+
+
+def test_grouping_identifiers_with_operators():
+ p = sqlparse.parse('a+b as c from table where (d-e)%2= 1')[0]
+ assert len([x for x in p.flatten() if x.ttype == T.Name]) == 5
+
+
+def test_grouping_identifier_list_with_order():
+ # issue101
+ p = sqlparse.parse('1, 2 desc, 3')[0]
+ assert isinstance(p.tokens[0], sql.IdentifierList)
+ assert isinstance(p.tokens[0].tokens[3], sql.Identifier)
+ assert str(p.tokens[0].tokens[3]) == '2 desc'
+
+
+def test_grouping_where():
+ s = 'select * from foo where bar = 1 order by id desc'
+ p = sqlparse.parse(s)[0]
+ assert str(p) == s
+ assert len(p.tokens) == 14
+
+ s = 'select x from (select y from foo where bar = 1) z'
+ p = sqlparse.parse(s)[0]
+ assert str(p) == s
+ assert isinstance(p.tokens[-1].tokens[0].tokens[-2], sql.Where)
+
+
+def test_grouping_typecast():
+ s = 'select foo::integer from bar'
+ p = sqlparse.parse(s)[0]
+ assert str(p) == s
+ assert p.tokens[2].get_typecast() == 'integer'
+ assert p.tokens[2].get_name() == 'foo'
+ s = 'select (current_database())::information_schema.sql_identifier'
+ p = sqlparse.parse(s)[0]
+ assert str(p) == s
+ assert (p.tokens[2].get_typecast() == 'information_schema.sql_identifier')
+
+
+def test_grouping_alias():
+ s = 'select foo as bar from mytable'
+ p = sqlparse.parse(s)[0]
+ assert str(p) == s
+ assert p.tokens[2].get_real_name() == 'foo'
+ assert p.tokens[2].get_alias() == 'bar'
+ s = 'select foo from mytable t1'
+ p = sqlparse.parse(s)[0]
+ assert str(p) == s
+ assert p.tokens[6].get_real_name() == 'mytable'
+ assert p.tokens[6].get_alias() == 't1'
+ s = 'select foo::integer as bar from mytable'
+ p = sqlparse.parse(s)[0]
+ assert str(p) == s
+ assert p.tokens[2].get_alias() == 'bar'
+ s = ('SELECT DISTINCT '
+ '(current_database())::information_schema.sql_identifier AS view')
+ p = sqlparse.parse(s)[0]
+ assert str(p) == s
+ assert p.tokens[4].get_alias() == 'view'
+
+
+def test_grouping_alias_case():
+ # see issue46
+ p = sqlparse.parse('CASE WHEN 1 THEN 2 ELSE 3 END foo')[0]
+ assert len(p.tokens) == 1
+ assert p.tokens[0].get_alias() == 'foo'
+
+
+def test_grouping_alias_returns_none():
+ # see issue185
+ p = sqlparse.parse('foo.bar')[0]
+ assert len(p.tokens) == 1
+ assert p.tokens[0].get_alias() is None
+
+
+def test_grouping_idlist_function():
+ # see issue10 too
+ p = sqlparse.parse('foo(1) x, bar')[0]
+ assert isinstance(p.tokens[0], sql.IdentifierList)
+
+
+def test_grouping_comparison_exclude():
+ # make sure operators are not handled too lazy
+ p = sqlparse.parse('(=)')[0]
+ assert isinstance(p.tokens[0], sql.Parenthesis)
+ assert not isinstance(p.tokens[0].tokens[1], sql.Comparison)
+ p = sqlparse.parse('(a=1)')[0]
+ assert isinstance(p.tokens[0].tokens[1], sql.Comparison)
+ p = sqlparse.parse('(a>=1)')[0]
+ assert isinstance(p.tokens[0].tokens[1], sql.Comparison)
+
+
+def test_grouping_function():
+ p = sqlparse.parse('foo()')[0]
+ assert isinstance(p.tokens[0], sql.Function)
+ p = sqlparse.parse('foo(null, bar)')[0]
+ assert isinstance(p.tokens[0], sql.Function)
+ assert len(list(p.tokens[0].get_parameters())) == 2
+
+
+def test_grouping_function_not_in():
+ # issue183
+ p = sqlparse.parse('in(1, 2)')[0]
+ assert len(p.tokens) == 2
+ assert p.tokens[0].ttype == T.Keyword
+ assert isinstance(p.tokens[1], sql.Parenthesis)
+
+
+def test_grouping_varchar():
+ p = sqlparse.parse('"text" Varchar(50) NOT NULL')[0]
+ assert isinstance(p.tokens[2], sql.Function)
+
+
+def test_statement_get_type():
+ def f(sql):
+ return sqlparse.parse(sql)[0]
+
+ assert f('select * from foo').get_type() == 'SELECT'
+ assert f('update foo').get_type() == 'UPDATE'
+ assert f(' update foo').get_type() == 'UPDATE'
+ assert f('\nupdate foo').get_type() == 'UPDATE'
+ assert f('foo').get_type() == 'UNKNOWN'
+ # Statements that have a whitespace after the closing semicolon
+ # are parsed as two statements where later only consists of the
+ # trailing whitespace.
+ assert f('\n').get_type() == 'UNKNOWN'
+
+
+def test_identifier_with_operators():
+ # issue 53
p = sqlparse.parse('foo||bar')[0]
assert len(p.tokens) == 1
assert isinstance(p.tokens[0], sql.Operation)
@@ -293,7 +309,7 @@ def test_identifier_with_op_trailing_ws():
def test_identifier_with_string_literals():
- p = sqlparse.parse('foo + \'bar\'')[0]
+ p = sqlparse.parse("foo + 'bar'")[0]
assert len(p.tokens) == 1
assert isinstance(p.tokens[0], sql.Operation)
@@ -302,12 +318,13 @@ def test_identifier_with_string_literals():
# showed that this shouldn't be an identifier at all. I'm leaving this
# commented in the source for a while.
# def test_identifier_string_concat():
-# p = sqlparse.parse('\'foo\' || bar')[0]
+# p = sqlparse.parse("'foo' || bar")[0]
# assert len(p.tokens) == 1
# assert isinstance(p.tokens[0], sql.Identifier)
-def test_identifier_consumes_ordering(): # issue89
+def test_identifier_consumes_ordering():
+ # issue89
p = sqlparse.parse('select * from foo order by c1 desc, c2, c3')[0]
assert isinstance(p.tokens[-1], sql.IdentifierList)
ids = list(p.tokens[-1].get_identifiers())
@@ -318,7 +335,8 @@ def test_identifier_consumes_ordering(): # issue89
assert ids[1].get_ordering() is None
-def test_comparison_with_keywords(): # issue90
+def test_comparison_with_keywords():
+ # issue90
# in fact these are assignments, but for now we don't distinguish them
p = sqlparse.parse('foo = NULL')[0]
assert len(p.tokens) == 1
@@ -332,7 +350,8 @@ def test_comparison_with_keywords(): # issue90
assert isinstance(p.tokens[0], sql.Comparison)
-def test_comparison_with_floats(): # issue145
+def test_comparison_with_floats():
+ # issue145
p = sqlparse.parse('foo = 25.5')[0]
assert len(p.tokens) == 1
assert isinstance(p.tokens[0], sql.Comparison)
@@ -341,7 +360,8 @@ def test_comparison_with_floats(): # issue145
assert p.tokens[0].right.value == '25.5'
-def test_comparison_with_parenthesis(): # issue23
+def test_comparison_with_parenthesis():
+ # issue23
p = sqlparse.parse('(3 + 4) = 7')[0]
assert len(p.tokens) == 1
assert isinstance(p.tokens[0], sql.Comparison)
@@ -350,15 +370,17 @@ def test_comparison_with_parenthesis(): # issue23
assert comp.right.ttype is T.Number.Integer
-def test_comparison_with_strings(): # issue148
- p = sqlparse.parse('foo = \'bar\'')[0]
+def test_comparison_with_strings():
+ # issue148
+ p = sqlparse.parse("foo = 'bar'")[0]
assert len(p.tokens) == 1
assert isinstance(p.tokens[0], sql.Comparison)
- assert p.tokens[0].right.value == '\'bar\''
+ assert p.tokens[0].right.value == "'bar'"
assert p.tokens[0].right.ttype == T.String.Single
-def test_comparison_with_functions(): # issue230
+def test_comparison_with_functions():
+ # issue230
p = sqlparse.parse('foo = DATE(bar.baz)')[0]
assert len(p.tokens) == 1
assert isinstance(p.tokens[0], sql.Comparison)
diff --git a/tests/test_parse.py b/tests/test_parse.py
index 75a7ab5..2d23425 100644
--- a/tests/test_parse.py
+++ b/tests/test_parse.py
@@ -1,161 +1,150 @@
# -*- coding: utf-8 -*-
-"""Tests sqlparse function."""
+"""Tests sqlparse.parse()."""
import pytest
-from tests.utils import TestCaseBase
-
import sqlparse
-import sqlparse.sql
-from sqlparse import tokens as T
-from sqlparse.compat import u, StringIO
-
-
-class SQLParseTest(TestCaseBase):
- """Tests sqlparse.parse()."""
-
- def test_tokenize(self):
- sql = 'select * from foo;'
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 1)
- self.assertEqual(str(stmts[0]), sql)
-
- def test_multistatement(self):
- sql1 = 'select * from foo;'
- sql2 = 'select * from bar;'
- stmts = sqlparse.parse(sql1 + sql2)
- self.assertEqual(len(stmts), 2)
- self.assertEqual(str(stmts[0]), sql1)
- self.assertEqual(str(stmts[1]), sql2)
-
- def test_newlines(self):
- sql = u'select\n*from foo;'
- p = sqlparse.parse(sql)[0]
- self.assertEqual(u(p), sql)
- sql = u'select\r\n*from foo'
- p = sqlparse.parse(sql)[0]
- self.assertEqual(u(p), sql)
- sql = u'select\r*from foo'
- p = sqlparse.parse(sql)[0]
- self.assertEqual(u(p), sql)
- sql = u'select\r\n*from foo\n'
- p = sqlparse.parse(sql)[0]
- self.assertEqual(u(p), sql)
-
- def test_within(self):
- sql = 'foo(col1, col2)'
- p = sqlparse.parse(sql)[0]
- col1 = p.tokens[0].tokens[1].tokens[1].tokens[0]
- self.assert_(col1.within(sqlparse.sql.Function))
-
- def test_child_of(self):
- sql = '(col1, col2)'
- p = sqlparse.parse(sql)[0]
- self.assert_(p.tokens[0].tokens[1].is_child_of(p.tokens[0]))
- sql = 'select foo'
- p = sqlparse.parse(sql)[0]
- self.assert_(not p.tokens[2].is_child_of(p.tokens[0]))
- self.assert_(p.tokens[2].is_child_of(p))
-
- def test_has_ancestor(self):
- sql = 'foo or (bar, baz)'
- p = sqlparse.parse(sql)[0]
- baz = p.tokens[-1].tokens[1].tokens[-1]
- self.assert_(baz.has_ancestor(p.tokens[-1].tokens[1]))
- self.assert_(baz.has_ancestor(p.tokens[-1]))
- self.assert_(baz.has_ancestor(p))
-
- def test_float(self):
- t = sqlparse.parse('.5')[0].tokens
- self.assertEqual(len(t), 1)
- self.assert_(t[0].ttype is sqlparse.tokens.Number.Float)
- t = sqlparse.parse('.51')[0].tokens
- self.assertEqual(len(t), 1)
- self.assert_(t[0].ttype is sqlparse.tokens.Number.Float)
- t = sqlparse.parse('1.5')[0].tokens
- self.assertEqual(len(t), 1)
- self.assert_(t[0].ttype is sqlparse.tokens.Number.Float)
- t = sqlparse.parse('12.5')[0].tokens
- self.assertEqual(len(t), 1)
- self.assert_(t[0].ttype is sqlparse.tokens.Number.Float)
-
- def test_placeholder(self):
- def _get_tokens(sql):
- return sqlparse.parse(sql)[0].tokens[-1].tokens
-
- t = _get_tokens('select * from foo where user = ?')
- self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder)
- self.assertEqual(t[-1].value, '?')
- t = _get_tokens('select * from foo where user = :1')
- self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder)
- self.assertEqual(t[-1].value, ':1')
- t = _get_tokens('select * from foo where user = :name')
- self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder)
- self.assertEqual(t[-1].value, ':name')
- t = _get_tokens('select * from foo where user = %s')
- self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder)
- self.assertEqual(t[-1].value, '%s')
- t = _get_tokens('select * from foo where user = $a')
- self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder)
- self.assertEqual(t[-1].value, '$a')
-
- def test_modulo_not_placeholder(self):
- tokens = list(sqlparse.lexer.tokenize('x %3'))
- self.assertEqual(tokens[2][0], sqlparse.tokens.Operator)
-
- def test_access_symbol(self): # see issue27
- t = sqlparse.parse('select a.[foo bar] as foo')[0].tokens
- self.assert_(isinstance(t[-1], sqlparse.sql.Identifier))
- self.assertEqual(t[-1].get_name(), 'foo')
- self.assertEqual(t[-1].get_real_name(), '[foo bar]')
- self.assertEqual(t[-1].get_parent_name(), 'a')
-
- def test_square_brackets_notation_isnt_too_greedy(self): # see issue153
- t = sqlparse.parse('[foo], [bar]')[0].tokens
- self.assert_(isinstance(t[0], sqlparse.sql.IdentifierList))
- self.assertEqual(len(t[0].tokens), 4)
- self.assertEqual(t[0].tokens[0].get_real_name(), '[foo]')
- self.assertEqual(t[0].tokens[-1].get_real_name(), '[bar]')
-
- def test_keyword_like_identifier(self): # see issue47
- t = sqlparse.parse('foo.key')[0].tokens
- self.assertEqual(len(t), 1)
- self.assert_(isinstance(t[0], sqlparse.sql.Identifier))
-
- def test_function_parameter(self): # see issue94
- t = sqlparse.parse('abs(some_col)')[0].tokens[0].get_parameters()
- self.assertEqual(len(t), 1)
- self.assert_(isinstance(t[0], sqlparse.sql.Identifier))
-
- def test_function_param_single_literal(self):
- t = sqlparse.parse('foo(5)')[0].tokens[0].get_parameters()
- self.assertEqual(len(t), 1)
- self.assert_(t[0].ttype is T.Number.Integer)
-
- def test_nested_function(self):
- t = sqlparse.parse('foo(bar(5))')[0].tokens[0].get_parameters()
- self.assertEqual(len(t), 1)
- self.assert_(type(t[0]) is sqlparse.sql.Function)
+from sqlparse import sql, tokens as T
+from sqlparse.compat import StringIO
+
+
+def test_parse_tokenize():
+ s = 'select * from foo;'
+ stmts = sqlparse.parse(s)
+ assert len(stmts) == 1
+ assert str(stmts[0]) == s
+
+
+def test_parse_multistatement():
+ sql1 = 'select * from foo;'
+ sql2 = 'select * from bar;'
+ stmts = sqlparse.parse(sql1 + sql2)
+ assert len(stmts) == 2
+ assert str(stmts[0]) == sql1
+ assert str(stmts[1]) == sql2
+
+
+@pytest.mark.parametrize('s', ['select\n*from foo;',
+ 'select\r\n*from foo',
+ 'select\r*from foo',
+ 'select\r\n*from foo\n'])
+def test_parse_newlines(s):
+ p = sqlparse.parse(s)[0]
+ assert str(p) == s
+
+
+def test_parse_within():
+ s = 'foo(col1, col2)'
+ p = sqlparse.parse(s)[0]
+ col1 = p.tokens[0].tokens[1].tokens[1].tokens[0]
+ assert col1.within(sql.Function)
+
+
+def test_parse_child_of():
+ s = '(col1, col2)'
+ p = sqlparse.parse(s)[0]
+ assert p.tokens[0].tokens[1].is_child_of(p.tokens[0])
+ s = 'select foo'
+ p = sqlparse.parse(s)[0]
+ assert not p.tokens[2].is_child_of(p.tokens[0])
+ assert p.tokens[2].is_child_of(p)
+
+
+def test_parse_has_ancestor():
+ s = 'foo or (bar, baz)'
+ p = sqlparse.parse(s)[0]
+ baz = p.tokens[-1].tokens[1].tokens[-1]
+ assert baz.has_ancestor(p.tokens[-1].tokens[1])
+ assert baz.has_ancestor(p.tokens[-1])
+ assert baz.has_ancestor(p)
+
+
+@pytest.mark.parametrize('s', ['.5', '.51', '1.5', '12.5'])
+def test_parse_float(s):
+ t = sqlparse.parse(s)[0].tokens
+ assert len(t) == 1
+ assert t[0].ttype is sqlparse.tokens.Number.Float
+
+
+@pytest.mark.parametrize('s, holder', [
+ ('select * from foo where user = ?', '?'),
+ ('select * from foo where user = :1', ':1'),
+ ('select * from foo where user = :name', ':name'),
+ ('select * from foo where user = %s', '%s'),
+ ('select * from foo where user = $a', '$a')])
+def test_parse_placeholder(s, holder):
+ t = sqlparse.parse(s)[0].tokens[-1].tokens
+ assert t[-1].ttype is sqlparse.tokens.Name.Placeholder
+ assert t[-1].value == holder
+
+
+def test_parse_modulo_not_placeholder():
+ tokens = list(sqlparse.lexer.tokenize('x %3'))
+ assert tokens[2][0] == sqlparse.tokens.Operator
+
+
+def test_parse_access_symbol():
+ # see issue27
+ t = sqlparse.parse('select a.[foo bar] as foo')[0].tokens
+ assert isinstance(t[-1], sql.Identifier)
+ assert t[-1].get_name() == 'foo'
+ assert t[-1].get_real_name() == '[foo bar]'
+ assert t[-1].get_parent_name() == 'a'
+
+
+def test_parse_square_brackets_notation_isnt_too_greedy():
+ # see issue153
+ t = sqlparse.parse('[foo], [bar]')[0].tokens
+ assert isinstance(t[0], sql.IdentifierList)
+ assert len(t[0].tokens) == 4
+ assert t[0].tokens[0].get_real_name() == '[foo]'
+ assert t[0].tokens[-1].get_real_name() == '[bar]'
+
+
+def test_parse_keyword_like_identifier():
+ # see issue47
+ t = sqlparse.parse('foo.key')[0].tokens
+ assert len(t) == 1
+ assert isinstance(t[0], sql.Identifier)
+
+
+def test_parse_function_parameter():
+ # see issue94
+ t = sqlparse.parse('abs(some_col)')[0].tokens[0].get_parameters()
+ assert len(t) == 1
+ assert isinstance(t[0], sql.Identifier)
+
+
+def test_parse_function_param_single_literal():
+ t = sqlparse.parse('foo(5)')[0].tokens[0].get_parameters()
+ assert len(t) == 1
+ assert t[0].ttype is T.Number.Integer
+
+
+def test_parse_nested_function():
+ t = sqlparse.parse('foo(bar(5))')[0].tokens[0].get_parameters()
+ assert len(t) == 1
+ assert type(t[0]) is sql.Function
def test_quoted_identifier():
t = sqlparse.parse('select x.y as "z" from foo')[0].tokens
- assert isinstance(t[2], sqlparse.sql.Identifier)
+ assert isinstance(t[2], sql.Identifier)
assert t[2].get_name() == 'z'
assert t[2].get_real_name() == 'y'
-@pytest.mark.parametrize('name', [
- 'foo',
- '_foo',
-])
-def test_valid_identifier_names(name): # issue175
+@pytest.mark.parametrize('name', ['foo', '_foo'])
+def test_valid_identifier_names(name):
+ # issue175
t = sqlparse.parse(name)[0].tokens
- assert isinstance(t[0], sqlparse.sql.Identifier)
+ assert isinstance(t[0], sql.Identifier)
+
+def test_psql_quotation_marks():
+ # issue83
-def test_psql_quotation_marks(): # issue83
# regression: make sure plain $$ work
t = sqlparse.split("""
CREATE OR REPLACE FUNCTION testfunc1(integer) RETURNS integer AS $$
@@ -165,6 +154,7 @@ def test_psql_quotation_marks(): # issue83
....
$$ LANGUAGE plpgsql;""")
assert len(t) == 2
+
# make sure $SOMETHING$ works too
t = sqlparse.split("""
CREATE OR REPLACE FUNCTION testfunc1(integer) RETURNS integer AS $PROC_1$
@@ -177,8 +167,8 @@ def test_psql_quotation_marks(): # issue83
def test_double_precision_is_builtin():
- sql = 'DOUBLE PRECISION'
- t = sqlparse.parse(sql)[0].tokens
+ s = 'DOUBLE PRECISION'
+ t = sqlparse.parse(s)[0].tokens
assert len(t) == 1
assert t[0].ttype == sqlparse.tokens.Name.Builtin
assert t[0].value == 'DOUBLE PRECISION'
@@ -207,10 +197,11 @@ def test_single_quotes_are_strings():
def test_double_quotes_are_identifiers():
p = sqlparse.parse('"foo"')[0].tokens
assert len(p) == 1
- assert isinstance(p[0], sqlparse.sql.Identifier)
+ assert isinstance(p[0], sql.Identifier)
-def test_single_quotes_with_linebreaks(): # issue118
+def test_single_quotes_with_linebreaks():
+ # issue118
p = sqlparse.parse("'f\nf'")[0].tokens
assert len(p) == 1
assert p[0].ttype is T.String.Single
@@ -221,7 +212,7 @@ def test_sqlite_identifiers():
p = sqlparse.parse('[col1],[col2]')[0].tokens
id_names = [id_.get_name() for id_ in p[0].get_identifiers()]
assert len(p) == 1
- assert isinstance(p[0], sqlparse.sql.IdentifierList)
+ assert isinstance(p[0], sql.IdentifierList)
assert id_names == ['[col1]', '[col2]']
p = sqlparse.parse('[col1]+[col2]')[0]
@@ -280,35 +271,29 @@ def test_typed_array_definition():
# indentifer names
p = sqlparse.parse('x int, y int[], z int')[0]
names = [x.get_name() for x in p.get_sublists()
- if isinstance(x, sqlparse.sql.Identifier)]
+ if isinstance(x, sql.Identifier)]
assert names == ['x', 'y', 'z']
-@pytest.mark.parametrize('sql', [
- 'select 1 -- foo',
- 'select 1 # foo' # see issue178
-])
-def test_single_line_comments(sql):
- p = sqlparse.parse(sql)[0]
+@pytest.mark.parametrize('s', ['select 1 -- foo', 'select 1 # foo'])
+def test_single_line_comments(s):
+ # see issue178
+ p = sqlparse.parse(s)[0]
assert len(p.tokens) == 5
assert p.tokens[-1].ttype == T.Comment.Single
-@pytest.mark.parametrize('sql', [
- 'foo',
- '@foo',
- '#foo', # see issue192
- '##foo'
-])
-def test_names_and_special_names(sql):
- p = sqlparse.parse(sql)[0]
+@pytest.mark.parametrize('s', ['foo', '@foo', '#foo', '##foo'])
+def test_names_and_special_names(s):
+ # see issue192
+ p = sqlparse.parse(s)[0]
assert len(p.tokens) == 1
- assert isinstance(p.tokens[0], sqlparse.sql.Identifier)
+ assert isinstance(p.tokens[0], sql.Identifier)
def test_get_token_at_offset():
- # 0123456789
p = sqlparse.parse('select * from dual')[0]
+ # 0123456789
assert p.get_token_at_offset(0) == p.tokens[0]
assert p.get_token_at_offset(1) == p.tokens[0]
assert p.get_token_at_offset(6) == p.tokens[1]
@@ -324,60 +309,61 @@ def test_pprint():
output = StringIO()
p._pprint_tree(f=output)
- pprint = u'\n'.join([" 0 DML 'select'",
- " 1 Whitespace ' '",
- " 2 IdentifierList 'a0, b0...'",
- " | 0 Identifier 'a0'",
- " | | 0 Name 'a0'",
- " | 1 Punctuation ','",
- " | 2 Whitespace ' '",
- " | 3 Identifier 'b0'",
- " | | 0 Name 'b0'",
- " | 4 Punctuation ','",
- " | 5 Whitespace ' '",
- " | 6 Identifier 'c0'",
- " | | 0 Name 'c0'",
- " | 7 Punctuation ','",
- " | 8 Whitespace ' '",
- " | 9 Identifier 'd0'",
- " | | 0 Name 'd0'",
- " | 10 Punctuation ','",
- " | 11 Whitespace ' '",
- " | 12 Float 'e0'",
- " 3 Whitespace ' '",
- " 4 Keyword 'from'",
- " 5 Whitespace ' '",
- " 6 Identifier '(selec...'",
- " | 0 Parenthesis '(selec...'",
- " | | 0 Punctuation '('",
- " | | 1 DML 'select'",
- " | | 2 Whitespace ' '",
- " | | 3 Wildcard '*'",
- " | | 4 Whitespace ' '",
- " | | 5 Keyword 'from'",
- " | | 6 Whitespace ' '",
- " | | 7 Identifier 'dual'",
- " | | | 0 Name 'dual'",
- " | | 8 Punctuation ')'",
- " | 1 Whitespace ' '",
- " | 2 Identifier 'q0'",
- " | | 0 Name 'q0'",
- " 7 Whitespace ' '",
- " 8 Where 'where ...'",
- " | 0 Keyword 'where'",
- " | 1 Whitespace ' '",
- " | 2 Comparison '1=1'",
- " | | 0 Integer '1'",
- " | | 1 Comparison '='",
- " | | 2 Integer '1'",
- " | 3 Whitespace ' '",
- " | 4 Keyword 'and'",
- " | 5 Whitespace ' '",
- " | 6 Comparison '2=2'",
- " | | 0 Integer '2'",
- " | | 1 Comparison '='",
- " | | 2 Integer '2'",
- "", ])
+ pprint = '\n'.join([
+ " 0 DML 'select'",
+ " 1 Whitespace ' '",
+ " 2 IdentifierList 'a0, b0...'",
+ " | 0 Identifier 'a0'",
+ " | | 0 Name 'a0'",
+ " | 1 Punctuation ','",
+ " | 2 Whitespace ' '",
+ " | 3 Identifier 'b0'",
+ " | | 0 Name 'b0'",
+ " | 4 Punctuation ','",
+ " | 5 Whitespace ' '",
+ " | 6 Identifier 'c0'",
+ " | | 0 Name 'c0'",
+ " | 7 Punctuation ','",
+ " | 8 Whitespace ' '",
+ " | 9 Identifier 'd0'",
+ " | | 0 Name 'd0'",
+ " | 10 Punctuation ','",
+ " | 11 Whitespace ' '",
+ " | 12 Float 'e0'",
+ " 3 Whitespace ' '",
+ " 4 Keyword 'from'",
+ " 5 Whitespace ' '",
+ " 6 Identifier '(selec...'",
+ " | 0 Parenthesis '(selec...'",
+ " | | 0 Punctuation '('",
+ " | | 1 DML 'select'",
+ " | | 2 Whitespace ' '",
+ " | | 3 Wildcard '*'",
+ " | | 4 Whitespace ' '",
+ " | | 5 Keyword 'from'",
+ " | | 6 Whitespace ' '",
+ " | | 7 Identifier 'dual'",
+ " | | | 0 Name 'dual'",
+ " | | 8 Punctuation ')'",
+ " | 1 Whitespace ' '",
+ " | 2 Identifier 'q0'",
+ " | | 0 Name 'q0'",
+ " 7 Whitespace ' '",
+ " 8 Where 'where ...'",
+ " | 0 Keyword 'where'",
+ " | 1 Whitespace ' '",
+ " | 2 Comparison '1=1'",
+ " | | 0 Integer '1'",
+ " | | 1 Comparison '='",
+ " | | 2 Integer '1'",
+ " | 3 Whitespace ' '",
+ " | 4 Keyword 'and'",
+ " | 5 Whitespace ' '",
+ " | 6 Comparison '2=2'",
+ " | | 0 Integer '2'",
+ " | | 1 Comparison '='",
+ " | | 2 Integer '2'",
+ ""])
assert output.getvalue() == pprint
@@ -394,7 +380,7 @@ def test_wildcard_multiplication():
def test_stmt_tokens_parents():
# see issue 226
- sql = "CREATE TABLE test();"
- stmt = sqlparse.parse(sql)[0]
+ s = "CREATE TABLE test();"
+ stmt = sqlparse.parse(s)[0]
for token in stmt.tokens:
assert token.has_ancestor(stmt)
diff --git a/tests/test_regressions.py b/tests/test_regressions.py
index 14aab24..255493c 100644
--- a/tests/test_regressions.py
+++ b/tests/test_regressions.py
@@ -1,163 +1,146 @@
# -*- coding: utf-8 -*-
-import sys
-
-import pytest # noqa
-from tests.utils import TestCaseBase, load_file
+import pytest
import sqlparse
-from sqlparse import sql
-from sqlparse import tokens as T
-
-
-class RegressionTests(TestCaseBase):
-
- def test_issue9(self):
- # make sure where doesn't consume parenthesis
- p = sqlparse.parse('(where 1)')[0]
- self.assert_(isinstance(p, sql.Statement))
- self.assertEqual(len(p.tokens), 1)
- self.assert_(isinstance(p.tokens[0], sql.Parenthesis))
- prt = p.tokens[0]
- self.assertEqual(len(prt.tokens), 3)
- self.assertEqual(prt.tokens[0].ttype, T.Punctuation)
- self.assertEqual(prt.tokens[-1].ttype, T.Punctuation)
-
- def test_issue13(self):
- parsed = sqlparse.parse(("select 'one';\n"
- "select 'two\\'';\n"
- "select 'three';"))
- self.assertEqual(len(parsed), 3)
- self.assertEqual(str(parsed[1]).strip(), "select 'two\\'';")
-
- def test_issue26(self):
- # parse stand-alone comments
- p = sqlparse.parse('--hello')[0]
- self.assertEqual(len(p.tokens), 1)
- self.assert_(p.tokens[0].ttype is T.Comment.Single)
- p = sqlparse.parse('-- hello')[0]
- self.assertEqual(len(p.tokens), 1)
- self.assert_(p.tokens[0].ttype is T.Comment.Single)
- p = sqlparse.parse('--hello\n')[0]
- self.assertEqual(len(p.tokens), 1)
- self.assert_(p.tokens[0].ttype is T.Comment.Single)
- p = sqlparse.parse('--')[0]
- self.assertEqual(len(p.tokens), 1)
- self.assert_(p.tokens[0].ttype is T.Comment.Single)
- p = sqlparse.parse('--\n')[0]
- self.assertEqual(len(p.tokens), 1)
- self.assert_(p.tokens[0].ttype is T.Comment.Single)
-
- def test_issue34(self):
- t = sqlparse.parse("create")[0].token_first()
- self.assertEqual(t.match(T.Keyword.DDL, "create"), True)
- self.assertEqual(t.match(T.Keyword.DDL, "CREATE"), True)
-
- def test_issue35(self):
- # missing space before LIMIT
- sql = sqlparse.format("select * from foo where bar = 1 limit 1",
- reindent=True)
- self.ndiffAssertEqual(sql, "\n".join(["select *",
- "from foo",
- "where bar = 1 limit 1"]))
-
- def test_issue38(self):
- sql = sqlparse.format("SELECT foo; -- comment",
- strip_comments=True)
- self.ndiffAssertEqual(sql, "SELECT foo;")
- sql = sqlparse.format("/* foo */", strip_comments=True)
- self.ndiffAssertEqual(sql, "")
-
- def test_issue39(self):
- p = sqlparse.parse('select user.id from user')[0]
- self.assertEqual(len(p.tokens), 7)
- idt = p.tokens[2]
- self.assertEqual(idt.__class__, sql.Identifier)
- self.assertEqual(len(idt.tokens), 3)
- self.assertEqual(idt.tokens[0].match(T.Name, 'user'), True)
- self.assertEqual(idt.tokens[1].match(T.Punctuation, '.'), True)
- self.assertEqual(idt.tokens[2].match(T.Name, 'id'), True)
-
- def test_issue40(self):
- # make sure identifier lists in subselects are grouped
- p = sqlparse.parse(('SELECT id, name FROM '
- '(SELECT id, name FROM bar) as foo'))[0]
- self.assertEqual(len(p.tokens), 7)
- self.assertEqual(p.tokens[2].__class__, sql.IdentifierList)
- self.assertEqual(p.tokens[-1].__class__, sql.Identifier)
- self.assertEqual(p.tokens[-1].get_name(), u'foo')
- sp = p.tokens[-1].tokens[0]
- self.assertEqual(sp.tokens[3].__class__, sql.IdentifierList)
- # make sure that formatting works as expected
- self.ndiffAssertEqual(
- sqlparse.format(('SELECT id, name FROM '
- '(SELECT id, name FROM bar)'),
- reindent=True),
- ('SELECT id,\n'
- ' name\n'
- 'FROM\n'
- ' (SELECT id,\n'
- ' name\n'
- ' FROM bar)'))
- self.ndiffAssertEqual(
- sqlparse.format(('SELECT id, name FROM '
- '(SELECT id, name FROM bar) as foo'),
- reindent=True),
- ('SELECT id,\n'
- ' name\n'
- 'FROM\n'
- ' (SELECT id,\n'
- ' name\n'
- ' FROM bar) as foo'))
-
-
-def test_issue78():
+from sqlparse import sql, tokens as T
+from sqlparse.compat import PY2
+
+
+def test_issue9():
+ # make sure where doesn't consume parenthesis
+ p = sqlparse.parse('(where 1)')[0]
+ assert isinstance(p, sql.Statement)
+ assert len(p.tokens) == 1
+ assert isinstance(p.tokens[0], sql.Parenthesis)
+ prt = p.tokens[0]
+ assert len(prt.tokens) == 3
+ assert prt.tokens[0].ttype == T.Punctuation
+ assert prt.tokens[-1].ttype == T.Punctuation
+
+
+def test_issue13():
+ parsed = sqlparse.parse(("select 'one';\n"
+ "select 'two\\'';\n"
+ "select 'three';"))
+ assert len(parsed) == 3
+ assert str(parsed[1]).strip() == "select 'two\\'';"
+
+
+@pytest.mark.parametrize('s', ['--hello', '-- hello', '--hello\n',
+ '--', '--\n'])
+def test_issue26(s):
+ # parse stand-alone comments
+ p = sqlparse.parse(s)[0]
+ assert len(p.tokens) == 1
+ assert p.tokens[0].ttype is T.Comment.Single
+
+
+@pytest.mark.parametrize('value', ['create', 'CREATE'])
+def test_issue34(value):
+ t = sqlparse.parse("create")[0].token_first()
+ assert t.match(T.Keyword.DDL, value) is True
+
+
+def test_issue35():
+ # missing space before LIMIT
+ sql = sqlparse.format("select * from foo where bar = 1 limit 1",
+ reindent=True)
+ assert sql == "\n".join([
+ "select *",
+ "from foo",
+ "where bar = 1 limit 1"])
+
+
+def test_issue38():
+ sql = sqlparse.format("SELECT foo; -- comment", strip_comments=True)
+ assert sql == "SELECT foo;"
+ sql = sqlparse.format("/* foo */", strip_comments=True)
+ assert sql == ""
+
+
+def test_issue39():
+ p = sqlparse.parse('select user.id from user')[0]
+ assert len(p.tokens) == 7
+ idt = p.tokens[2]
+ assert idt.__class__ == sql.Identifier
+ assert len(idt.tokens) == 3
+ assert idt.tokens[0].match(T.Name, 'user') is True
+ assert idt.tokens[1].match(T.Punctuation, '.') is True
+ assert idt.tokens[2].match(T.Name, 'id') is True
+
+
+def test_issue40():
+ # make sure identifier lists in subselects are grouped
+ p = sqlparse.parse(('SELECT id, name FROM '
+ '(SELECT id, name FROM bar) as foo'))[0]
+ assert len(p.tokens) == 7
+ assert p.tokens[2].__class__ == sql.IdentifierList
+ assert p.tokens[-1].__class__ == sql.Identifier
+ assert p.tokens[-1].get_name() == 'foo'
+ sp = p.tokens[-1].tokens[0]
+ assert sp.tokens[3].__class__ == sql.IdentifierList
+ # make sure that formatting works as expected
+ s = sqlparse.format('SELECT id == name FROM '
+ '(SELECT id, name FROM bar)', reindent=True)
+ assert s == '\n'.join([
+ 'SELECT id == name',
+ 'FROM',
+ ' (SELECT id,',
+ ' name',
+ ' FROM bar)'])
+
+ s = sqlparse.format('SELECT id == name FROM '
+ '(SELECT id, name FROM bar) as foo', reindent=True)
+ assert s == '\n'.join([
+ 'SELECT id == name',
+ 'FROM',
+ ' (SELECT id,',
+ ' name',
+ ' FROM bar) as foo'])
+
+
+@pytest.mark.parametrize('s', ['select x.y::text as z from foo',
+ 'select x.y::text as "z" from foo',
+ 'select x."y"::text as z from foo',
+ 'select x."y"::text as "z" from foo',
+ 'select "x".y::text as z from foo',
+ 'select "x".y::text as "z" from foo',
+ 'select "x"."y"::text as z from foo',
+ 'select "x"."y"::text as "z" from foo'])
+@pytest.mark.parametrize('func_name, result', [('get_name', 'z'),
+ ('get_real_name', 'y'),
+ ('get_parent_name', 'x'),
+ ('get_alias', 'z'),
+ ('get_typecast', 'text')])
+def test_issue78(s, func_name, result):
# the bug author provided this nice examples, let's use them!
- def _get_identifier(sql):
- p = sqlparse.parse(sql)[0]
- return p.tokens[2]
- results = (('get_name', 'z'),
- ('get_real_name', 'y'),
- ('get_parent_name', 'x'),
- ('get_alias', 'z'),
- ('get_typecast', 'text'))
- variants = (
- 'select x.y::text as z from foo',
- 'select x.y::text as "z" from foo',
- 'select x."y"::text as z from foo',
- 'select x."y"::text as "z" from foo',
- 'select "x".y::text as z from foo',
- 'select "x".y::text as "z" from foo',
- 'select "x"."y"::text as z from foo',
- 'select "x"."y"::text as "z" from foo',
- )
- for variant in variants:
- i = _get_identifier(variant)
- assert isinstance(i, sql.Identifier)
- for func_name, result in results:
- func = getattr(i, func_name)
- assert func() == result
+ p = sqlparse.parse(s)[0]
+ i = p.tokens[2]
+ assert isinstance(i, sql.Identifier)
+
+ func = getattr(i, func_name)
+ assert func() == result
def test_issue83():
- sql = """
-CREATE OR REPLACE FUNCTION func_a(text)
- RETURNS boolean LANGUAGE plpgsql STRICT IMMUTABLE AS
-$_$
-BEGIN
- ...
-END;
-$_$;
-
-CREATE OR REPLACE FUNCTION func_b(text)
- RETURNS boolean LANGUAGE plpgsql STRICT IMMUTABLE AS
-$_$
-BEGIN
- ...
-END;
-$_$;
-
-ALTER TABLE..... ;"""
+ sql = """ CREATE OR REPLACE FUNCTION func_a(text)
+ RETURNS boolean LANGUAGE plpgsql STRICT IMMUTABLE AS
+ $_$
+ BEGIN
+ ...
+ END;
+ $_$;
+
+ CREATE OR REPLACE FUNCTION func_b(text)
+ RETURNS boolean LANGUAGE plpgsql STRICT IMMUTABLE AS
+ $_$
+ BEGIN
+ ...
+ END;
+ $_$;
+
+ ALTER TABLE..... ;"""
t = sqlparse.split(sql)
assert len(t) == 3
@@ -177,7 +160,7 @@ def test_parse_sql_with_binary():
sql = "select * from foo where bar = '{0}'".format(digest)
formatted = sqlparse.format(sql, reindent=True)
tformatted = "select *\nfrom foo\nwhere bar = '{0}'".format(digest)
- if sys.version_info < (3,):
+ if PY2:
tformatted = tformatted.decode('unicode-escape')
assert formatted == tformatted
@@ -192,7 +175,8 @@ def test_dont_alias_keywords():
assert p.tokens[2].ttype is T.Keyword
-def test_format_accepts_encoding(): # issue20
+def test_format_accepts_encoding(load_file):
+ # issue20
sql = load_file('test_cp1251.sql', 'cp1251')
formatted = sqlparse.format(sql, reindent=True, encoding='cp1251')
tformatted = u'insert into foo\nvalues (1); -- Песня про надежду\n'
@@ -206,17 +190,18 @@ def test_issue90():
' "rating_score" = 0, "thumbnail_width" = NULL,'
' "thumbnail_height" = NULL, "price" = 1, "description" = NULL')
formatted = sqlparse.format(sql, reindent=True)
- tformatted = '\n'.join(['UPDATE "gallery_photo"',
- 'SET "owner_id" = 4018,',
- ' "deleted_at" = NULL,',
- ' "width" = NULL,',
- ' "height" = NULL,',
- ' "rating_votes" = 0,',
- ' "rating_score" = 0,',
- ' "thumbnail_width" = NULL,',
- ' "thumbnail_height" = NULL,',
- ' "price" = 1,',
- ' "description" = NULL'])
+ tformatted = '\n'.join([
+ 'UPDATE "gallery_photo"',
+ 'SET "owner_id" = 4018,',
+ ' "deleted_at" = NULL,',
+ ' "width" = NULL,',
+ ' "height" = NULL,',
+ ' "rating_votes" = 0,',
+ ' "rating_score" = 0,',
+ ' "thumbnail_width" = NULL,',
+ ' "thumbnail_height" = NULL,',
+ ' "price" = 1,',
+ ' "description" = NULL'])
assert formatted == tformatted
@@ -230,8 +215,7 @@ def test_except_formatting():
'EXCEPT',
'SELECT 2',
'FROM bar',
- 'WHERE 1 = 2'
- ])
+ 'WHERE 1 = 2'])
assert formatted == tformatted
@@ -241,32 +225,31 @@ def test_null_with_as():
tformatted = '\n'.join([
'SELECT NULL AS c1,',
' NULL AS c2',
- 'FROM t1'
- ])
+ 'FROM t1'])
assert formatted == tformatted
def test_issue193_splitting_function():
- sql = """CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20)
-BEGIN
- DECLARE y VARCHAR(20);
- RETURN x;
-END;
-SELECT * FROM a.b;"""
+ sql = """ CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20)
+ BEGIN
+ DECLARE y VARCHAR(20);
+ RETURN x;
+ END;
+ SELECT * FROM a.b;"""
splitted = sqlparse.split(sql)
assert len(splitted) == 2
def test_issue194_splitting_function():
- sql = """CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20)
-BEGIN
- DECLARE y VARCHAR(20);
- IF (1 = 1) THEN
- SET x = y;
- END IF;
- RETURN x;
-END;
-SELECT * FROM a.b;"""
+ sql = """ CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20)
+ BEGIN
+ DECLARE y VARCHAR(20);
+ IF (1 = 1) THEN
+ SET x = y;
+ END IF;
+ RETURN x;
+ END;
+ SELECT * FROM a.b;"""
splitted = sqlparse.split(sql)
assert len(splitted) == 2
@@ -278,8 +261,8 @@ def test_issue186_get_type():
def test_issue212_py2unicode():
- t1 = sql.Token(T.String, u"schöner ")
- t2 = sql.Token(T.String, u"bug")
+ t1 = sql.Token(T.String, u'schöner ')
+ t2 = sql.Token(T.String, 'bug')
l = sql.TokenList([t1, t2])
assert str(l) == 'schöner bug'
@@ -295,23 +278,23 @@ def test_issue227_gettype_cte():
with_stmt = sqlparse.parse('WITH foo AS (SELECT 1, 2, 3)'
'SELECT * FROM foo;')
assert with_stmt[0].get_type() == 'SELECT'
- with2_stmt = sqlparse.parse('''
+ with2_stmt = sqlparse.parse("""
WITH foo AS (SELECT 1 AS abc, 2 AS def),
bar AS (SELECT * FROM something WHERE x > 1)
- INSERT INTO elsewhere SELECT * FROM foo JOIN bar;
- ''')
+ INSERT INTO elsewhere SELECT * FROM foo JOIN bar;""")
assert with2_stmt[0].get_type() == 'INSERT'
def test_issue207_runaway_format():
sql = 'select 1 from (select 1 as one, 2 as two, 3 from dual) t0'
p = sqlparse.format(sql, reindent=True)
- assert p == '\n'.join(["select 1",
- "from",
- " (select 1 as one,",
- " 2 as two,",
- " 3",
- " from dual) t0"])
+ assert p == '\n'.join([
+ "select 1",
+ "from",
+ " (select 1 as one,",
+ " 2 as two,",
+ " 3",
+ " from dual) t0"])
def token_next_doesnt_ignore_skip_cm():
diff --git a/tests/test_split.py b/tests/test_split.py
index 7c2645d..af7c9ce 100644
--- a/tests/test_split.py
+++ b/tests/test_split.py
@@ -4,139 +4,133 @@
import types
-from tests.utils import load_file, TestCaseBase
+import pytest
import sqlparse
-from sqlparse.compat import StringIO, u, text_type
-
-
-class SQLSplitTest(TestCaseBase):
- """Tests sqlparse.sqlsplit()."""
-
- _sql1 = 'select * from foo;'
- _sql2 = 'select * from bar;'
-
- def test_split_semicolon(self):
- sql2 = 'select * from foo where bar = \'foo;bar\';'
- stmts = sqlparse.parse(''.join([self._sql1, sql2]))
- self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(u(stmts[0]), self._sql1)
- self.ndiffAssertEqual(u(stmts[1]), sql2)
-
- def test_split_backslash(self):
- stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';")
- self.assertEqual(len(stmts), 3)
-
- def test_create_function(self):
- sql = load_file('function.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
-
- def test_create_function_psql(self):
- sql = load_file('function_psql.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
-
- def test_create_function_psql3(self):
- sql = load_file('function_psql3.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
-
- def test_create_function_psql2(self):
- sql = load_file('function_psql2.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
-
- def test_dashcomments(self):
- sql = load_file('dashcomment.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
-
- def test_dashcomments_eol(self):
- stmts = sqlparse.parse('select foo; -- comment\n')
- self.assertEqual(len(stmts), 1)
- stmts = sqlparse.parse('select foo; -- comment\r')
- self.assertEqual(len(stmts), 1)
- stmts = sqlparse.parse('select foo; -- comment\r\n')
- self.assertEqual(len(stmts), 1)
- stmts = sqlparse.parse('select foo; -- comment')
- self.assertEqual(len(stmts), 1)
-
- def test_begintag(self):
- sql = load_file('begintag.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
-
- def test_begintag_2(self):
- sql = load_file('begintag_2.sql')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
-
- def test_dropif(self):
- sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
-
- def test_comment_with_umlaut(self):
- sql = (u'select * from foo;\n'
- u'-- Testing an umlaut: ä\n'
- u'select * from bar;')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
-
- def test_comment_end_of_line(self):
- sql = ('select * from foo; -- foo\n'
- 'select * from bar;')
- stmts = sqlparse.parse(sql)
- self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
- # make sure the comment belongs to first query
- self.ndiffAssertEqual(u(stmts[0]), 'select * from foo; -- foo\n')
-
- def test_casewhen(self):
- sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
- 'comment on table actor is \'The actor table.\';')
- stmts = sqlparse.split(sql)
- self.assertEqual(len(stmts), 2)
-
- def test_cursor_declare(self):
- sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n'
- 'SELECT 2;')
- stmts = sqlparse.split(sql)
- self.assertEqual(len(stmts), 2)
-
- def test_if_function(self): # see issue 33
- # don't let IF as a function confuse the splitter
- sql = ('CREATE TEMPORARY TABLE tmp '
- 'SELECT IF(a=1, a, b) AS o FROM one; '
- 'SELECT t FROM two')
- stmts = sqlparse.split(sql)
- self.assertEqual(len(stmts), 2)
-
- def test_split_stream(self):
- stream = StringIO("SELECT 1; SELECT 2;")
- stmts = sqlparse.parsestream(stream)
- self.assertEqual(type(stmts), types.GeneratorType)
- self.assertEqual(len(list(stmts)), 2)
-
- def test_encoding_parsestream(self):
- stream = StringIO("SELECT 1; SELECT 2;")
- stmts = list(sqlparse.parsestream(stream))
- self.assertEqual(type(stmts[0].tokens[0].value), text_type)
-
- def test_unicode_parsestream(self):
- stream = StringIO(u"SELECT ö")
- stmts = list(sqlparse.parsestream(stream))
- self.assertEqual(str(stmts[0]), "SELECT ö")
+from sqlparse.compat import StringIO, text_type
+
+
+def test_split_semicolon():
+ sql1 = 'select * from foo;'
+ sql2 = "select * from foo where bar = 'foo;bar';"
+ stmts = sqlparse.parse(''.join([sql1, sql2]))
+ assert len(stmts) == 2
+ assert str(stmts[0]) == sql1
+ assert str(stmts[1]) == sql2
+
+
+def test_split_backslash():
+ stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';")
+ assert len(stmts) == 3
+
+
+@pytest.mark.parametrize('fn', ['function.sql',
+ 'function_psql.sql',
+ 'function_psql2.sql',
+ 'function_psql3.sql'])
+def test_split_create_function(load_file, fn):
+ sql = load_file(fn)
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 1
+ assert text_type(stmts[0]) == sql
+
+
+def test_split_dashcomments(load_file):
+ sql = load_file('dashcomment.sql')
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 3
+ assert ''.join(str(q) for q in stmts) == sql
+
+
+@pytest.mark.parametrize('s', ['select foo; -- comment\n',
+ 'select foo; -- comment\r',
+ 'select foo; -- comment\r\n',
+ 'select foo; -- comment'])
+def test_split_dashcomments_eol(s):
+ stmts = sqlparse.parse(s)
+ assert len(stmts) == 1
+
+
+def test_split_begintag(load_file):
+ sql = load_file('begintag.sql')
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 3
+ assert ''.join(str(q) for q in stmts) == sql
+
+
+def test_split_begintag_2(load_file):
+ sql = load_file('begintag_2.sql')
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 1
+ assert ''.join(str(q) for q in stmts) == sql
+
+
+def test_split_dropif():
+ sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 2
+ assert ''.join(str(q) for q in stmts) == sql
+
+
+def test_split_comment_with_umlaut():
+ sql = (u'select * from foo;\n'
+ u'-- Testing an umlaut: ä\n'
+ u'select * from bar;')
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 2
+ assert ''.join(text_type(q) for q in stmts) == sql
+
+
+def test_split_comment_end_of_line():
+ sql = ('select * from foo; -- foo\n'
+ 'select * from bar;')
+ stmts = sqlparse.parse(sql)
+ assert len(stmts) == 2
+ assert ''.join(str(q) for q in stmts) == sql
+ # make sure the comment belongs to first query
+ assert str(stmts[0]) == 'select * from foo; -- foo\n'
+
+
+def test_split_casewhen():
+ sql = ("SELECT case when val = 1 then 2 else null end as foo;\n"
+ "comment on table actor is 'The actor table.';")
+ stmts = sqlparse.split(sql)
+ assert len(stmts) == 2
+
+
+def test_split_cursor_declare():
+ sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n'
+ 'SELECT 2;')
+ stmts = sqlparse.split(sql)
+ assert len(stmts) == 2
+
+
+def test_split_if_function(): # see issue 33
+ # don't let IF as a function confuse the splitter
+ sql = ('CREATE TEMPORARY TABLE tmp '
+ 'SELECT IF(a=1, a, b) AS o FROM one; '
+ 'SELECT t FROM two')
+ stmts = sqlparse.split(sql)
+ assert len(stmts) == 2
+
+
+def test_split_stream():
+ stream = StringIO("SELECT 1; SELECT 2;")
+ stmts = sqlparse.parsestream(stream)
+ assert isinstance(stmts, types.GeneratorType)
+ assert len(list(stmts)) == 2
+
+
+def test_split_encoding_parsestream():
+ stream = StringIO("SELECT 1; SELECT 2;")
+ stmts = list(sqlparse.parsestream(stream))
+ assert isinstance(stmts[0].tokens[0].value, text_type)
+
+
+def test_split_unicode_parsestream():
+ stream = StringIO(u'SELECT ö')
+ stmts = list(sqlparse.parsestream(stream))
+ assert str(stmts[0]) == 'SELECT ö'
def test_split_simple():
diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py
index 61eaa3e..6cc0dfa 100644
--- a/tests/test_tokenize.py
+++ b/tests/test_tokenize.py
@@ -1,171 +1,159 @@
# -*- coding: utf-8 -*-
import types
-import unittest
import pytest
import sqlparse
from sqlparse import lexer
-from sqlparse import sql
-from sqlparse import tokens as T
+from sqlparse import sql, tokens as T
from sqlparse.compat import StringIO
-class TestTokenize(unittest.TestCase):
-
- def test_simple(self):
- s = 'select * from foo;'
- stream = lexer.tokenize(s)
- self.assert_(isinstance(stream, types.GeneratorType))
- tokens = list(stream)
- self.assertEqual(len(tokens), 8)
- self.assertEqual(len(tokens[0]), 2)
- self.assertEqual(tokens[0], (T.Keyword.DML, u'select'))
- self.assertEqual(tokens[-1], (T.Punctuation, u';'))
-
- def test_backticks(self):
- s = '`foo`.`bar`'
- tokens = list(lexer.tokenize(s))
- self.assertEqual(len(tokens), 3)
- self.assertEqual(tokens[0], (T.Name, u'`foo`'))
-
- def test_linebreaks(self): # issue1
- s = 'foo\nbar\n'
- tokens = lexer.tokenize(s)
- self.assertEqual(''.join(str(x[1]) for x in tokens), s)
- s = 'foo\rbar\r'
- tokens = lexer.tokenize(s)
- self.assertEqual(''.join(str(x[1]) for x in tokens), s)
- s = 'foo\r\nbar\r\n'
- tokens = lexer.tokenize(s)
- self.assertEqual(''.join(str(x[1]) for x in tokens), s)
- s = 'foo\r\nbar\n'
- tokens = lexer.tokenize(s)
- self.assertEqual(''.join(str(x[1]) for x in tokens), s)
-
- def test_inline_keywords(self): # issue 7
- s = "create created_foo"
- tokens = list(lexer.tokenize(s))
- self.assertEqual(len(tokens), 3)
- self.assertEqual(tokens[0][0], T.Keyword.DDL)
- self.assertEqual(tokens[2][0], T.Name)
- self.assertEqual(tokens[2][1], u'created_foo')
- s = "enddate"
- tokens = list(lexer.tokenize(s))
- self.assertEqual(len(tokens), 1)
- self.assertEqual(tokens[0][0], T.Name)
- s = "join_col"
- tokens = list(lexer.tokenize(s))
- self.assertEqual(len(tokens), 1)
- self.assertEqual(tokens[0][0], T.Name)
- s = "left join_col"
- tokens = list(lexer.tokenize(s))
- self.assertEqual(len(tokens), 3)
- self.assertEqual(tokens[2][0], T.Name)
- self.assertEqual(tokens[2][1], 'join_col')
-
- def test_negative_numbers(self):
- s = "values(-1)"
- tokens = list(lexer.tokenize(s))
- self.assertEqual(len(tokens), 4)
- self.assertEqual(tokens[2][0], T.Number.Integer)
- self.assertEqual(tokens[2][1], '-1')
-
-
-class TestToken(unittest.TestCase):
-
- def test_str(self):
- token = sql.Token(None, 'FoO')
- self.assertEqual(str(token), 'FoO')
-
- def test_repr(self):
- token = sql.Token(T.Keyword, 'foo')
- tst = "<Keyword 'foo' at 0x"
- self.assertEqual(repr(token)[:len(tst)], tst)
- token = sql.Token(T.Keyword, '1234567890')
- tst = "<Keyword '123456...' at 0x"
- self.assertEqual(repr(token)[:len(tst)], tst)
-
- def test_flatten(self):
- token = sql.Token(T.Keyword, 'foo')
- gen = token.flatten()
- self.assertEqual(type(gen), types.GeneratorType)
- lgen = list(gen)
- self.assertEqual(lgen, [token])
-
-
-class TestTokenList(unittest.TestCase):
-
- def test_repr(self):
- p = sqlparse.parse('foo, bar, baz')[0]
- tst = "<IdentifierList 'foo, b...' at 0x"
- self.assertEqual(repr(p.tokens[0])[:len(tst)], tst)
-
- def test_token_first(self):
- p = sqlparse.parse(' select foo')[0]
- first = p.token_first()
- self.assertEqual(first.value, 'select')
- self.assertEqual(p.token_first(skip_ws=False).value, ' ')
- self.assertEqual(sql.TokenList([]).token_first(), None)
-
- def test_token_matching(self):
- t1 = sql.Token(T.Keyword, 'foo')
- t2 = sql.Token(T.Punctuation, ',')
- x = sql.TokenList([t1, t2])
- self.assertEqual(x.token_matching(
- [lambda t: t.ttype is T.Keyword], 0), t1)
- self.assertEqual(x.token_matching(
- [lambda t: t.ttype is T.Punctuation], 0), t2)
- self.assertEqual(x.token_matching(
- [lambda t: t.ttype is T.Keyword], 1), None)
-
-
-class TestStream(unittest.TestCase):
- def test_simple(self):
- stream = StringIO("SELECT 1; SELECT 2;")
-
- tokens = lexer.tokenize(stream)
- self.assertEqual(len(list(tokens)), 9)
-
- stream.seek(0)
- tokens = list(lexer.tokenize(stream))
- self.assertEqual(len(tokens), 9)
-
- stream.seek(0)
- tokens = list(lexer.tokenize(stream))
- self.assertEqual(len(tokens), 9)
-
- def test_error(self):
- stream = StringIO("FOOBAR{")
-
- tokens = list(lexer.tokenize(stream))
- self.assertEqual(len(tokens), 2)
- self.assertEqual(tokens[1][0], T.Error)
-
-
-@pytest.mark.parametrize('expr', ['JOIN', 'LEFT JOIN', 'LEFT OUTER JOIN',
- 'FULL OUTER JOIN', 'NATURAL JOIN',
- 'CROSS JOIN', 'STRAIGHT JOIN',
- 'INNER JOIN', 'LEFT INNER JOIN'])
+def test_tokenize_simple():
+ s = 'select * from foo;'
+ stream = lexer.tokenize(s)
+ assert isinstance(stream, types.GeneratorType)
+ tokens = list(stream)
+ assert len(tokens) == 8
+ assert len(tokens[0]) == 2
+ assert tokens[0] == (T.Keyword.DML, 'select')
+ assert tokens[-1] == (T.Punctuation, ';')
+
+
+def test_tokenize_backticks():
+ s = '`foo`.`bar`'
+ tokens = list(lexer.tokenize(s))
+ assert len(tokens) == 3
+ assert tokens[0] == (T.Name, '`foo`')
+
+
+@pytest.mark.parametrize('s', ['foo\nbar\n', 'foo\rbar\r',
+ 'foo\r\nbar\r\n', 'foo\r\nbar\n'])
+def test_tokenize_linebreaks(s):
+ # issue1
+ tokens = lexer.tokenize(s)
+ assert ''.join(str(x[1]) for x in tokens) == s
+
+
+def test_tokenize_inline_keywords():
+ # issue 7
+ s = "create created_foo"
+ tokens = list(lexer.tokenize(s))
+ assert len(tokens) == 3
+ assert tokens[0][0] == T.Keyword.DDL
+ assert tokens[2][0] == T.Name
+ assert tokens[2][1] == 'created_foo'
+ s = "enddate"
+ tokens = list(lexer.tokenize(s))
+ assert len(tokens) == 1
+ assert tokens[0][0] == T.Name
+ s = "join_col"
+ tokens = list(lexer.tokenize(s))
+ assert len(tokens) == 1
+ assert tokens[0][0] == T.Name
+ s = "left join_col"
+ tokens = list(lexer.tokenize(s))
+ assert len(tokens) == 3
+ assert tokens[2][0] == T.Name
+ assert tokens[2][1] == 'join_col'
+
+
+def test_tokenize_negative_numbers():
+ s = "values(-1)"
+ tokens = list(lexer.tokenize(s))
+ assert len(tokens) == 4
+ assert tokens[2][0] == T.Number.Integer
+ assert tokens[2][1] == '-1'
+
+
+def test_token_str():
+ token = sql.Token(None, 'FoO')
+ assert str(token) == 'FoO'
+
+
+def test_token_repr():
+ token = sql.Token(T.Keyword, 'foo')
+ tst = "<Keyword 'foo' at 0x"
+ assert repr(token)[:len(tst)] == tst
+ token = sql.Token(T.Keyword, '1234567890')
+ tst = "<Keyword '123456...' at 0x"
+ assert repr(token)[:len(tst)] == tst
+
+
+def test_token_flatten():
+ token = sql.Token(T.Keyword, 'foo')
+ gen = token.flatten()
+ assert isinstance(gen, types.GeneratorType)
+ lgen = list(gen)
+ assert lgen == [token]
+
+
+def test_tokenlist_repr():
+ p = sqlparse.parse('foo, bar, baz')[0]
+ tst = "<IdentifierList 'foo, b...' at 0x"
+ assert repr(p.tokens[0])[:len(tst)] == tst
+
+
+def test_tokenlist_first():
+ p = sqlparse.parse(' select foo')[0]
+ first = p.token_first()
+ assert first.value == 'select'
+ assert p.token_first(skip_ws=False).value == ' '
+ assert sql.TokenList([]).token_first() is None
+
+
+def test_tokenlist_token_matching():
+ t1 = sql.Token(T.Keyword, 'foo')
+ t2 = sql.Token(T.Punctuation, ',')
+ x = sql.TokenList([t1, t2])
+ assert x.token_matching([lambda t: t.ttype is T.Keyword], 0) == t1
+ assert x.token_matching([lambda t: t.ttype is T.Punctuation], 0) == t2
+ assert x.token_matching([lambda t: t.ttype is T.Keyword], 1) is None
+
+
+def test_stream_simple():
+ stream = StringIO("SELECT 1; SELECT 2;")
+
+ tokens = lexer.tokenize(stream)
+ assert len(list(tokens)) == 9
+
+ stream.seek(0)
+ tokens = list(lexer.tokenize(stream))
+ assert len(tokens) == 9
+
+ stream.seek(0)
+ tokens = list(lexer.tokenize(stream))
+ assert len(tokens) == 9
+
+
+def test_stream_error():
+ stream = StringIO("FOOBAR{")
+
+ tokens = list(lexer.tokenize(stream))
+ assert len(tokens) == 2
+ assert tokens[1][0] == T.Error
+
+
+@pytest.mark.parametrize('expr', [
+ 'JOIN',
+ 'LEFT JOIN',
+ 'LEFT OUTER JOIN',
+ 'FULL OUTER JOIN',
+ 'NATURAL JOIN',
+ 'CROSS JOIN',
+ 'STRAIGHT JOIN',
+ 'INNER JOIN',
+ 'LEFT INNER JOIN'])
def test_parse_join(expr):
p = sqlparse.parse('{0} foo'.format(expr))[0]
assert len(p.tokens) == 3
assert p.tokens[0].ttype is T.Keyword
-def test_parse_endifloop():
- p = sqlparse.parse('END IF')[0]
- assert len(p.tokens) == 1
- assert p.tokens[0].ttype is T.Keyword
- p = sqlparse.parse('END IF')[0]
- assert len(p.tokens) == 1
- p = sqlparse.parse('END\t\nIF')[0]
- assert len(p.tokens) == 1
- assert p.tokens[0].ttype is T.Keyword
- p = sqlparse.parse('END LOOP')[0]
- assert len(p.tokens) == 1
- assert p.tokens[0].ttype is T.Keyword
- p = sqlparse.parse('END LOOP')[0]
+@pytest.mark.parametrize('s', ['END IF', 'END IF', 'END\t\nIF',
+ 'END LOOP', 'END LOOP', 'END\t\nLOOP'])
+def test_parse_endifloop(s):
+ p = sqlparse.parse(s)[0]
assert len(p.tokens) == 1
assert p.tokens[0].ttype is T.Keyword
diff --git a/tests/utils.py b/tests/utils.py
deleted file mode 100644
index 8000ae5..0000000
--- a/tests/utils.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""Helpers for testing."""
-
-import difflib
-import io
-import os
-import unittest
-
-from sqlparse.utils import split_unquoted_newlines
-from sqlparse.compat import StringIO
-
-DIR_PATH = os.path.dirname(__file__)
-FILES_DIR = os.path.join(DIR_PATH, 'files')
-
-
-def load_file(filename, encoding='utf-8'):
- """Opens filename with encoding and return its contents."""
- with io.open(os.path.join(FILES_DIR, filename), encoding=encoding) as f:
- return f.read()
-
-
-class TestCaseBase(unittest.TestCase):
- """Base class for test cases with some additional checks."""
-
- # Adopted from Python's tests.
- def ndiffAssertEqual(self, first, second):
- """Like failUnlessEqual except use ndiff for readable output."""
- if first != second:
- # Using the built-in .splitlines() method here will cause incorrect
- # results when splitting statements that have quoted CR/CR+LF
- # characters.
- sfirst = split_unquoted_newlines(first)
- ssecond = split_unquoted_newlines(second)
- diff = difflib.ndiff(sfirst, ssecond)
-
- fp = StringIO()
- fp.write('\n')
- fp.write('\n'.join(diff))
-
- raise self.failureException(fp.getvalue())