diff options
| author | Vik <vmuriart@gmail.com> | 2016-06-23 11:37:46 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2016-06-23 11:37:46 -0700 |
| commit | 78fb2041aef1068eb6c4bac3caad5dd012186868 (patch) | |
| tree | db402888277a06ca450eba2e57dd6f94c9b369ad | |
| parent | c56652ef9fdac111dd59e26b913765719eaf1141 (diff) | |
| parent | 85349e68592964e66e5dfe7e48e9f76cb93d48fd (diff) | |
| download | sqlparse-78fb2041aef1068eb6c4bac3caad5dd012186868.tar.gz | |
Merge pull request #263 from vmuriart/clean-tests
Clean-up tests. Fully migrate to Py.test
| -rw-r--r-- | sqlparse/cli.py | 12 | ||||
| -rw-r--r-- | tests/conftest.py | 41 | ||||
| -rw-r--r-- | tests/test_cli.py | 49 | ||||
| -rw-r--r-- | tests/test_format.py | 654 | ||||
| -rw-r--r-- | tests/test_grouping.py | 586 | ||||
| -rw-r--r-- | tests/test_parse.py | 424 | ||||
| -rw-r--r-- | tests/test_regressions.py | 373 | ||||
| -rw-r--r-- | tests/test_split.py | 256 | ||||
| -rw-r--r-- | tests/test_tokenize.py | 292 | ||||
| -rw-r--r-- | tests/utils.py | 41 |
10 files changed, 1367 insertions, 1361 deletions
diff --git a/sqlparse/cli.py b/sqlparse/cli.py index 03a4f8f..80d547d 100644 --- a/sqlparse/cli.py +++ b/sqlparse/cli.py @@ -123,7 +123,8 @@ def create_parser(): def _error(msg): """Print msg and optionally exit with return code exit_.""" - sys.stderr.write('[ERROR] %s\n' % msg) + sys.stderr.write('[ERROR] {0}\n'.format(msg)) + return 1 def main(args=None): @@ -137,15 +138,13 @@ def main(args=None): # TODO: Needs to deal with encoding data = ''.join(open(args.filename).readlines()) except IOError as e: - _error('Failed to read %s: %s' % (args.filename, e)) - return 1 + return _error('Failed to read {0}: {1}'.format(args.filename, e)) if args.outfile: try: stream = open(args.outfile, 'w') except IOError as e: - _error('Failed to open %s: %s' % (args.outfile, e)) - return 1 + return _error('Failed to open {0}: {1}'.format(args.outfile, e)) else: stream = sys.stdout @@ -153,8 +152,7 @@ def main(args=None): try: formatter_opts = sqlparse.formatter.validate_options(formatter_opts) except SQLParseError as e: - _error('Invalid options: %s' % e) - return 1 + return _error('Invalid options: {0}'.format(e)) s = sqlparse.format(data, **formatter_opts) if PY2: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d5621eb --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + +"""Helpers for testing.""" + +import io +import os + +import pytest + +DIR_PATH = os.path.dirname(__file__) +FILES_DIR = os.path.join(DIR_PATH, 'files') + + +@pytest.fixture() +def filepath(): + """Returns full file path for test files.""" + + def make_filepath(filename): + # http://stackoverflow.com/questions/18011902/parameter-to-a-fixture + # Alternate solution is to use paramtrization `inderect=True` + # http://stackoverflow.com/a/33879151 + # Syntax is noisy and requires specific variable names + return os.path.join(FILES_DIR, filename) + + return make_filepath + + +@pytest.fixture() +def load_file(filepath): + """Opens filename with encoding and return its contents.""" + + def make_load_file(filename, encoding='utf-8'): + # http://stackoverflow.com/questions/18011902/parameter-to-a-fixture + # Alternate solution is to use paramtrization `inderect=True` + # http://stackoverflow.com/a/33879151 + # Syntax is noisy and requires specific variable names + # And seems to be limited to only 1 argument. + with io.open(filepath(filename), encoding=encoding) as f: + return f.read() + + return make_load_file diff --git a/tests/test_cli.py b/tests/test_cli.py index 81d8449..77a764e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,16 +1,11 @@ # -*- coding: utf-8 -*- -import os import subprocess import sys import pytest import sqlparse -import sqlparse.__main__ -from tests.utils import FILES_DIR - -path = os.path.join(FILES_DIR, 'function.sql') def test_cli_main_empty(): @@ -31,12 +26,50 @@ def test_main_help(): assert exinfo.value.code == 0 -def test_valid_args(): +def test_valid_args(filepath): # test doesn't abort + path = filepath('function.sql') assert sqlparse.cli.main([path, '-r']) is not None +def test_invalid_choise(filepath): + path = filepath('function.sql') + with pytest.raises(SystemExit): + sqlparse.cli.main([path, '-l', 'spanish']) + + +def test_invalid_args(filepath, capsys): + path = filepath('function.sql') + sqlparse.cli.main([path, '-r', '--indent_width', '0']) + _, err = capsys.readouterr() + assert err == ("[ERROR] Invalid options: indent_width requires " + "a positive integer\n") + + +def test_invalid_infile(filepath, capsys): + path = filepath('missing.sql') + sqlparse.cli.main([path, '-r']) + _, err = capsys.readouterr() + assert err[:22] == "[ERROR] Failed to read" + + +def test_invalid_outfile(filepath, capsys): + path = filepath('function.sql') + outpath = filepath('/missing/function.sql') + sqlparse.cli.main([path, '-r', '-o', outpath]) + _, err = capsys.readouterr() + assert err[:22] == "[ERROR] Failed to open" + + +def test_stdout(filepath, load_file, capsys): + path = filepath('begintag.sql') + expected = load_file('begintag.sql') + sqlparse.cli.main([path]) + out, _ = capsys.readouterr() + assert out == expected + + def test_script(): # Call with the --help option as a basic sanity check. - cmdl = "{0:s} -m sqlparse.cli --help".format(sys.executable) - assert subprocess.call(cmdl.split()) == 0 + cmd = "{0:s} -m sqlparse.cli --help".format(sys.executable) + assert subprocess.call(cmd.split()) == 0 diff --git a/tests/test_format.py b/tests/test_format.py index 74fce71..023f26d 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -2,83 +2,94 @@ import pytest -from tests.utils import TestCaseBase - import sqlparse from sqlparse.exceptions import SQLParseError -class TestFormat(TestCaseBase): - +class TestFormat(object): def test_keywordcase(self): sql = 'select * from bar; -- select foo\n' res = sqlparse.format(sql, keyword_case='upper') - self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- select foo\n') + assert res == 'SELECT * FROM bar; -- select foo\n' res = sqlparse.format(sql, keyword_case='capitalize') - self.ndiffAssertEqual(res, 'Select * From bar; -- select foo\n') + assert res == 'Select * From bar; -- select foo\n' res = sqlparse.format(sql.upper(), keyword_case='lower') - self.ndiffAssertEqual(res, 'select * from BAR; -- SELECT FOO\n') - self.assertRaises(SQLParseError, sqlparse.format, sql, - keyword_case='foo') + assert res == 'select * from BAR; -- SELECT FOO\n' + + def test_keywordcase_invalid_option(self): + sql = 'select * from bar; -- select foo\n' + with pytest.raises(SQLParseError): + sqlparse.format(sql, keyword_case='foo') def test_identifiercase(self): sql = 'select * from bar; -- select foo\n' res = sqlparse.format(sql, identifier_case='upper') - self.ndiffAssertEqual(res, 'select * from BAR; -- select foo\n') + assert res == 'select * from BAR; -- select foo\n' res = sqlparse.format(sql, identifier_case='capitalize') - self.ndiffAssertEqual(res, 'select * from Bar; -- select foo\n') + assert res == 'select * from Bar; -- select foo\n' res = sqlparse.format(sql.upper(), identifier_case='lower') - self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- SELECT FOO\n') - self.assertRaises(SQLParseError, sqlparse.format, sql, - identifier_case='foo') + assert res == 'SELECT * FROM bar; -- SELECT FOO\n' + + def test_identifiercase_invalid_option(self): + sql = 'select * from bar; -- select foo\n' + with pytest.raises(SQLParseError): + sqlparse.format(sql, identifier_case='foo') + + def test_identifiercase_quotes(self): sql = 'select * from "foo"."bar"' res = sqlparse.format(sql, identifier_case="upper") - self.ndiffAssertEqual(res, 'select * from "foo"."bar"') + assert res == 'select * from "foo"."bar"' def test_strip_comments_single(self): sql = 'select *-- statement starts here\nfrom foo' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select * from foo') + assert res == 'select * from foo' sql = 'select * -- statement starts here\nfrom foo' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select * from foo') + assert res == 'select * from foo' sql = 'select-- foo\nfrom -- bar\nwhere' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select from where') - self.assertRaises(SQLParseError, sqlparse.format, sql, - strip_comments=None) + assert res == 'select from where' + + def test_strip_comments_invalid_option(self): + sql = 'select-- foo\nfrom -- bar\nwhere' + with pytest.raises(SQLParseError): + sqlparse.format(sql, strip_comments=None) def test_strip_comments_multi(self): sql = '/* sql starts here */\nselect' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select') + assert res == 'select' sql = '/* sql starts here */ select' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select') + assert res == 'select' sql = '/*\n * sql starts here\n */\nselect' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select') + assert res == 'select' sql = 'select (/* sql starts here */ select 2)' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select (select 2)') + assert res == 'select (select 2)' sql = 'select (/* sql /* starts here */ select 2)' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select (select 2)') + assert res == 'select (select 2)' def test_strip_ws(self): f = lambda sql: sqlparse.format(sql, strip_whitespace=True) s = 'select\n* from foo\n\twhere ( 1 = 2 )\n' - self.ndiffAssertEqual(f(s), 'select * from foo where (1 = 2)') + assert f(s) == 'select * from foo where (1 = 2)' + s = 'select -- foo\nfrom bar\n' + assert f(s) == 'select -- foo\nfrom bar' + + def test_strip_ws_invalid_option(self): s = 'select -- foo\nfrom bar\n' - self.ndiffAssertEqual(f(s), 'select -- foo\nfrom bar') - self.assertRaises(SQLParseError, sqlparse.format, s, - strip_whitespace=None) + with pytest.raises(SQLParseError): + sqlparse.format(s, strip_whitespace=None) def test_preserve_ws(self): # preserve at least one whitespace after subgroups f = lambda sql: sqlparse.format(sql, strip_whitespace=True) s = 'select\n* /* foo */ from bar ' - self.ndiffAssertEqual(f(s), 'select * /* foo */ from bar') + assert f(s) == 'select * /* foo */ from bar' def test_notransform_of_quoted_crlf(self): # Make sure that CR/CR+LF characters inside string literals don't get @@ -92,21 +103,14 @@ class TestFormat(TestCaseBase): f = lambda x: sqlparse.format(x) # Because of the use of - self.ndiffAssertEqual(f(s1), "SELECT some_column LIKE 'value\r'") - self.ndiffAssertEqual( - f(s2), "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n") - self.ndiffAssertEqual( - f(s3), "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n") - self.ndiffAssertEqual( - f(s4), "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n") - - def test_outputformat(self): - sql = 'select * from foo;' - self.assertRaises(SQLParseError, sqlparse.format, sql, - output_format='foo') + assert f(s1) == "SELECT some_column LIKE 'value\r'" + assert f(s2) == "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n" + assert f(s3) == "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n" + assert (f(s4) == + "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n") -class TestFormatReindentAligned(TestCaseBase): +class TestFormatReindentAligned(object): @staticmethod def formatter(sql): return sqlparse.format(sql, reindent_aligned=True) @@ -121,23 +125,21 @@ class TestFormatReindentAligned(TestCaseBase): or d is 'blue' limit 10 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' b as bb,', - ' c', - ' from table', - ' join (', - ' select a * 2 as a', - ' from new_table', - ' ) other', - ' on table.a = other.a', - ' where c is true', - ' and b between 3 and 4', - " or d is 'blue'", - ' limit 10', - ])) + + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' b as bb,', + ' c', + ' from table', + ' join (', + ' select a * 2 as a', + ' from new_table', + ' ) other', + ' on table.a = other.a', + ' where c is true', + ' and b between 3 and 4', + " or d is 'blue'", + ' limit 10']) def test_joins(self): sql = """ @@ -148,22 +150,19 @@ class TestFormatReindentAligned(TestCaseBase): cross join e on e.four = a.four join f using (one, two, three) """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select *', - ' from a', - ' join b', - ' on a.one = b.one', - ' left join c', - ' on c.two = a.two', - ' and c.three = a.three', - ' full outer join d', - ' on d.three = a.three', - ' cross join e', - ' on e.four = a.four', - ' join f using (one, two, three)', - ])) + assert self.formatter(sql) == '\n'.join([ + 'select *', + ' from a', + ' join b', + ' on a.one = b.one', + ' left join c', + ' on c.two = a.two', + ' and c.three = a.three', + ' full outer join d', + ' on d.three = a.three', + ' cross join e', + ' on e.four = a.four', + ' join f using (one, two, three)']) def test_case_statement(self): sql = """ @@ -178,20 +177,17 @@ class TestFormatReindentAligned(TestCaseBase): where c is true and b between 3 and 4 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' case when a = 0 then 1', - ' when bb = 1 then 1', - ' when c = 2 then 2', - ' else 0', - ' end as d,', - ' extra_col', - ' from table', - ' where c is true', - ' and b between 3 and 4' - ])) + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' case when a = 0 then 1', + ' when bb = 1 then 1', + ' when c = 2 then 2', + ' else 0', + ' end as d,', + ' extra_col', + ' from table', + ' where c is true', + ' and b between 3 and 4']) def test_case_statement_with_between(self): sql = """ @@ -207,21 +203,18 @@ class TestFormatReindentAligned(TestCaseBase): where c is true and b between 3 and 4 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' case when a = 0 then 1', - ' when bb = 1 then 1', - ' when c = 2 then 2', - ' when d between 3 and 5 then 3', - ' else 0', - ' end as d,', - ' extra_col', - ' from table', - ' where c is true', - ' and b between 3 and 4' - ])) + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' case when a = 0 then 1', + ' when bb = 1 then 1', + ' when c = 2 then 2', + ' when d between 3 and 5 then 3', + ' else 0', + ' end as d,', + ' extra_col', + ' from table', + ' where c is true', + ' and b between 3 and 4']) def test_group_by(self): sql = """ @@ -232,24 +225,21 @@ class TestFormatReindentAligned(TestCaseBase): and count(y) > 5 order by 3,2,1 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' b,', - ' c,', - ' sum(x) as sum_x,', - ' count(y) as cnt_y', - ' from table', - ' group by a,', - ' b,', - ' c', - 'having sum(x) > 1', - ' and count(y) > 5', - ' order by 3,', - ' 2,', - ' 1', - ])) + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' b,', + ' c,', + ' sum(x) as sum_x,', + ' count(y) as cnt_y', + ' from table', + ' group by a,', + ' b,', + ' c', + 'having sum(x) > 1', + ' and count(y) > 5', + ' order by 3,', + ' 2,', + ' 1']) def test_group_by_subquery(self): # TODO: add subquery alias when test_identifier_list_subquery fixed @@ -261,21 +251,18 @@ class TestFormatReindentAligned(TestCaseBase): group by a,z) order by 1,2 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select *,', - ' sum_b + 2 as mod_sum', - ' from (', - ' select a,', - ' sum(b) as sum_b', - ' from table', - ' group by a,', - ' z', - ' )', - ' order by 1,', - ' 2', - ])) + assert self.formatter(sql) == '\n'.join([ + 'select *,', + ' sum_b + 2 as mod_sum', + ' from (', + ' select a,', + ' sum(b) as sum_b', + ' from table', + ' group by a,', + ' z', + ' )', + ' order by 1,', + ' 2']) def test_window_functions(self): sql = """ @@ -284,21 +271,17 @@ class TestFormatReindentAligned(TestCaseBase): BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a, ROW_NUMBER() OVER (PARTITION BY b, c ORDER BY d DESC) as row_num - from table - """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - (' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS ' - 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,'), - (' ROW_NUMBER() OVER ' - '(PARTITION BY b, c ORDER BY d DESC) as row_num'), - ' from table', - ])) - - -class TestSpacesAroundOperators(TestCaseBase): + from table""" + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS ' + 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,', + ' ROW_NUMBER() OVER ' + '(PARTITION BY b, c ORDER BY d DESC) as row_num', + ' from table']) + + +class TestSpacesAroundOperators(object): @staticmethod def formatter(sql): return sqlparse.format(sql, use_space_around_operators=True) @@ -306,311 +289,330 @@ class TestSpacesAroundOperators(TestCaseBase): def test_basic(self): sql = ('select a+b as d from table ' 'where (c-d)%2= 1 and e> 3.0/4 and z^2 <100') - self.ndiffAssertEqual( - self.formatter(sql), ( - 'select a + b as d from table ' - 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100') - ) + assert self.formatter(sql) == ( + 'select a + b as d from table ' + 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100') def test_bools(self): sql = 'select * from table where a &&b or c||d' - self.ndiffAssertEqual( - self.formatter(sql), - 'select * from table where a && b or c || d' - ) + assert self.formatter( + sql) == 'select * from table where a && b or c || d' def test_nested(self): sql = 'select *, case when a-b then c end from table' - self.ndiffAssertEqual( - self.formatter(sql), - 'select *, case when a - b then c end from table' - ) + assert self.formatter( + sql) == 'select *, case when a - b then c end from table' def test_wildcard_vs_mult(self): sql = 'select a*b-c from table' - self.ndiffAssertEqual( - self.formatter(sql), - 'select a * b - c from table' - ) + assert self.formatter(sql) == 'select a * b - c from table' -class TestFormatReindent(TestCaseBase): - +class TestFormatReindent(object): def test_option(self): - self.assertRaises(SQLParseError, sqlparse.format, 'foo', - reindent=2) - self.assertRaises(SQLParseError, sqlparse.format, 'foo', - indent_tabs=2) - self.assertRaises(SQLParseError, sqlparse.format, 'foo', - reindent=True, indent_width='foo') - self.assertRaises(SQLParseError, sqlparse.format, 'foo', - reindent=True, indent_width=-12) - self.assertRaises(SQLParseError, sqlparse.format, 'foo', - reindent=True, wrap_after='foo') - self.assertRaises(SQLParseError, sqlparse.format, 'foo', - reindent=True, wrap_after=-12) + with pytest.raises(SQLParseError): + sqlparse.format('foo', reindent=2) + with pytest.raises(SQLParseError): + sqlparse.format('foo', indent_tabs=2) + with pytest.raises(SQLParseError): + sqlparse.format('foo', reindent=True, indent_width='foo') + with pytest.raises(SQLParseError): + sqlparse.format('foo', reindent=True, indent_width=-12) + with pytest.raises(SQLParseError): + sqlparse.format('foo', reindent=True, wrap_after='foo') + with pytest.raises(SQLParseError): + sqlparse.format('foo', reindent=True, wrap_after=-12) def test_stmts(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo; select bar' - self.ndiffAssertEqual(f(s), 'select foo;\n\nselect bar') + assert f(s) == 'select foo;\n\nselect bar' s = 'select foo' - self.ndiffAssertEqual(f(s), 'select foo') + assert f(s) == 'select foo' s = 'select foo; -- test\n select bar' - self.ndiffAssertEqual(f(s), 'select foo; -- test\n\nselect bar') + assert f(s) == 'select foo; -- test\n\nselect bar' def test_keywords(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo union select * from bar;' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'union', - 'select *', - 'from bar;'])) - - def test_keywords_between(self): # issue 14 + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'union', + 'select *', + 'from bar;']) + + def test_keywords_between(self): + # issue 14 # don't break AND after BETWEEN f = lambda sql: sqlparse.format(sql, reindent=True) s = 'and foo between 1 and 2 and bar = 3' - self.ndiffAssertEqual(f(s), '\n'.join(['', - 'and foo between 1 and 2', - 'and bar = 3'])) + assert f(s) == '\n'.join([ + '', + 'and foo between 1 and 2', + 'and bar = 3']) def test_parenthesis(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select count(*) from (select * from foo);' - self.ndiffAssertEqual(f(s), - '\n'.join(['select count(*)', - 'from', - ' (select *', - ' from foo);', - ]) - ) + assert f(s) == '\n'.join([ + 'select count(*)', + 'from', + ' (select *', + ' from foo);']) def test_where(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;' - self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n' - 'where bar = 1\n' - ' and baz = 2\n' - ' or bzz = 3;')) + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'where bar = 1', + ' and baz = 2', + ' or bzz = 3;']) + s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);' - self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n' - 'where bar = 1\n' - ' and (baz = 2\n' - ' or bzz = 3);')) + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'where bar = 1', + ' and (baz = 2', + ' or bzz = 3);']) def test_join(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'join bar on 1 = 2'])) + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'join bar on 1 = 2']) s = 'select * from foo inner join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'inner join bar on 1 = 2'])) + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'inner join bar on 1 = 2']) s = 'select * from foo left outer join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'left outer join bar on 1 = 2'] - )) + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'left outer join bar on 1 = 2']) s = 'select * from foo straight_join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'straight_join bar on 1 = 2'] - )) + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'straight_join bar on 1 = 2']) def test_identifier_list(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo, bar, baz from table1, table2 where 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select foo,', - ' bar,', - ' baz', - 'from table1,', - ' table2', - 'where 1 = 2'])) + assert f(s) == '\n'.join([ + 'select foo,', + ' bar,', + ' baz', + 'from table1,', + ' table2', + 'where 1 = 2']) s = 'select a.*, b.id from a, b' - self.ndiffAssertEqual(f(s), '\n'.join(['select a.*,', - ' b.id', - 'from a,', - ' b'])) + assert f(s) == '\n'.join([ + 'select a.*,', + ' b.id', + 'from a,', + ' b']) def test_identifier_list_with_wrap_after(self): f = lambda sql: sqlparse.format(sql, reindent=True, wrap_after=14) s = 'select foo, bar, baz from table1, table2 where 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select foo, bar,', - ' baz', - 'from table1, table2', - 'where 1 = 2'])) + assert f(s) == '\n'.join([ + 'select foo, bar,', + ' baz', + 'from table1, table2', + 'where 1 = 2']) def test_identifier_list_with_functions(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = ("select 'abc' as foo, coalesce(col1, col2)||col3 as bar," "col3 from my_table") - self.ndiffAssertEqual(f(s), '\n'.join( - ["select 'abc' as foo,", - " coalesce(col1, col2)||col3 as bar,", - " col3", - "from my_table"])) + assert f(s) == '\n'.join([ + "select 'abc' as foo,", + " coalesce(col1, col2)||col3 as bar,", + " col3", + "from my_table"]) def test_case(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end' - self.ndiffAssertEqual(f(s), '\n'.join(['case', - ' when foo = 1 then 2', - ' when foo = 3 then 4', - ' else 5', - 'end'])) + assert f(s) == '\n'.join([ + 'case', + ' when foo = 1 then 2', + ' when foo = 3 then 4', + ' else 5', + 'end']) def test_case2(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'case(foo) when bar = 1 then 2 else 3 end' - self.ndiffAssertEqual(f(s), '\n'.join(['case(foo)', - ' when bar = 1 then 2', - ' else 3', - 'end'])) - - def test_nested_identifier_list(self): # issue4 + assert f(s) == '\n'.join([ + 'case(foo)', + ' when bar = 1 then 2', + ' else 3', + 'end']) + + def test_nested_identifier_list(self): + # issue4 f = lambda sql: sqlparse.format(sql, reindent=True) s = '(foo as bar, bar1, bar2 as bar3, b4 as b5)' - self.ndiffAssertEqual(f(s), '\n'.join(['(foo as bar,', - ' bar1,', - ' bar2 as bar3,', - ' b4 as b5)'])) - - def test_duplicate_linebreaks(self): # issue3 + assert f(s) == '\n'.join([ + '(foo as bar,', + ' bar1,', + ' bar2 as bar3,', + ' b4 as b5)']) + + def test_duplicate_linebreaks(self): + # issue3 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select c1 -- column1\nfrom foo' - self.ndiffAssertEqual(f(s), '\n'.join(['select c1 -- column1', - 'from foo'])) + assert f(s) == '\n'.join([ + 'select c1 -- column1', + 'from foo']) s = 'select c1 -- column1\nfrom foo' r = sqlparse.format(s, reindent=True, strip_comments=True) - self.ndiffAssertEqual(r, '\n'.join(['select c1', - 'from foo'])) + assert r == '\n'.join([ + 'select c1', + 'from foo']) s = 'select c1\nfrom foo\norder by c1' - self.ndiffAssertEqual(f(s), '\n'.join(['select c1', - 'from foo', - 'order by c1'])) + assert f(s) == '\n'.join([ + 'select c1', + 'from foo', + 'order by c1']) s = 'select c1 from t1 where (c1 = 1) order by c1' - self.ndiffAssertEqual(f(s), '\n'.join(['select c1', - 'from t1', - 'where (c1 = 1)', - 'order by c1'])) - - def test_keywordfunctions(self): # issue36 + assert f(s) == '\n'.join([ + 'select c1', + 'from t1', + 'where (c1 = 1)', + 'order by c1']) + + def test_keywordfunctions(self): + # issue36 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select max(a) b, foo, bar' - self.ndiffAssertEqual(f(s), '\n'.join(['select max(a) b,', - ' foo,', - ' bar'])) + assert f(s) == '\n'.join([ + 'select max(a) b,', + ' foo,', + ' bar']) - def test_identifier_and_functions(self): # issue45 + def test_identifier_and_functions(self): + # issue45 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo.bar, nvl(1) from dual' - self.ndiffAssertEqual(f(s), '\n'.join(['select foo.bar,', - ' nvl(1)', - 'from dual'])) + assert f(s) == '\n'.join([ + 'select foo.bar,', + ' nvl(1)', + 'from dual']) -class TestOutputFormat(TestCaseBase): - +class TestOutputFormat(object): def test_python(self): sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='python') - self.ndiffAssertEqual(f(sql), "sql = 'select * from foo;'") + assert f(sql) == "sql = 'select * from foo;'" f = lambda sql: sqlparse.format(sql, output_format='python', reindent=True) - self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n" - " 'from foo;')")) + assert f(sql) == '\n'.join([ + "sql = ('select * '", + " 'from foo;')"]) def test_python_multiple_statements(self): sql = 'select * from foo; select 1 from dual' f = lambda sql: sqlparse.format(sql, output_format='python') - self.ndiffAssertEqual(f(sql), ("sql = 'select * from foo; '\n" - "sql2 = 'select 1 from dual'")) + assert f(sql) == '\n'.join([ + "sql = 'select * from foo; '", + "sql2 = 'select 1 from dual'"]) @pytest.mark.xfail(reason="Needs fixing") def test_python_multiple_statements_with_formatting(self): sql = 'select * from foo; select 1 from dual' f = lambda sql: sqlparse.format(sql, output_format='python', reindent=True) - self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n" - " 'from foo;')\n" - "sql2 = ('select 1 '\n" - " 'from dual')")) + assert f(sql) == '\n'.join([ + "sql = ('select * '", + " 'from foo;')", + "sql2 = ('select 1 '", + " 'from dual')"]) def test_php(self): sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='php') - self.ndiffAssertEqual(f(sql), '$sql = "select * from foo;";') + assert f(sql) == '$sql = "select * from foo;";' f = lambda sql: sqlparse.format(sql, output_format='php', reindent=True) - self.ndiffAssertEqual(f(sql), ('$sql = "select * ";\n' - '$sql .= "from foo;";')) + assert f(sql) == '\n'.join([ + '$sql = "select * ";', + '$sql .= "from foo;";']) - def test_sql(self): # "sql" is an allowed option but has no effect + def test_sql(self): + # "sql" is an allowed option but has no effect sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='sql') - self.ndiffAssertEqual(f(sql), 'select * from foo;') + assert f(sql) == 'select * from foo;' + + def test_invalid_option(self): + sql = 'select * from foo;' + with pytest.raises(SQLParseError): + sqlparse.format(sql, output_format='foo') -def test_format_column_ordering(): # issue89 +def test_format_column_ordering(): + # issue89 sql = 'select * from foo order by c1 desc, c2, c3;' formatted = sqlparse.format(sql, reindent=True) - expected = '\n'.join(['select *', - 'from foo', - 'order by c1 desc,', - ' c2,', - ' c3;']) + expected = '\n'.join([ + 'select *', + 'from foo', + 'order by c1 desc,', + ' c2,', + ' c3;']) assert formatted == expected def test_truncate_strings(): - sql = 'update foo set value = \'' + 'x' * 1000 + '\';' + sql = "update foo set value = '{0}';".format('x' * 1000) formatted = sqlparse.format(sql, truncate_strings=10) - assert formatted == 'update foo set value = \'xxxxxxxxxx[...]\';' + assert formatted == "update foo set value = 'xxxxxxxxxx[...]';" formatted = sqlparse.format(sql, truncate_strings=3, truncate_char='YYY') - assert formatted == 'update foo set value = \'xxxYYY\';' + assert formatted == "update foo set value = 'xxxYYY';" -def test_truncate_strings_invalid_option(): - pytest.raises(SQLParseError, sqlparse.format, - 'foo', truncate_strings='bar') - pytest.raises(SQLParseError, sqlparse.format, - 'foo', truncate_strings=-1) - pytest.raises(SQLParseError, sqlparse.format, - 'foo', truncate_strings=0) +@pytest.mark.parametrize('option', ['bar', -1, 0]) +def test_truncate_strings_invalid_option2(option): + with pytest.raises(SQLParseError): + sqlparse.format('foo', truncate_strings=option) -@pytest.mark.parametrize('sql', ['select verrrylongcolumn from foo', - 'select "verrrylongcolumn" from "foo"']) +@pytest.mark.parametrize('sql', [ + 'select verrrylongcolumn from foo', + 'select "verrrylongcolumn" from "foo"']) def test_truncate_strings_doesnt_truncate_identifiers(sql): formatted = sqlparse.format(sql, truncate_strings=2) assert formatted == sql def test_having_produces_newline(): - sql = ( - 'select * from foo, bar where bar.id = foo.bar_id' - ' having sum(bar.value) > 100') + sql = ('select * from foo, bar where bar.id = foo.bar_id ' + 'having sum(bar.value) > 100') formatted = sqlparse.format(sql, reindent=True) expected = [ 'select *', 'from foo,', ' bar', 'where bar.id = foo.bar_id', - 'having sum(bar.value) > 100' - ] + 'having sum(bar.value) > 100'] assert formatted == '\n'.join(expected) -def test_format_right_margin_invalid_input(): - with pytest.raises(SQLParseError): - sqlparse.format('foo', right_margin=2) - +@pytest.mark.parametrize('right_margin', ['ten', 2]) +def test_format_right_margin_invalid_option(right_margin): with pytest.raises(SQLParseError): - sqlparse.format('foo', right_margin="two") + sqlparse.format('foo', right_margin=right_margin) @pytest.mark.xfail(reason="Needs fixing") diff --git a/tests/test_grouping.py b/tests/test_grouping.py index bf6bfeb..12d7310 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -3,278 +3,294 @@ import pytest import sqlparse -from sqlparse import sql -from sqlparse import tokens as T -from sqlparse.compat import u - -from tests.utils import TestCaseBase - - -class TestGrouping(TestCaseBase): - - def test_parenthesis(self): - s = 'select (select (x3) x2) and (y2) bar' - parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, str(parsed)) - self.assertEqual(len(parsed.tokens), 7) - self.assert_(isinstance(parsed.tokens[2], sql.Parenthesis)) - self.assert_(isinstance(parsed.tokens[-1], sql.Identifier)) - self.assertEqual(len(parsed.tokens[2].tokens), 5) - self.assert_(isinstance(parsed.tokens[2].tokens[3], sql.Identifier)) - self.assert_(isinstance(parsed.tokens[2].tokens[3].tokens[0], - sql.Parenthesis)) - self.assertEqual(len(parsed.tokens[2].tokens[3].tokens), 3) - - def test_comments(self): - s = '/*\n * foo\n */ \n bar' - parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(parsed)) - self.assertEqual(len(parsed.tokens), 2) - - def test_assignment(self): - s = 'foo := 1;' - parsed = sqlparse.parse(s)[0] - self.assertEqual(len(parsed.tokens), 1) - self.assert_(isinstance(parsed.tokens[0], sql.Assignment)) - s = 'foo := 1' - parsed = sqlparse.parse(s)[0] - self.assertEqual(len(parsed.tokens), 1) - self.assert_(isinstance(parsed.tokens[0], sql.Assignment)) - - def test_identifiers(self): - s = 'select foo.bar from "myscheme"."table" where fail. order' - parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(parsed)) - self.assert_(isinstance(parsed.tokens[2], sql.Identifier)) - self.assert_(isinstance(parsed.tokens[6], sql.Identifier)) - self.assert_(isinstance(parsed.tokens[8], sql.Where)) - s = 'select * from foo where foo.id = 1' - parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(parsed)) - self.assert_(isinstance(parsed.tokens[-1].tokens[-1].tokens[0], - sql.Identifier)) - s = 'select * from (select "foo"."id" from foo)' - parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(parsed)) - self.assert_(isinstance(parsed.tokens[-1].tokens[3], sql.Identifier)) - - s = "INSERT INTO `test` VALUES('foo', 'bar');" - parsed = sqlparse.parse(s)[0] - types = [l.ttype for l in parsed.tokens if not l.is_whitespace()] - self.assertEquals(types, [T.DML, T.Keyword, None, - T.Keyword, None, T.Punctuation]) - - s = "select 1.0*(a+b) as col, sum(c)/sum(d) from myschema.mytable" - parsed = sqlparse.parse(s)[0] - self.assertEqual(len(parsed.tokens), 7) - self.assert_(isinstance(parsed.tokens[2], sql.IdentifierList)) - self.assertEqual(len(parsed.tokens[2].tokens), 4) - identifiers = list(parsed.tokens[2].get_identifiers()) - self.assertEqual(len(identifiers), 2) - self.assertEquals(identifiers[0].get_alias(), u"col") - - def test_identifier_wildcard(self): - p = sqlparse.parse('a.*, b.id')[0] - self.assert_(isinstance(p.tokens[0], sql.IdentifierList)) - self.assert_(isinstance(p.tokens[0].tokens[0], sql.Identifier)) - self.assert_(isinstance(p.tokens[0].tokens[-1], sql.Identifier)) - - def test_identifier_name_wildcard(self): - p = sqlparse.parse('a.*')[0] - t = p.tokens[0] - self.assertEqual(t.get_name(), '*') - self.assertEqual(t.is_wildcard(), True) - - def test_identifier_invalid(self): - p = sqlparse.parse('a.')[0] - self.assert_(isinstance(p.tokens[0], sql.Identifier)) - self.assertEqual(p.tokens[0].has_alias(), False) - self.assertEqual(p.tokens[0].get_name(), None) - self.assertEqual(p.tokens[0].get_real_name(), None) - self.assertEqual(p.tokens[0].get_parent_name(), 'a') - - def test_identifier_invalid_in_middle(self): - # issue261 - s = 'SELECT foo. FROM foo' - p = sqlparse.parse(s)[0] - assert isinstance(p[2], sql.Identifier) - assert p[2][1].ttype == T.Punctuation - assert p[3].ttype == T.Whitespace - assert str(p[2]) == 'foo.' - - def test_identifier_as_invalid(self): # issue8 - p = sqlparse.parse('foo as select *')[0] - self.assert_(len(p.tokens), 5) - self.assert_(isinstance(p.tokens[0], sql.Identifier)) - self.assertEqual(len(p.tokens[0].tokens), 1) - self.assertEqual(p.tokens[2].ttype, T.Keyword) - - def test_identifier_function(self): - p = sqlparse.parse('foo() as bar')[0] - self.assert_(isinstance(p.tokens[0], sql.Identifier)) - self.assert_(isinstance(p.tokens[0].tokens[0], sql.Function)) - p = sqlparse.parse('foo()||col2 bar')[0] - self.assert_(isinstance(p.tokens[0], sql.Identifier)) - self.assert_(isinstance(p.tokens[0].tokens[0], sql.Operation)) - self.assert_(isinstance(p.tokens[0].tokens[0].tokens[0], sql.Function)) - - def test_identifier_extended(self): # issue 15 - p = sqlparse.parse('foo+100')[0] - self.assert_(isinstance(p.tokens[0], sql.Operation)) - p = sqlparse.parse('foo + 100')[0] - self.assert_(isinstance(p.tokens[0], sql.Operation)) - p = sqlparse.parse('foo*100')[0] - self.assert_(isinstance(p.tokens[0], sql.Operation)) - - def test_identifier_list(self): - p = sqlparse.parse('a, b, c')[0] - self.assert_(isinstance(p.tokens[0], sql.IdentifierList)) - p = sqlparse.parse('(a, b, c)')[0] - self.assert_(isinstance(p.tokens[0].tokens[1], sql.IdentifierList)) - - def test_identifier_list_subquery(self): - """identifier lists should still work in subqueries with aliases""" - p = sqlparse.parse("select * from (" - "select a, b + c as d from table) sub")[0] - subquery = p.tokens[-1].tokens[0] - idx, iden_list = subquery.token_next_by(i=sql.IdentifierList) - self.assert_(iden_list is not None) - # all the identifiers should be within the IdentifierList - _, ilist = subquery.token_next_by(i=sql.Identifier, idx=idx) - self.assert_(ilist is None) - - def test_identifier_list_case(self): - p = sqlparse.parse('a, case when 1 then 2 else 3 end as b, c')[0] - self.assert_(isinstance(p.tokens[0], sql.IdentifierList)) - p = sqlparse.parse('(a, case when 1 then 2 else 3 end as b, c)')[0] - self.assert_(isinstance(p.tokens[0].tokens[1], sql.IdentifierList)) - - def test_identifier_list_other(self): # issue2 - p = sqlparse.parse("select *, null, 1, 'foo', bar from mytable, x")[0] - self.assert_(isinstance(p.tokens[2], sql.IdentifierList)) - l = p.tokens[2] - self.assertEqual(len(l.tokens), 13) - - def test_identifier_list_with_inline_comments(self): # issue163 - p = sqlparse.parse('foo /* a comment */, bar')[0] - self.assert_(isinstance(p.tokens[0], sql.IdentifierList)) - self.assert_(isinstance(p.tokens[0].tokens[0], sql.Identifier)) - self.assert_(isinstance(p.tokens[0].tokens[3], sql.Identifier)) - - def test_identifiers_with_operators(self): - p = sqlparse.parse('a+b as c from table where (d-e)%2= 1')[0] - self.assertEqual(len([x for x in p.flatten() - if x.ttype == sqlparse.tokens.Name]), 5) - - def test_identifier_list_with_order(self): # issue101 - p = sqlparse.parse('1, 2 desc, 3')[0] - self.assert_(isinstance(p.tokens[0], sql.IdentifierList)) - self.assert_(isinstance(p.tokens[0].tokens[3], sql.Identifier)) - self.ndiffAssertEqual(u(p.tokens[0].tokens[3]), '2 desc') - - def test_where(self): - s = 'select * from foo where bar = 1 order by id desc' - p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) - self.assert_(len(p.tokens) == 14) - - s = 'select x from (select y from foo where bar = 1) z' - p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) - self.assert_(isinstance(p.tokens[-1].tokens[0].tokens[-2], sql.Where)) - - def test_typecast(self): - s = 'select foo::integer from bar' - p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) - self.assertEqual(p.tokens[2].get_typecast(), 'integer') - self.assertEqual(p.tokens[2].get_name(), 'foo') - s = 'select (current_database())::information_schema.sql_identifier' - p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) - self.assertEqual(p.tokens[2].get_typecast(), - 'information_schema.sql_identifier') - - def test_alias(self): - s = 'select foo as bar from mytable' - p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) - self.assertEqual(p.tokens[2].get_real_name(), 'foo') - self.assertEqual(p.tokens[2].get_alias(), 'bar') - s = 'select foo from mytable t1' - p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) - self.assertEqual(p.tokens[6].get_real_name(), 'mytable') - self.assertEqual(p.tokens[6].get_alias(), 't1') - s = 'select foo::integer as bar from mytable' - p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) - self.assertEqual(p.tokens[2].get_alias(), 'bar') - s = ('SELECT DISTINCT ' - '(current_database())::information_schema.sql_identifier AS view') - p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) - self.assertEqual(p.tokens[4].get_alias(), 'view') - - def test_alias_case(self): # see issue46 - p = sqlparse.parse('CASE WHEN 1 THEN 2 ELSE 3 END foo')[0] - self.assertEqual(len(p.tokens), 1) - self.assertEqual(p.tokens[0].get_alias(), 'foo') - - def test_alias_returns_none(self): # see issue185 - p = sqlparse.parse('foo.bar')[0] - self.assertEqual(len(p.tokens), 1) - self.assertEqual(p.tokens[0].get_alias(), None) - - def test_idlist_function(self): # see issue10 too - p = sqlparse.parse('foo(1) x, bar')[0] - self.assert_(isinstance(p.tokens[0], sql.IdentifierList)) - - def test_comparison_exclude(self): - # make sure operators are not handled too lazy - p = sqlparse.parse('(=)')[0] - self.assert_(isinstance(p.tokens[0], sql.Parenthesis)) - self.assert_(not isinstance(p.tokens[0].tokens[1], sql.Comparison)) - p = sqlparse.parse('(a=1)')[0] - self.assert_(isinstance(p.tokens[0].tokens[1], sql.Comparison)) - p = sqlparse.parse('(a>=1)')[0] - self.assert_(isinstance(p.tokens[0].tokens[1], sql.Comparison)) - - def test_function(self): - p = sqlparse.parse('foo()')[0] - self.assert_(isinstance(p.tokens[0], sql.Function)) - p = sqlparse.parse('foo(null, bar)')[0] - self.assert_(isinstance(p.tokens[0], sql.Function)) - self.assertEqual(len(list(p.tokens[0].get_parameters())), 2) - - def test_function_not_in(self): # issue183 - p = sqlparse.parse('in(1, 2)')[0] - self.assertEqual(len(p.tokens), 2) - self.assertEqual(p.tokens[0].ttype, T.Keyword) - self.assert_(isinstance(p.tokens[1], sql.Parenthesis)) - - def test_varchar(self): - p = sqlparse.parse('"text" Varchar(50) NOT NULL')[0] - self.assert_(isinstance(p.tokens[2], sql.Function)) - - -class TestStatement(TestCaseBase): - - def test_get_type(self): - def f(sql): - return sqlparse.parse(sql)[0] - self.assertEqual(f('select * from foo').get_type(), 'SELECT') - self.assertEqual(f('update foo').get_type(), 'UPDATE') - self.assertEqual(f(' update foo').get_type(), 'UPDATE') - self.assertEqual(f('\nupdate foo').get_type(), 'UPDATE') - self.assertEqual(f('foo').get_type(), 'UNKNOWN') - # Statements that have a whitespace after the closing semicolon - # are parsed as two statements where later only consists of the - # trailing whitespace. - self.assertEqual(f('\n').get_type(), 'UNKNOWN') - - -def test_identifier_with_operators(): # issue 53 +from sqlparse import sql, tokens as T + + +def test_grouping_parenthesis(): + s = 'select (select (x3) x2) and (y2) bar' + parsed = sqlparse.parse(s)[0] + assert str(parsed) == s + assert len(parsed.tokens) == 7 + assert isinstance(parsed.tokens[2], sql.Parenthesis) + assert isinstance(parsed.tokens[-1], sql.Identifier) + assert len(parsed.tokens[2].tokens) == 5 + assert isinstance(parsed.tokens[2].tokens[3], sql.Identifier) + assert isinstance(parsed.tokens[2].tokens[3].tokens[0], sql.Parenthesis) + assert len(parsed.tokens[2].tokens[3].tokens) == 3 + + +def test_grouping_comments(): + s = '/*\n * foo\n */ \n bar' + parsed = sqlparse.parse(s)[0] + assert str(parsed) == s + assert len(parsed.tokens) == 2 + + +@pytest.mark.parametrize('s', ['foo := 1;', 'foo := 1']) +def test_grouping_assignment(s): + parsed = sqlparse.parse(s)[0] + assert len(parsed.tokens) == 1 + assert isinstance(parsed.tokens[0], sql.Assignment) + + +def test_grouping_identifiers(): + s = 'select foo.bar from "myscheme"."table" where fail. order' + parsed = sqlparse.parse(s)[0] + assert str(parsed) == s + assert isinstance(parsed.tokens[2], sql.Identifier) + assert isinstance(parsed.tokens[6], sql.Identifier) + assert isinstance(parsed.tokens[8], sql.Where) + s = 'select * from foo where foo.id = 1' + parsed = sqlparse.parse(s)[0] + assert str(parsed) == s + assert isinstance(parsed.tokens[-1].tokens[-1].tokens[0], sql.Identifier) + s = 'select * from (select "foo"."id" from foo)' + parsed = sqlparse.parse(s)[0] + assert str(parsed) == s + assert isinstance(parsed.tokens[-1].tokens[3], sql.Identifier) + + s = "INSERT INTO `test` VALUES('foo', 'bar');" + parsed = sqlparse.parse(s)[0] + types = [l.ttype for l in parsed.tokens if not l.is_whitespace()] + assert types == [T.DML, T.Keyword, None, T.Keyword, None, T.Punctuation] + + s = "select 1.0*(a+b) as col, sum(c)/sum(d) from myschema.mytable" + parsed = sqlparse.parse(s)[0] + assert len(parsed.tokens) == 7 + assert isinstance(parsed.tokens[2], sql.IdentifierList) + assert len(parsed.tokens[2].tokens) == 4 + identifiers = list(parsed.tokens[2].get_identifiers()) + assert len(identifiers) == 2 + assert identifiers[0].get_alias() == "col" + + +def test_grouping_identifier_wildcard(): + p = sqlparse.parse('a.*, b.id')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + assert isinstance(p.tokens[0].tokens[0], sql.Identifier) + assert isinstance(p.tokens[0].tokens[-1], sql.Identifier) + + +def test_grouping_identifier_name_wildcard(): + p = sqlparse.parse('a.*')[0] + t = p.tokens[0] + assert t.get_name() == '*' + assert t.is_wildcard() is True + + +def test_grouping_identifier_invalid(): + p = sqlparse.parse('a.')[0] + assert isinstance(p.tokens[0], sql.Identifier) + assert p.tokens[0].has_alias() is False + assert p.tokens[0].get_name() is None + assert p.tokens[0].get_real_name() is None + assert p.tokens[0].get_parent_name() == 'a' + + +def test_grouping_identifier_invalid_in_middle(): + # issue261 + s = 'SELECT foo. FROM foo' + p = sqlparse.parse(s)[0] + assert isinstance(p[2], sql.Identifier) + assert p[2][1].ttype == T.Punctuation + assert p[3].ttype == T.Whitespace + assert str(p[2]) == 'foo.' + + +def test_grouping_identifier_as_invalid(): + # issue8 + p = sqlparse.parse('foo as select *')[0] + assert len(p.tokens), 5 + assert isinstance(p.tokens[0], sql.Identifier) + assert len(p.tokens[0].tokens) == 1 + assert p.tokens[2].ttype == T.Keyword + + +def test_grouping_identifier_function(): + p = sqlparse.parse('foo() as bar')[0] + assert isinstance(p.tokens[0], sql.Identifier) + assert isinstance(p.tokens[0].tokens[0], sql.Function) + p = sqlparse.parse('foo()||col2 bar')[0] + assert isinstance(p.tokens[0], sql.Identifier) + assert isinstance(p.tokens[0].tokens[0], sql.Operation) + assert isinstance(p.tokens[0].tokens[0].tokens[0], sql.Function) + + +@pytest.mark.parametrize('s', ['foo+100', 'foo + 100', 'foo*100']) +def test_grouping_operation(s): + p = sqlparse.parse(s)[0] + assert isinstance(p.tokens[0], sql.Operation) + + +def test_grouping_identifier_list(): + p = sqlparse.parse('a, b, c')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + p = sqlparse.parse('(a, b, c)')[0] + assert isinstance(p.tokens[0].tokens[1], sql.IdentifierList) + + +def test_grouping_identifier_list_subquery(): + """identifier lists should still work in subqueries with aliases""" + p = sqlparse.parse("select * from (" + "select a, b + c as d from table) sub")[0] + subquery = p.tokens[-1].tokens[0] + idx, iden_list = subquery.token_next_by(i=sql.IdentifierList) + assert iden_list is not None + # all the identifiers should be within the IdentifierList + _, ilist = subquery.token_next_by(i=sql.Identifier, idx=idx) + assert ilist is None + + +def test_grouping_identifier_list_case(): + p = sqlparse.parse('a, case when 1 then 2 else 3 end as b, c')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + p = sqlparse.parse('(a, case when 1 then 2 else 3 end as b, c)')[0] + assert isinstance(p.tokens[0].tokens[1], sql.IdentifierList) + + +def test_grouping_identifier_list_other(): + # issue2 + p = sqlparse.parse("select *, null, 1, 'foo', bar from mytable, x")[0] + assert isinstance(p.tokens[2], sql.IdentifierList) + assert len(p.tokens[2].tokens) == 13 + + +def test_grouping_identifier_list_with_inline_comments(): + # issue163 + p = sqlparse.parse('foo /* a comment */, bar')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + assert isinstance(p.tokens[0].tokens[0], sql.Identifier) + assert isinstance(p.tokens[0].tokens[3], sql.Identifier) + + +def test_grouping_identifiers_with_operators(): + p = sqlparse.parse('a+b as c from table where (d-e)%2= 1')[0] + assert len([x for x in p.flatten() if x.ttype == T.Name]) == 5 + + +def test_grouping_identifier_list_with_order(): + # issue101 + p = sqlparse.parse('1, 2 desc, 3')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + assert isinstance(p.tokens[0].tokens[3], sql.Identifier) + assert str(p.tokens[0].tokens[3]) == '2 desc' + + +def test_grouping_where(): + s = 'select * from foo where bar = 1 order by id desc' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert len(p.tokens) == 14 + + s = 'select x from (select y from foo where bar = 1) z' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert isinstance(p.tokens[-1].tokens[0].tokens[-2], sql.Where) + + +def test_grouping_typecast(): + s = 'select foo::integer from bar' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert p.tokens[2].get_typecast() == 'integer' + assert p.tokens[2].get_name() == 'foo' + s = 'select (current_database())::information_schema.sql_identifier' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert (p.tokens[2].get_typecast() == 'information_schema.sql_identifier') + + +def test_grouping_alias(): + s = 'select foo as bar from mytable' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert p.tokens[2].get_real_name() == 'foo' + assert p.tokens[2].get_alias() == 'bar' + s = 'select foo from mytable t1' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert p.tokens[6].get_real_name() == 'mytable' + assert p.tokens[6].get_alias() == 't1' + s = 'select foo::integer as bar from mytable' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert p.tokens[2].get_alias() == 'bar' + s = ('SELECT DISTINCT ' + '(current_database())::information_schema.sql_identifier AS view') + p = sqlparse.parse(s)[0] + assert str(p) == s + assert p.tokens[4].get_alias() == 'view' + + +def test_grouping_alias_case(): + # see issue46 + p = sqlparse.parse('CASE WHEN 1 THEN 2 ELSE 3 END foo')[0] + assert len(p.tokens) == 1 + assert p.tokens[0].get_alias() == 'foo' + + +def test_grouping_alias_returns_none(): + # see issue185 + p = sqlparse.parse('foo.bar')[0] + assert len(p.tokens) == 1 + assert p.tokens[0].get_alias() is None + + +def test_grouping_idlist_function(): + # see issue10 too + p = sqlparse.parse('foo(1) x, bar')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + + +def test_grouping_comparison_exclude(): + # make sure operators are not handled too lazy + p = sqlparse.parse('(=)')[0] + assert isinstance(p.tokens[0], sql.Parenthesis) + assert not isinstance(p.tokens[0].tokens[1], sql.Comparison) + p = sqlparse.parse('(a=1)')[0] + assert isinstance(p.tokens[0].tokens[1], sql.Comparison) + p = sqlparse.parse('(a>=1)')[0] + assert isinstance(p.tokens[0].tokens[1], sql.Comparison) + + +def test_grouping_function(): + p = sqlparse.parse('foo()')[0] + assert isinstance(p.tokens[0], sql.Function) + p = sqlparse.parse('foo(null, bar)')[0] + assert isinstance(p.tokens[0], sql.Function) + assert len(list(p.tokens[0].get_parameters())) == 2 + + +def test_grouping_function_not_in(): + # issue183 + p = sqlparse.parse('in(1, 2)')[0] + assert len(p.tokens) == 2 + assert p.tokens[0].ttype == T.Keyword + assert isinstance(p.tokens[1], sql.Parenthesis) + + +def test_grouping_varchar(): + p = sqlparse.parse('"text" Varchar(50) NOT NULL')[0] + assert isinstance(p.tokens[2], sql.Function) + + +def test_statement_get_type(): + def f(sql): + return sqlparse.parse(sql)[0] + + assert f('select * from foo').get_type() == 'SELECT' + assert f('update foo').get_type() == 'UPDATE' + assert f(' update foo').get_type() == 'UPDATE' + assert f('\nupdate foo').get_type() == 'UPDATE' + assert f('foo').get_type() == 'UNKNOWN' + # Statements that have a whitespace after the closing semicolon + # are parsed as two statements where later only consists of the + # trailing whitespace. + assert f('\n').get_type() == 'UNKNOWN' + + +def test_identifier_with_operators(): + # issue 53 p = sqlparse.parse('foo||bar')[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Operation) @@ -293,7 +309,7 @@ def test_identifier_with_op_trailing_ws(): def test_identifier_with_string_literals(): - p = sqlparse.parse('foo + \'bar\'')[0] + p = sqlparse.parse("foo + 'bar'")[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Operation) @@ -302,12 +318,13 @@ def test_identifier_with_string_literals(): # showed that this shouldn't be an identifier at all. I'm leaving this # commented in the source for a while. # def test_identifier_string_concat(): -# p = sqlparse.parse('\'foo\' || bar')[0] +# p = sqlparse.parse("'foo' || bar")[0] # assert len(p.tokens) == 1 # assert isinstance(p.tokens[0], sql.Identifier) -def test_identifier_consumes_ordering(): # issue89 +def test_identifier_consumes_ordering(): + # issue89 p = sqlparse.parse('select * from foo order by c1 desc, c2, c3')[0] assert isinstance(p.tokens[-1], sql.IdentifierList) ids = list(p.tokens[-1].get_identifiers()) @@ -318,7 +335,8 @@ def test_identifier_consumes_ordering(): # issue89 assert ids[1].get_ordering() is None -def test_comparison_with_keywords(): # issue90 +def test_comparison_with_keywords(): + # issue90 # in fact these are assignments, but for now we don't distinguish them p = sqlparse.parse('foo = NULL')[0] assert len(p.tokens) == 1 @@ -332,7 +350,8 @@ def test_comparison_with_keywords(): # issue90 assert isinstance(p.tokens[0], sql.Comparison) -def test_comparison_with_floats(): # issue145 +def test_comparison_with_floats(): + # issue145 p = sqlparse.parse('foo = 25.5')[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Comparison) @@ -341,7 +360,8 @@ def test_comparison_with_floats(): # issue145 assert p.tokens[0].right.value == '25.5' -def test_comparison_with_parenthesis(): # issue23 +def test_comparison_with_parenthesis(): + # issue23 p = sqlparse.parse('(3 + 4) = 7')[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Comparison) @@ -350,15 +370,17 @@ def test_comparison_with_parenthesis(): # issue23 assert comp.right.ttype is T.Number.Integer -def test_comparison_with_strings(): # issue148 - p = sqlparse.parse('foo = \'bar\'')[0] +def test_comparison_with_strings(): + # issue148 + p = sqlparse.parse("foo = 'bar'")[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Comparison) - assert p.tokens[0].right.value == '\'bar\'' + assert p.tokens[0].right.value == "'bar'" assert p.tokens[0].right.ttype == T.String.Single -def test_comparison_with_functions(): # issue230 +def test_comparison_with_functions(): + # issue230 p = sqlparse.parse('foo = DATE(bar.baz)')[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Comparison) diff --git a/tests/test_parse.py b/tests/test_parse.py index 75a7ab5..2d23425 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -1,161 +1,150 @@ # -*- coding: utf-8 -*- -"""Tests sqlparse function.""" +"""Tests sqlparse.parse().""" import pytest -from tests.utils import TestCaseBase - import sqlparse -import sqlparse.sql -from sqlparse import tokens as T -from sqlparse.compat import u, StringIO - - -class SQLParseTest(TestCaseBase): - """Tests sqlparse.parse().""" - - def test_tokenize(self): - sql = 'select * from foo;' - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 1) - self.assertEqual(str(stmts[0]), sql) - - def test_multistatement(self): - sql1 = 'select * from foo;' - sql2 = 'select * from bar;' - stmts = sqlparse.parse(sql1 + sql2) - self.assertEqual(len(stmts), 2) - self.assertEqual(str(stmts[0]), sql1) - self.assertEqual(str(stmts[1]), sql2) - - def test_newlines(self): - sql = u'select\n*from foo;' - p = sqlparse.parse(sql)[0] - self.assertEqual(u(p), sql) - sql = u'select\r\n*from foo' - p = sqlparse.parse(sql)[0] - self.assertEqual(u(p), sql) - sql = u'select\r*from foo' - p = sqlparse.parse(sql)[0] - self.assertEqual(u(p), sql) - sql = u'select\r\n*from foo\n' - p = sqlparse.parse(sql)[0] - self.assertEqual(u(p), sql) - - def test_within(self): - sql = 'foo(col1, col2)' - p = sqlparse.parse(sql)[0] - col1 = p.tokens[0].tokens[1].tokens[1].tokens[0] - self.assert_(col1.within(sqlparse.sql.Function)) - - def test_child_of(self): - sql = '(col1, col2)' - p = sqlparse.parse(sql)[0] - self.assert_(p.tokens[0].tokens[1].is_child_of(p.tokens[0])) - sql = 'select foo' - p = sqlparse.parse(sql)[0] - self.assert_(not p.tokens[2].is_child_of(p.tokens[0])) - self.assert_(p.tokens[2].is_child_of(p)) - - def test_has_ancestor(self): - sql = 'foo or (bar, baz)' - p = sqlparse.parse(sql)[0] - baz = p.tokens[-1].tokens[1].tokens[-1] - self.assert_(baz.has_ancestor(p.tokens[-1].tokens[1])) - self.assert_(baz.has_ancestor(p.tokens[-1])) - self.assert_(baz.has_ancestor(p)) - - def test_float(self): - t = sqlparse.parse('.5')[0].tokens - self.assertEqual(len(t), 1) - self.assert_(t[0].ttype is sqlparse.tokens.Number.Float) - t = sqlparse.parse('.51')[0].tokens - self.assertEqual(len(t), 1) - self.assert_(t[0].ttype is sqlparse.tokens.Number.Float) - t = sqlparse.parse('1.5')[0].tokens - self.assertEqual(len(t), 1) - self.assert_(t[0].ttype is sqlparse.tokens.Number.Float) - t = sqlparse.parse('12.5')[0].tokens - self.assertEqual(len(t), 1) - self.assert_(t[0].ttype is sqlparse.tokens.Number.Float) - - def test_placeholder(self): - def _get_tokens(sql): - return sqlparse.parse(sql)[0].tokens[-1].tokens - - t = _get_tokens('select * from foo where user = ?') - self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder) - self.assertEqual(t[-1].value, '?') - t = _get_tokens('select * from foo where user = :1') - self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder) - self.assertEqual(t[-1].value, ':1') - t = _get_tokens('select * from foo where user = :name') - self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder) - self.assertEqual(t[-1].value, ':name') - t = _get_tokens('select * from foo where user = %s') - self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder) - self.assertEqual(t[-1].value, '%s') - t = _get_tokens('select * from foo where user = $a') - self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder) - self.assertEqual(t[-1].value, '$a') - - def test_modulo_not_placeholder(self): - tokens = list(sqlparse.lexer.tokenize('x %3')) - self.assertEqual(tokens[2][0], sqlparse.tokens.Operator) - - def test_access_symbol(self): # see issue27 - t = sqlparse.parse('select a.[foo bar] as foo')[0].tokens - self.assert_(isinstance(t[-1], sqlparse.sql.Identifier)) - self.assertEqual(t[-1].get_name(), 'foo') - self.assertEqual(t[-1].get_real_name(), '[foo bar]') - self.assertEqual(t[-1].get_parent_name(), 'a') - - def test_square_brackets_notation_isnt_too_greedy(self): # see issue153 - t = sqlparse.parse('[foo], [bar]')[0].tokens - self.assert_(isinstance(t[0], sqlparse.sql.IdentifierList)) - self.assertEqual(len(t[0].tokens), 4) - self.assertEqual(t[0].tokens[0].get_real_name(), '[foo]') - self.assertEqual(t[0].tokens[-1].get_real_name(), '[bar]') - - def test_keyword_like_identifier(self): # see issue47 - t = sqlparse.parse('foo.key')[0].tokens - self.assertEqual(len(t), 1) - self.assert_(isinstance(t[0], sqlparse.sql.Identifier)) - - def test_function_parameter(self): # see issue94 - t = sqlparse.parse('abs(some_col)')[0].tokens[0].get_parameters() - self.assertEqual(len(t), 1) - self.assert_(isinstance(t[0], sqlparse.sql.Identifier)) - - def test_function_param_single_literal(self): - t = sqlparse.parse('foo(5)')[0].tokens[0].get_parameters() - self.assertEqual(len(t), 1) - self.assert_(t[0].ttype is T.Number.Integer) - - def test_nested_function(self): - t = sqlparse.parse('foo(bar(5))')[0].tokens[0].get_parameters() - self.assertEqual(len(t), 1) - self.assert_(type(t[0]) is sqlparse.sql.Function) +from sqlparse import sql, tokens as T +from sqlparse.compat import StringIO + + +def test_parse_tokenize(): + s = 'select * from foo;' + stmts = sqlparse.parse(s) + assert len(stmts) == 1 + assert str(stmts[0]) == s + + +def test_parse_multistatement(): + sql1 = 'select * from foo;' + sql2 = 'select * from bar;' + stmts = sqlparse.parse(sql1 + sql2) + assert len(stmts) == 2 + assert str(stmts[0]) == sql1 + assert str(stmts[1]) == sql2 + + +@pytest.mark.parametrize('s', ['select\n*from foo;', + 'select\r\n*from foo', + 'select\r*from foo', + 'select\r\n*from foo\n']) +def test_parse_newlines(s): + p = sqlparse.parse(s)[0] + assert str(p) == s + + +def test_parse_within(): + s = 'foo(col1, col2)' + p = sqlparse.parse(s)[0] + col1 = p.tokens[0].tokens[1].tokens[1].tokens[0] + assert col1.within(sql.Function) + + +def test_parse_child_of(): + s = '(col1, col2)' + p = sqlparse.parse(s)[0] + assert p.tokens[0].tokens[1].is_child_of(p.tokens[0]) + s = 'select foo' + p = sqlparse.parse(s)[0] + assert not p.tokens[2].is_child_of(p.tokens[0]) + assert p.tokens[2].is_child_of(p) + + +def test_parse_has_ancestor(): + s = 'foo or (bar, baz)' + p = sqlparse.parse(s)[0] + baz = p.tokens[-1].tokens[1].tokens[-1] + assert baz.has_ancestor(p.tokens[-1].tokens[1]) + assert baz.has_ancestor(p.tokens[-1]) + assert baz.has_ancestor(p) + + +@pytest.mark.parametrize('s', ['.5', '.51', '1.5', '12.5']) +def test_parse_float(s): + t = sqlparse.parse(s)[0].tokens + assert len(t) == 1 + assert t[0].ttype is sqlparse.tokens.Number.Float + + +@pytest.mark.parametrize('s, holder', [ + ('select * from foo where user = ?', '?'), + ('select * from foo where user = :1', ':1'), + ('select * from foo where user = :name', ':name'), + ('select * from foo where user = %s', '%s'), + ('select * from foo where user = $a', '$a')]) +def test_parse_placeholder(s, holder): + t = sqlparse.parse(s)[0].tokens[-1].tokens + assert t[-1].ttype is sqlparse.tokens.Name.Placeholder + assert t[-1].value == holder + + +def test_parse_modulo_not_placeholder(): + tokens = list(sqlparse.lexer.tokenize('x %3')) + assert tokens[2][0] == sqlparse.tokens.Operator + + +def test_parse_access_symbol(): + # see issue27 + t = sqlparse.parse('select a.[foo bar] as foo')[0].tokens + assert isinstance(t[-1], sql.Identifier) + assert t[-1].get_name() == 'foo' + assert t[-1].get_real_name() == '[foo bar]' + assert t[-1].get_parent_name() == 'a' + + +def test_parse_square_brackets_notation_isnt_too_greedy(): + # see issue153 + t = sqlparse.parse('[foo], [bar]')[0].tokens + assert isinstance(t[0], sql.IdentifierList) + assert len(t[0].tokens) == 4 + assert t[0].tokens[0].get_real_name() == '[foo]' + assert t[0].tokens[-1].get_real_name() == '[bar]' + + +def test_parse_keyword_like_identifier(): + # see issue47 + t = sqlparse.parse('foo.key')[0].tokens + assert len(t) == 1 + assert isinstance(t[0], sql.Identifier) + + +def test_parse_function_parameter(): + # see issue94 + t = sqlparse.parse('abs(some_col)')[0].tokens[0].get_parameters() + assert len(t) == 1 + assert isinstance(t[0], sql.Identifier) + + +def test_parse_function_param_single_literal(): + t = sqlparse.parse('foo(5)')[0].tokens[0].get_parameters() + assert len(t) == 1 + assert t[0].ttype is T.Number.Integer + + +def test_parse_nested_function(): + t = sqlparse.parse('foo(bar(5))')[0].tokens[0].get_parameters() + assert len(t) == 1 + assert type(t[0]) is sql.Function def test_quoted_identifier(): t = sqlparse.parse('select x.y as "z" from foo')[0].tokens - assert isinstance(t[2], sqlparse.sql.Identifier) + assert isinstance(t[2], sql.Identifier) assert t[2].get_name() == 'z' assert t[2].get_real_name() == 'y' -@pytest.mark.parametrize('name', [ - 'foo', - '_foo', -]) -def test_valid_identifier_names(name): # issue175 +@pytest.mark.parametrize('name', ['foo', '_foo']) +def test_valid_identifier_names(name): + # issue175 t = sqlparse.parse(name)[0].tokens - assert isinstance(t[0], sqlparse.sql.Identifier) + assert isinstance(t[0], sql.Identifier) + +def test_psql_quotation_marks(): + # issue83 -def test_psql_quotation_marks(): # issue83 # regression: make sure plain $$ work t = sqlparse.split(""" CREATE OR REPLACE FUNCTION testfunc1(integer) RETURNS integer AS $$ @@ -165,6 +154,7 @@ def test_psql_quotation_marks(): # issue83 .... $$ LANGUAGE plpgsql;""") assert len(t) == 2 + # make sure $SOMETHING$ works too t = sqlparse.split(""" CREATE OR REPLACE FUNCTION testfunc1(integer) RETURNS integer AS $PROC_1$ @@ -177,8 +167,8 @@ def test_psql_quotation_marks(): # issue83 def test_double_precision_is_builtin(): - sql = 'DOUBLE PRECISION' - t = sqlparse.parse(sql)[0].tokens + s = 'DOUBLE PRECISION' + t = sqlparse.parse(s)[0].tokens assert len(t) == 1 assert t[0].ttype == sqlparse.tokens.Name.Builtin assert t[0].value == 'DOUBLE PRECISION' @@ -207,10 +197,11 @@ def test_single_quotes_are_strings(): def test_double_quotes_are_identifiers(): p = sqlparse.parse('"foo"')[0].tokens assert len(p) == 1 - assert isinstance(p[0], sqlparse.sql.Identifier) + assert isinstance(p[0], sql.Identifier) -def test_single_quotes_with_linebreaks(): # issue118 +def test_single_quotes_with_linebreaks(): + # issue118 p = sqlparse.parse("'f\nf'")[0].tokens assert len(p) == 1 assert p[0].ttype is T.String.Single @@ -221,7 +212,7 @@ def test_sqlite_identifiers(): p = sqlparse.parse('[col1],[col2]')[0].tokens id_names = [id_.get_name() for id_ in p[0].get_identifiers()] assert len(p) == 1 - assert isinstance(p[0], sqlparse.sql.IdentifierList) + assert isinstance(p[0], sql.IdentifierList) assert id_names == ['[col1]', '[col2]'] p = sqlparse.parse('[col1]+[col2]')[0] @@ -280,35 +271,29 @@ def test_typed_array_definition(): # indentifer names p = sqlparse.parse('x int, y int[], z int')[0] names = [x.get_name() for x in p.get_sublists() - if isinstance(x, sqlparse.sql.Identifier)] + if isinstance(x, sql.Identifier)] assert names == ['x', 'y', 'z'] -@pytest.mark.parametrize('sql', [ - 'select 1 -- foo', - 'select 1 # foo' # see issue178 -]) -def test_single_line_comments(sql): - p = sqlparse.parse(sql)[0] +@pytest.mark.parametrize('s', ['select 1 -- foo', 'select 1 # foo']) +def test_single_line_comments(s): + # see issue178 + p = sqlparse.parse(s)[0] assert len(p.tokens) == 5 assert p.tokens[-1].ttype == T.Comment.Single -@pytest.mark.parametrize('sql', [ - 'foo', - '@foo', - '#foo', # see issue192 - '##foo' -]) -def test_names_and_special_names(sql): - p = sqlparse.parse(sql)[0] +@pytest.mark.parametrize('s', ['foo', '@foo', '#foo', '##foo']) +def test_names_and_special_names(s): + # see issue192 + p = sqlparse.parse(s)[0] assert len(p.tokens) == 1 - assert isinstance(p.tokens[0], sqlparse.sql.Identifier) + assert isinstance(p.tokens[0], sql.Identifier) def test_get_token_at_offset(): - # 0123456789 p = sqlparse.parse('select * from dual')[0] + # 0123456789 assert p.get_token_at_offset(0) == p.tokens[0] assert p.get_token_at_offset(1) == p.tokens[0] assert p.get_token_at_offset(6) == p.tokens[1] @@ -324,60 +309,61 @@ def test_pprint(): output = StringIO() p._pprint_tree(f=output) - pprint = u'\n'.join([" 0 DML 'select'", - " 1 Whitespace ' '", - " 2 IdentifierList 'a0, b0...'", - " | 0 Identifier 'a0'", - " | | 0 Name 'a0'", - " | 1 Punctuation ','", - " | 2 Whitespace ' '", - " | 3 Identifier 'b0'", - " | | 0 Name 'b0'", - " | 4 Punctuation ','", - " | 5 Whitespace ' '", - " | 6 Identifier 'c0'", - " | | 0 Name 'c0'", - " | 7 Punctuation ','", - " | 8 Whitespace ' '", - " | 9 Identifier 'd0'", - " | | 0 Name 'd0'", - " | 10 Punctuation ','", - " | 11 Whitespace ' '", - " | 12 Float 'e0'", - " 3 Whitespace ' '", - " 4 Keyword 'from'", - " 5 Whitespace ' '", - " 6 Identifier '(selec...'", - " | 0 Parenthesis '(selec...'", - " | | 0 Punctuation '('", - " | | 1 DML 'select'", - " | | 2 Whitespace ' '", - " | | 3 Wildcard '*'", - " | | 4 Whitespace ' '", - " | | 5 Keyword 'from'", - " | | 6 Whitespace ' '", - " | | 7 Identifier 'dual'", - " | | | 0 Name 'dual'", - " | | 8 Punctuation ')'", - " | 1 Whitespace ' '", - " | 2 Identifier 'q0'", - " | | 0 Name 'q0'", - " 7 Whitespace ' '", - " 8 Where 'where ...'", - " | 0 Keyword 'where'", - " | 1 Whitespace ' '", - " | 2 Comparison '1=1'", - " | | 0 Integer '1'", - " | | 1 Comparison '='", - " | | 2 Integer '1'", - " | 3 Whitespace ' '", - " | 4 Keyword 'and'", - " | 5 Whitespace ' '", - " | 6 Comparison '2=2'", - " | | 0 Integer '2'", - " | | 1 Comparison '='", - " | | 2 Integer '2'", - "", ]) + pprint = '\n'.join([ + " 0 DML 'select'", + " 1 Whitespace ' '", + " 2 IdentifierList 'a0, b0...'", + " | 0 Identifier 'a0'", + " | | 0 Name 'a0'", + " | 1 Punctuation ','", + " | 2 Whitespace ' '", + " | 3 Identifier 'b0'", + " | | 0 Name 'b0'", + " | 4 Punctuation ','", + " | 5 Whitespace ' '", + " | 6 Identifier 'c0'", + " | | 0 Name 'c0'", + " | 7 Punctuation ','", + " | 8 Whitespace ' '", + " | 9 Identifier 'd0'", + " | | 0 Name 'd0'", + " | 10 Punctuation ','", + " | 11 Whitespace ' '", + " | 12 Float 'e0'", + " 3 Whitespace ' '", + " 4 Keyword 'from'", + " 5 Whitespace ' '", + " 6 Identifier '(selec...'", + " | 0 Parenthesis '(selec...'", + " | | 0 Punctuation '('", + " | | 1 DML 'select'", + " | | 2 Whitespace ' '", + " | | 3 Wildcard '*'", + " | | 4 Whitespace ' '", + " | | 5 Keyword 'from'", + " | | 6 Whitespace ' '", + " | | 7 Identifier 'dual'", + " | | | 0 Name 'dual'", + " | | 8 Punctuation ')'", + " | 1 Whitespace ' '", + " | 2 Identifier 'q0'", + " | | 0 Name 'q0'", + " 7 Whitespace ' '", + " 8 Where 'where ...'", + " | 0 Keyword 'where'", + " | 1 Whitespace ' '", + " | 2 Comparison '1=1'", + " | | 0 Integer '1'", + " | | 1 Comparison '='", + " | | 2 Integer '1'", + " | 3 Whitespace ' '", + " | 4 Keyword 'and'", + " | 5 Whitespace ' '", + " | 6 Comparison '2=2'", + " | | 0 Integer '2'", + " | | 1 Comparison '='", + " | | 2 Integer '2'", + ""]) assert output.getvalue() == pprint @@ -394,7 +380,7 @@ def test_wildcard_multiplication(): def test_stmt_tokens_parents(): # see issue 226 - sql = "CREATE TABLE test();" - stmt = sqlparse.parse(sql)[0] + s = "CREATE TABLE test();" + stmt = sqlparse.parse(s)[0] for token in stmt.tokens: assert token.has_ancestor(stmt) diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 14aab24..255493c 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -1,163 +1,146 @@ # -*- coding: utf-8 -*- -import sys - -import pytest # noqa -from tests.utils import TestCaseBase, load_file +import pytest import sqlparse -from sqlparse import sql -from sqlparse import tokens as T - - -class RegressionTests(TestCaseBase): - - def test_issue9(self): - # make sure where doesn't consume parenthesis - p = sqlparse.parse('(where 1)')[0] - self.assert_(isinstance(p, sql.Statement)) - self.assertEqual(len(p.tokens), 1) - self.assert_(isinstance(p.tokens[0], sql.Parenthesis)) - prt = p.tokens[0] - self.assertEqual(len(prt.tokens), 3) - self.assertEqual(prt.tokens[0].ttype, T.Punctuation) - self.assertEqual(prt.tokens[-1].ttype, T.Punctuation) - - def test_issue13(self): - parsed = sqlparse.parse(("select 'one';\n" - "select 'two\\'';\n" - "select 'three';")) - self.assertEqual(len(parsed), 3) - self.assertEqual(str(parsed[1]).strip(), "select 'two\\'';") - - def test_issue26(self): - # parse stand-alone comments - p = sqlparse.parse('--hello')[0] - self.assertEqual(len(p.tokens), 1) - self.assert_(p.tokens[0].ttype is T.Comment.Single) - p = sqlparse.parse('-- hello')[0] - self.assertEqual(len(p.tokens), 1) - self.assert_(p.tokens[0].ttype is T.Comment.Single) - p = sqlparse.parse('--hello\n')[0] - self.assertEqual(len(p.tokens), 1) - self.assert_(p.tokens[0].ttype is T.Comment.Single) - p = sqlparse.parse('--')[0] - self.assertEqual(len(p.tokens), 1) - self.assert_(p.tokens[0].ttype is T.Comment.Single) - p = sqlparse.parse('--\n')[0] - self.assertEqual(len(p.tokens), 1) - self.assert_(p.tokens[0].ttype is T.Comment.Single) - - def test_issue34(self): - t = sqlparse.parse("create")[0].token_first() - self.assertEqual(t.match(T.Keyword.DDL, "create"), True) - self.assertEqual(t.match(T.Keyword.DDL, "CREATE"), True) - - def test_issue35(self): - # missing space before LIMIT - sql = sqlparse.format("select * from foo where bar = 1 limit 1", - reindent=True) - self.ndiffAssertEqual(sql, "\n".join(["select *", - "from foo", - "where bar = 1 limit 1"])) - - def test_issue38(self): - sql = sqlparse.format("SELECT foo; -- comment", - strip_comments=True) - self.ndiffAssertEqual(sql, "SELECT foo;") - sql = sqlparse.format("/* foo */", strip_comments=True) - self.ndiffAssertEqual(sql, "") - - def test_issue39(self): - p = sqlparse.parse('select user.id from user')[0] - self.assertEqual(len(p.tokens), 7) - idt = p.tokens[2] - self.assertEqual(idt.__class__, sql.Identifier) - self.assertEqual(len(idt.tokens), 3) - self.assertEqual(idt.tokens[0].match(T.Name, 'user'), True) - self.assertEqual(idt.tokens[1].match(T.Punctuation, '.'), True) - self.assertEqual(idt.tokens[2].match(T.Name, 'id'), True) - - def test_issue40(self): - # make sure identifier lists in subselects are grouped - p = sqlparse.parse(('SELECT id, name FROM ' - '(SELECT id, name FROM bar) as foo'))[0] - self.assertEqual(len(p.tokens), 7) - self.assertEqual(p.tokens[2].__class__, sql.IdentifierList) - self.assertEqual(p.tokens[-1].__class__, sql.Identifier) - self.assertEqual(p.tokens[-1].get_name(), u'foo') - sp = p.tokens[-1].tokens[0] - self.assertEqual(sp.tokens[3].__class__, sql.IdentifierList) - # make sure that formatting works as expected - self.ndiffAssertEqual( - sqlparse.format(('SELECT id, name FROM ' - '(SELECT id, name FROM bar)'), - reindent=True), - ('SELECT id,\n' - ' name\n' - 'FROM\n' - ' (SELECT id,\n' - ' name\n' - ' FROM bar)')) - self.ndiffAssertEqual( - sqlparse.format(('SELECT id, name FROM ' - '(SELECT id, name FROM bar) as foo'), - reindent=True), - ('SELECT id,\n' - ' name\n' - 'FROM\n' - ' (SELECT id,\n' - ' name\n' - ' FROM bar) as foo')) - - -def test_issue78(): +from sqlparse import sql, tokens as T +from sqlparse.compat import PY2 + + +def test_issue9(): + # make sure where doesn't consume parenthesis + p = sqlparse.parse('(where 1)')[0] + assert isinstance(p, sql.Statement) + assert len(p.tokens) == 1 + assert isinstance(p.tokens[0], sql.Parenthesis) + prt = p.tokens[0] + assert len(prt.tokens) == 3 + assert prt.tokens[0].ttype == T.Punctuation + assert prt.tokens[-1].ttype == T.Punctuation + + +def test_issue13(): + parsed = sqlparse.parse(("select 'one';\n" + "select 'two\\'';\n" + "select 'three';")) + assert len(parsed) == 3 + assert str(parsed[1]).strip() == "select 'two\\'';" + + +@pytest.mark.parametrize('s', ['--hello', '-- hello', '--hello\n', + '--', '--\n']) +def test_issue26(s): + # parse stand-alone comments + p = sqlparse.parse(s)[0] + assert len(p.tokens) == 1 + assert p.tokens[0].ttype is T.Comment.Single + + +@pytest.mark.parametrize('value', ['create', 'CREATE']) +def test_issue34(value): + t = sqlparse.parse("create")[0].token_first() + assert t.match(T.Keyword.DDL, value) is True + + +def test_issue35(): + # missing space before LIMIT + sql = sqlparse.format("select * from foo where bar = 1 limit 1", + reindent=True) + assert sql == "\n".join([ + "select *", + "from foo", + "where bar = 1 limit 1"]) + + +def test_issue38(): + sql = sqlparse.format("SELECT foo; -- comment", strip_comments=True) + assert sql == "SELECT foo;" + sql = sqlparse.format("/* foo */", strip_comments=True) + assert sql == "" + + +def test_issue39(): + p = sqlparse.parse('select user.id from user')[0] + assert len(p.tokens) == 7 + idt = p.tokens[2] + assert idt.__class__ == sql.Identifier + assert len(idt.tokens) == 3 + assert idt.tokens[0].match(T.Name, 'user') is True + assert idt.tokens[1].match(T.Punctuation, '.') is True + assert idt.tokens[2].match(T.Name, 'id') is True + + +def test_issue40(): + # make sure identifier lists in subselects are grouped + p = sqlparse.parse(('SELECT id, name FROM ' + '(SELECT id, name FROM bar) as foo'))[0] + assert len(p.tokens) == 7 + assert p.tokens[2].__class__ == sql.IdentifierList + assert p.tokens[-1].__class__ == sql.Identifier + assert p.tokens[-1].get_name() == 'foo' + sp = p.tokens[-1].tokens[0] + assert sp.tokens[3].__class__ == sql.IdentifierList + # make sure that formatting works as expected + s = sqlparse.format('SELECT id == name FROM ' + '(SELECT id, name FROM bar)', reindent=True) + assert s == '\n'.join([ + 'SELECT id == name', + 'FROM', + ' (SELECT id,', + ' name', + ' FROM bar)']) + + s = sqlparse.format('SELECT id == name FROM ' + '(SELECT id, name FROM bar) as foo', reindent=True) + assert s == '\n'.join([ + 'SELECT id == name', + 'FROM', + ' (SELECT id,', + ' name', + ' FROM bar) as foo']) + + +@pytest.mark.parametrize('s', ['select x.y::text as z from foo', + 'select x.y::text as "z" from foo', + 'select x."y"::text as z from foo', + 'select x."y"::text as "z" from foo', + 'select "x".y::text as z from foo', + 'select "x".y::text as "z" from foo', + 'select "x"."y"::text as z from foo', + 'select "x"."y"::text as "z" from foo']) +@pytest.mark.parametrize('func_name, result', [('get_name', 'z'), + ('get_real_name', 'y'), + ('get_parent_name', 'x'), + ('get_alias', 'z'), + ('get_typecast', 'text')]) +def test_issue78(s, func_name, result): # the bug author provided this nice examples, let's use them! - def _get_identifier(sql): - p = sqlparse.parse(sql)[0] - return p.tokens[2] - results = (('get_name', 'z'), - ('get_real_name', 'y'), - ('get_parent_name', 'x'), - ('get_alias', 'z'), - ('get_typecast', 'text')) - variants = ( - 'select x.y::text as z from foo', - 'select x.y::text as "z" from foo', - 'select x."y"::text as z from foo', - 'select x."y"::text as "z" from foo', - 'select "x".y::text as z from foo', - 'select "x".y::text as "z" from foo', - 'select "x"."y"::text as z from foo', - 'select "x"."y"::text as "z" from foo', - ) - for variant in variants: - i = _get_identifier(variant) - assert isinstance(i, sql.Identifier) - for func_name, result in results: - func = getattr(i, func_name) - assert func() == result + p = sqlparse.parse(s)[0] + i = p.tokens[2] + assert isinstance(i, sql.Identifier) + + func = getattr(i, func_name) + assert func() == result def test_issue83(): - sql = """ -CREATE OR REPLACE FUNCTION func_a(text) - RETURNS boolean LANGUAGE plpgsql STRICT IMMUTABLE AS -$_$ -BEGIN - ... -END; -$_$; - -CREATE OR REPLACE FUNCTION func_b(text) - RETURNS boolean LANGUAGE plpgsql STRICT IMMUTABLE AS -$_$ -BEGIN - ... -END; -$_$; - -ALTER TABLE..... ;""" + sql = """ CREATE OR REPLACE FUNCTION func_a(text) + RETURNS boolean LANGUAGE plpgsql STRICT IMMUTABLE AS + $_$ + BEGIN + ... + END; + $_$; + + CREATE OR REPLACE FUNCTION func_b(text) + RETURNS boolean LANGUAGE plpgsql STRICT IMMUTABLE AS + $_$ + BEGIN + ... + END; + $_$; + + ALTER TABLE..... ;""" t = sqlparse.split(sql) assert len(t) == 3 @@ -177,7 +160,7 @@ def test_parse_sql_with_binary(): sql = "select * from foo where bar = '{0}'".format(digest) formatted = sqlparse.format(sql, reindent=True) tformatted = "select *\nfrom foo\nwhere bar = '{0}'".format(digest) - if sys.version_info < (3,): + if PY2: tformatted = tformatted.decode('unicode-escape') assert formatted == tformatted @@ -192,7 +175,8 @@ def test_dont_alias_keywords(): assert p.tokens[2].ttype is T.Keyword -def test_format_accepts_encoding(): # issue20 +def test_format_accepts_encoding(load_file): + # issue20 sql = load_file('test_cp1251.sql', 'cp1251') formatted = sqlparse.format(sql, reindent=True, encoding='cp1251') tformatted = u'insert into foo\nvalues (1); -- Песня про надежду\n' @@ -206,17 +190,18 @@ def test_issue90(): ' "rating_score" = 0, "thumbnail_width" = NULL,' ' "thumbnail_height" = NULL, "price" = 1, "description" = NULL') formatted = sqlparse.format(sql, reindent=True) - tformatted = '\n'.join(['UPDATE "gallery_photo"', - 'SET "owner_id" = 4018,', - ' "deleted_at" = NULL,', - ' "width" = NULL,', - ' "height" = NULL,', - ' "rating_votes" = 0,', - ' "rating_score" = 0,', - ' "thumbnail_width" = NULL,', - ' "thumbnail_height" = NULL,', - ' "price" = 1,', - ' "description" = NULL']) + tformatted = '\n'.join([ + 'UPDATE "gallery_photo"', + 'SET "owner_id" = 4018,', + ' "deleted_at" = NULL,', + ' "width" = NULL,', + ' "height" = NULL,', + ' "rating_votes" = 0,', + ' "rating_score" = 0,', + ' "thumbnail_width" = NULL,', + ' "thumbnail_height" = NULL,', + ' "price" = 1,', + ' "description" = NULL']) assert formatted == tformatted @@ -230,8 +215,7 @@ def test_except_formatting(): 'EXCEPT', 'SELECT 2', 'FROM bar', - 'WHERE 1 = 2' - ]) + 'WHERE 1 = 2']) assert formatted == tformatted @@ -241,32 +225,31 @@ def test_null_with_as(): tformatted = '\n'.join([ 'SELECT NULL AS c1,', ' NULL AS c2', - 'FROM t1' - ]) + 'FROM t1']) assert formatted == tformatted def test_issue193_splitting_function(): - sql = """CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20) -BEGIN - DECLARE y VARCHAR(20); - RETURN x; -END; -SELECT * FROM a.b;""" + sql = """ CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20) + BEGIN + DECLARE y VARCHAR(20); + RETURN x; + END; + SELECT * FROM a.b;""" splitted = sqlparse.split(sql) assert len(splitted) == 2 def test_issue194_splitting_function(): - sql = """CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20) -BEGIN - DECLARE y VARCHAR(20); - IF (1 = 1) THEN - SET x = y; - END IF; - RETURN x; -END; -SELECT * FROM a.b;""" + sql = """ CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20) + BEGIN + DECLARE y VARCHAR(20); + IF (1 = 1) THEN + SET x = y; + END IF; + RETURN x; + END; + SELECT * FROM a.b;""" splitted = sqlparse.split(sql) assert len(splitted) == 2 @@ -278,8 +261,8 @@ def test_issue186_get_type(): def test_issue212_py2unicode(): - t1 = sql.Token(T.String, u"schöner ") - t2 = sql.Token(T.String, u"bug") + t1 = sql.Token(T.String, u'schöner ') + t2 = sql.Token(T.String, 'bug') l = sql.TokenList([t1, t2]) assert str(l) == 'schöner bug' @@ -295,23 +278,23 @@ def test_issue227_gettype_cte(): with_stmt = sqlparse.parse('WITH foo AS (SELECT 1, 2, 3)' 'SELECT * FROM foo;') assert with_stmt[0].get_type() == 'SELECT' - with2_stmt = sqlparse.parse(''' + with2_stmt = sqlparse.parse(""" WITH foo AS (SELECT 1 AS abc, 2 AS def), bar AS (SELECT * FROM something WHERE x > 1) - INSERT INTO elsewhere SELECT * FROM foo JOIN bar; - ''') + INSERT INTO elsewhere SELECT * FROM foo JOIN bar;""") assert with2_stmt[0].get_type() == 'INSERT' def test_issue207_runaway_format(): sql = 'select 1 from (select 1 as one, 2 as two, 3 from dual) t0' p = sqlparse.format(sql, reindent=True) - assert p == '\n'.join(["select 1", - "from", - " (select 1 as one,", - " 2 as two,", - " 3", - " from dual) t0"]) + assert p == '\n'.join([ + "select 1", + "from", + " (select 1 as one,", + " 2 as two,", + " 3", + " from dual) t0"]) def token_next_doesnt_ignore_skip_cm(): diff --git a/tests/test_split.py b/tests/test_split.py index 7c2645d..af7c9ce 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -4,139 +4,133 @@ import types -from tests.utils import load_file, TestCaseBase +import pytest import sqlparse -from sqlparse.compat import StringIO, u, text_type - - -class SQLSplitTest(TestCaseBase): - """Tests sqlparse.sqlsplit().""" - - _sql1 = 'select * from foo;' - _sql2 = 'select * from bar;' - - def test_split_semicolon(self): - sql2 = 'select * from foo where bar = \'foo;bar\';' - stmts = sqlparse.parse(''.join([self._sql1, sql2])) - self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(u(stmts[0]), self._sql1) - self.ndiffAssertEqual(u(stmts[1]), sql2) - - def test_split_backslash(self): - stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';") - self.assertEqual(len(stmts), 3) - - def test_create_function(self): - sql = load_file('function.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) - - def test_create_function_psql(self): - sql = load_file('function_psql.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) - - def test_create_function_psql3(self): - sql = load_file('function_psql3.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) - - def test_create_function_psql2(self): - sql = load_file('function_psql2.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) - - def test_dashcomments(self): - sql = load_file('dashcomment.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 3) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - - def test_dashcomments_eol(self): - stmts = sqlparse.parse('select foo; -- comment\n') - self.assertEqual(len(stmts), 1) - stmts = sqlparse.parse('select foo; -- comment\r') - self.assertEqual(len(stmts), 1) - stmts = sqlparse.parse('select foo; -- comment\r\n') - self.assertEqual(len(stmts), 1) - stmts = sqlparse.parse('select foo; -- comment') - self.assertEqual(len(stmts), 1) - - def test_begintag(self): - sql = load_file('begintag.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 3) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - - def test_begintag_2(self): - sql = load_file('begintag_2.sql') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - - def test_dropif(self): - sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;' - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - - def test_comment_with_umlaut(self): - sql = (u'select * from foo;\n' - u'-- Testing an umlaut: ä\n' - u'select * from bar;') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - - def test_comment_end_of_line(self): - sql = ('select * from foo; -- foo\n' - 'select * from bar;') - stmts = sqlparse.parse(sql) - self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) - # make sure the comment belongs to first query - self.ndiffAssertEqual(u(stmts[0]), 'select * from foo; -- foo\n') - - def test_casewhen(self): - sql = ('SELECT case when val = 1 then 2 else null end as foo;\n' - 'comment on table actor is \'The actor table.\';') - stmts = sqlparse.split(sql) - self.assertEqual(len(stmts), 2) - - def test_cursor_declare(self): - sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n' - 'SELECT 2;') - stmts = sqlparse.split(sql) - self.assertEqual(len(stmts), 2) - - def test_if_function(self): # see issue 33 - # don't let IF as a function confuse the splitter - sql = ('CREATE TEMPORARY TABLE tmp ' - 'SELECT IF(a=1, a, b) AS o FROM one; ' - 'SELECT t FROM two') - stmts = sqlparse.split(sql) - self.assertEqual(len(stmts), 2) - - def test_split_stream(self): - stream = StringIO("SELECT 1; SELECT 2;") - stmts = sqlparse.parsestream(stream) - self.assertEqual(type(stmts), types.GeneratorType) - self.assertEqual(len(list(stmts)), 2) - - def test_encoding_parsestream(self): - stream = StringIO("SELECT 1; SELECT 2;") - stmts = list(sqlparse.parsestream(stream)) - self.assertEqual(type(stmts[0].tokens[0].value), text_type) - - def test_unicode_parsestream(self): - stream = StringIO(u"SELECT ö") - stmts = list(sqlparse.parsestream(stream)) - self.assertEqual(str(stmts[0]), "SELECT ö") +from sqlparse.compat import StringIO, text_type + + +def test_split_semicolon(): + sql1 = 'select * from foo;' + sql2 = "select * from foo where bar = 'foo;bar';" + stmts = sqlparse.parse(''.join([sql1, sql2])) + assert len(stmts) == 2 + assert str(stmts[0]) == sql1 + assert str(stmts[1]) == sql2 + + +def test_split_backslash(): + stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';") + assert len(stmts) == 3 + + +@pytest.mark.parametrize('fn', ['function.sql', + 'function_psql.sql', + 'function_psql2.sql', + 'function_psql3.sql']) +def test_split_create_function(load_file, fn): + sql = load_file(fn) + stmts = sqlparse.parse(sql) + assert len(stmts) == 1 + assert text_type(stmts[0]) == sql + + +def test_split_dashcomments(load_file): + sql = load_file('dashcomment.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 3 + assert ''.join(str(q) for q in stmts) == sql + + +@pytest.mark.parametrize('s', ['select foo; -- comment\n', + 'select foo; -- comment\r', + 'select foo; -- comment\r\n', + 'select foo; -- comment']) +def test_split_dashcomments_eol(s): + stmts = sqlparse.parse(s) + assert len(stmts) == 1 + + +def test_split_begintag(load_file): + sql = load_file('begintag.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 3 + assert ''.join(str(q) for q in stmts) == sql + + +def test_split_begintag_2(load_file): + sql = load_file('begintag_2.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 1 + assert ''.join(str(q) for q in stmts) == sql + + +def test_split_dropif(): + sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;' + stmts = sqlparse.parse(sql) + assert len(stmts) == 2 + assert ''.join(str(q) for q in stmts) == sql + + +def test_split_comment_with_umlaut(): + sql = (u'select * from foo;\n' + u'-- Testing an umlaut: ä\n' + u'select * from bar;') + stmts = sqlparse.parse(sql) + assert len(stmts) == 2 + assert ''.join(text_type(q) for q in stmts) == sql + + +def test_split_comment_end_of_line(): + sql = ('select * from foo; -- foo\n' + 'select * from bar;') + stmts = sqlparse.parse(sql) + assert len(stmts) == 2 + assert ''.join(str(q) for q in stmts) == sql + # make sure the comment belongs to first query + assert str(stmts[0]) == 'select * from foo; -- foo\n' + + +def test_split_casewhen(): + sql = ("SELECT case when val = 1 then 2 else null end as foo;\n" + "comment on table actor is 'The actor table.';") + stmts = sqlparse.split(sql) + assert len(stmts) == 2 + + +def test_split_cursor_declare(): + sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n' + 'SELECT 2;') + stmts = sqlparse.split(sql) + assert len(stmts) == 2 + + +def test_split_if_function(): # see issue 33 + # don't let IF as a function confuse the splitter + sql = ('CREATE TEMPORARY TABLE tmp ' + 'SELECT IF(a=1, a, b) AS o FROM one; ' + 'SELECT t FROM two') + stmts = sqlparse.split(sql) + assert len(stmts) == 2 + + +def test_split_stream(): + stream = StringIO("SELECT 1; SELECT 2;") + stmts = sqlparse.parsestream(stream) + assert isinstance(stmts, types.GeneratorType) + assert len(list(stmts)) == 2 + + +def test_split_encoding_parsestream(): + stream = StringIO("SELECT 1; SELECT 2;") + stmts = list(sqlparse.parsestream(stream)) + assert isinstance(stmts[0].tokens[0].value, text_type) + + +def test_split_unicode_parsestream(): + stream = StringIO(u'SELECT ö') + stmts = list(sqlparse.parsestream(stream)) + assert str(stmts[0]) == 'SELECT ö' def test_split_simple(): diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py index 61eaa3e..6cc0dfa 100644 --- a/tests/test_tokenize.py +++ b/tests/test_tokenize.py @@ -1,171 +1,159 @@ # -*- coding: utf-8 -*- import types -import unittest import pytest import sqlparse from sqlparse import lexer -from sqlparse import sql -from sqlparse import tokens as T +from sqlparse import sql, tokens as T from sqlparse.compat import StringIO -class TestTokenize(unittest.TestCase): - - def test_simple(self): - s = 'select * from foo;' - stream = lexer.tokenize(s) - self.assert_(isinstance(stream, types.GeneratorType)) - tokens = list(stream) - self.assertEqual(len(tokens), 8) - self.assertEqual(len(tokens[0]), 2) - self.assertEqual(tokens[0], (T.Keyword.DML, u'select')) - self.assertEqual(tokens[-1], (T.Punctuation, u';')) - - def test_backticks(self): - s = '`foo`.`bar`' - tokens = list(lexer.tokenize(s)) - self.assertEqual(len(tokens), 3) - self.assertEqual(tokens[0], (T.Name, u'`foo`')) - - def test_linebreaks(self): # issue1 - s = 'foo\nbar\n' - tokens = lexer.tokenize(s) - self.assertEqual(''.join(str(x[1]) for x in tokens), s) - s = 'foo\rbar\r' - tokens = lexer.tokenize(s) - self.assertEqual(''.join(str(x[1]) for x in tokens), s) - s = 'foo\r\nbar\r\n' - tokens = lexer.tokenize(s) - self.assertEqual(''.join(str(x[1]) for x in tokens), s) - s = 'foo\r\nbar\n' - tokens = lexer.tokenize(s) - self.assertEqual(''.join(str(x[1]) for x in tokens), s) - - def test_inline_keywords(self): # issue 7 - s = "create created_foo" - tokens = list(lexer.tokenize(s)) - self.assertEqual(len(tokens), 3) - self.assertEqual(tokens[0][0], T.Keyword.DDL) - self.assertEqual(tokens[2][0], T.Name) - self.assertEqual(tokens[2][1], u'created_foo') - s = "enddate" - tokens = list(lexer.tokenize(s)) - self.assertEqual(len(tokens), 1) - self.assertEqual(tokens[0][0], T.Name) - s = "join_col" - tokens = list(lexer.tokenize(s)) - self.assertEqual(len(tokens), 1) - self.assertEqual(tokens[0][0], T.Name) - s = "left join_col" - tokens = list(lexer.tokenize(s)) - self.assertEqual(len(tokens), 3) - self.assertEqual(tokens[2][0], T.Name) - self.assertEqual(tokens[2][1], 'join_col') - - def test_negative_numbers(self): - s = "values(-1)" - tokens = list(lexer.tokenize(s)) - self.assertEqual(len(tokens), 4) - self.assertEqual(tokens[2][0], T.Number.Integer) - self.assertEqual(tokens[2][1], '-1') - - -class TestToken(unittest.TestCase): - - def test_str(self): - token = sql.Token(None, 'FoO') - self.assertEqual(str(token), 'FoO') - - def test_repr(self): - token = sql.Token(T.Keyword, 'foo') - tst = "<Keyword 'foo' at 0x" - self.assertEqual(repr(token)[:len(tst)], tst) - token = sql.Token(T.Keyword, '1234567890') - tst = "<Keyword '123456...' at 0x" - self.assertEqual(repr(token)[:len(tst)], tst) - - def test_flatten(self): - token = sql.Token(T.Keyword, 'foo') - gen = token.flatten() - self.assertEqual(type(gen), types.GeneratorType) - lgen = list(gen) - self.assertEqual(lgen, [token]) - - -class TestTokenList(unittest.TestCase): - - def test_repr(self): - p = sqlparse.parse('foo, bar, baz')[0] - tst = "<IdentifierList 'foo, b...' at 0x" - self.assertEqual(repr(p.tokens[0])[:len(tst)], tst) - - def test_token_first(self): - p = sqlparse.parse(' select foo')[0] - first = p.token_first() - self.assertEqual(first.value, 'select') - self.assertEqual(p.token_first(skip_ws=False).value, ' ') - self.assertEqual(sql.TokenList([]).token_first(), None) - - def test_token_matching(self): - t1 = sql.Token(T.Keyword, 'foo') - t2 = sql.Token(T.Punctuation, ',') - x = sql.TokenList([t1, t2]) - self.assertEqual(x.token_matching( - [lambda t: t.ttype is T.Keyword], 0), t1) - self.assertEqual(x.token_matching( - [lambda t: t.ttype is T.Punctuation], 0), t2) - self.assertEqual(x.token_matching( - [lambda t: t.ttype is T.Keyword], 1), None) - - -class TestStream(unittest.TestCase): - def test_simple(self): - stream = StringIO("SELECT 1; SELECT 2;") - - tokens = lexer.tokenize(stream) - self.assertEqual(len(list(tokens)), 9) - - stream.seek(0) - tokens = list(lexer.tokenize(stream)) - self.assertEqual(len(tokens), 9) - - stream.seek(0) - tokens = list(lexer.tokenize(stream)) - self.assertEqual(len(tokens), 9) - - def test_error(self): - stream = StringIO("FOOBAR{") - - tokens = list(lexer.tokenize(stream)) - self.assertEqual(len(tokens), 2) - self.assertEqual(tokens[1][0], T.Error) - - -@pytest.mark.parametrize('expr', ['JOIN', 'LEFT JOIN', 'LEFT OUTER JOIN', - 'FULL OUTER JOIN', 'NATURAL JOIN', - 'CROSS JOIN', 'STRAIGHT JOIN', - 'INNER JOIN', 'LEFT INNER JOIN']) +def test_tokenize_simple(): + s = 'select * from foo;' + stream = lexer.tokenize(s) + assert isinstance(stream, types.GeneratorType) + tokens = list(stream) + assert len(tokens) == 8 + assert len(tokens[0]) == 2 + assert tokens[0] == (T.Keyword.DML, 'select') + assert tokens[-1] == (T.Punctuation, ';') + + +def test_tokenize_backticks(): + s = '`foo`.`bar`' + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 3 + assert tokens[0] == (T.Name, '`foo`') + + +@pytest.mark.parametrize('s', ['foo\nbar\n', 'foo\rbar\r', + 'foo\r\nbar\r\n', 'foo\r\nbar\n']) +def test_tokenize_linebreaks(s): + # issue1 + tokens = lexer.tokenize(s) + assert ''.join(str(x[1]) for x in tokens) == s + + +def test_tokenize_inline_keywords(): + # issue 7 + s = "create created_foo" + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 3 + assert tokens[0][0] == T.Keyword.DDL + assert tokens[2][0] == T.Name + assert tokens[2][1] == 'created_foo' + s = "enddate" + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 1 + assert tokens[0][0] == T.Name + s = "join_col" + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 1 + assert tokens[0][0] == T.Name + s = "left join_col" + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 3 + assert tokens[2][0] == T.Name + assert tokens[2][1] == 'join_col' + + +def test_tokenize_negative_numbers(): + s = "values(-1)" + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 4 + assert tokens[2][0] == T.Number.Integer + assert tokens[2][1] == '-1' + + +def test_token_str(): + token = sql.Token(None, 'FoO') + assert str(token) == 'FoO' + + +def test_token_repr(): + token = sql.Token(T.Keyword, 'foo') + tst = "<Keyword 'foo' at 0x" + assert repr(token)[:len(tst)] == tst + token = sql.Token(T.Keyword, '1234567890') + tst = "<Keyword '123456...' at 0x" + assert repr(token)[:len(tst)] == tst + + +def test_token_flatten(): + token = sql.Token(T.Keyword, 'foo') + gen = token.flatten() + assert isinstance(gen, types.GeneratorType) + lgen = list(gen) + assert lgen == [token] + + +def test_tokenlist_repr(): + p = sqlparse.parse('foo, bar, baz')[0] + tst = "<IdentifierList 'foo, b...' at 0x" + assert repr(p.tokens[0])[:len(tst)] == tst + + +def test_tokenlist_first(): + p = sqlparse.parse(' select foo')[0] + first = p.token_first() + assert first.value == 'select' + assert p.token_first(skip_ws=False).value == ' ' + assert sql.TokenList([]).token_first() is None + + +def test_tokenlist_token_matching(): + t1 = sql.Token(T.Keyword, 'foo') + t2 = sql.Token(T.Punctuation, ',') + x = sql.TokenList([t1, t2]) + assert x.token_matching([lambda t: t.ttype is T.Keyword], 0) == t1 + assert x.token_matching([lambda t: t.ttype is T.Punctuation], 0) == t2 + assert x.token_matching([lambda t: t.ttype is T.Keyword], 1) is None + + +def test_stream_simple(): + stream = StringIO("SELECT 1; SELECT 2;") + + tokens = lexer.tokenize(stream) + assert len(list(tokens)) == 9 + + stream.seek(0) + tokens = list(lexer.tokenize(stream)) + assert len(tokens) == 9 + + stream.seek(0) + tokens = list(lexer.tokenize(stream)) + assert len(tokens) == 9 + + +def test_stream_error(): + stream = StringIO("FOOBAR{") + + tokens = list(lexer.tokenize(stream)) + assert len(tokens) == 2 + assert tokens[1][0] == T.Error + + +@pytest.mark.parametrize('expr', [ + 'JOIN', + 'LEFT JOIN', + 'LEFT OUTER JOIN', + 'FULL OUTER JOIN', + 'NATURAL JOIN', + 'CROSS JOIN', + 'STRAIGHT JOIN', + 'INNER JOIN', + 'LEFT INNER JOIN']) def test_parse_join(expr): p = sqlparse.parse('{0} foo'.format(expr))[0] assert len(p.tokens) == 3 assert p.tokens[0].ttype is T.Keyword -def test_parse_endifloop(): - p = sqlparse.parse('END IF')[0] - assert len(p.tokens) == 1 - assert p.tokens[0].ttype is T.Keyword - p = sqlparse.parse('END IF')[0] - assert len(p.tokens) == 1 - p = sqlparse.parse('END\t\nIF')[0] - assert len(p.tokens) == 1 - assert p.tokens[0].ttype is T.Keyword - p = sqlparse.parse('END LOOP')[0] - assert len(p.tokens) == 1 - assert p.tokens[0].ttype is T.Keyword - p = sqlparse.parse('END LOOP')[0] +@pytest.mark.parametrize('s', ['END IF', 'END IF', 'END\t\nIF', + 'END LOOP', 'END LOOP', 'END\t\nLOOP']) +def test_parse_endifloop(s): + p = sqlparse.parse(s)[0] assert len(p.tokens) == 1 assert p.tokens[0].ttype is T.Keyword diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index 8000ae5..0000000 --- a/tests/utils.py +++ /dev/null @@ -1,41 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Helpers for testing.""" - -import difflib -import io -import os -import unittest - -from sqlparse.utils import split_unquoted_newlines -from sqlparse.compat import StringIO - -DIR_PATH = os.path.dirname(__file__) -FILES_DIR = os.path.join(DIR_PATH, 'files') - - -def load_file(filename, encoding='utf-8'): - """Opens filename with encoding and return its contents.""" - with io.open(os.path.join(FILES_DIR, filename), encoding=encoding) as f: - return f.read() - - -class TestCaseBase(unittest.TestCase): - """Base class for test cases with some additional checks.""" - - # Adopted from Python's tests. - def ndiffAssertEqual(self, first, second): - """Like failUnlessEqual except use ndiff for readable output.""" - if first != second: - # Using the built-in .splitlines() method here will cause incorrect - # results when splitting statements that have quoted CR/CR+LF - # characters. - sfirst = split_unquoted_newlines(first) - ssecond = split_unquoted_newlines(second) - diff = difflib.ndiff(sfirst, ssecond) - - fp = StringIO() - fp.write('\n') - fp.write('\n'.join(diff)) - - raise self.failureException(fp.getvalue()) |
