diff options
Diffstat (limited to 'tests/test_format.py')
| -rw-r--r-- | tests/test_format.py | 472 |
1 files changed, 222 insertions, 250 deletions
diff --git a/tests/test_format.py b/tests/test_format.py index 74fce71..fc15b2f 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -2,75 +2,73 @@ import pytest -from tests.utils import TestCaseBase - import sqlparse from sqlparse.exceptions import SQLParseError +from tests.utils import TestCaseBase class TestFormat(TestCaseBase): - def test_keywordcase(self): sql = 'select * from bar; -- select foo\n' res = sqlparse.format(sql, keyword_case='upper') - self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- select foo\n') + assert res == 'SELECT * FROM bar; -- select foo\n' res = sqlparse.format(sql, keyword_case='capitalize') - self.ndiffAssertEqual(res, 'Select * From bar; -- select foo\n') + assert res == 'Select * From bar; -- select foo\n' res = sqlparse.format(sql.upper(), keyword_case='lower') - self.ndiffAssertEqual(res, 'select * from BAR; -- SELECT FOO\n') + assert res == 'select * from BAR; -- SELECT FOO\n' self.assertRaises(SQLParseError, sqlparse.format, sql, keyword_case='foo') def test_identifiercase(self): sql = 'select * from bar; -- select foo\n' res = sqlparse.format(sql, identifier_case='upper') - self.ndiffAssertEqual(res, 'select * from BAR; -- select foo\n') + assert res == 'select * from BAR; -- select foo\n' res = sqlparse.format(sql, identifier_case='capitalize') - self.ndiffAssertEqual(res, 'select * from Bar; -- select foo\n') + assert res == 'select * from Bar; -- select foo\n' res = sqlparse.format(sql.upper(), identifier_case='lower') - self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- SELECT FOO\n') + assert res == 'SELECT * FROM bar; -- SELECT FOO\n' self.assertRaises(SQLParseError, sqlparse.format, sql, identifier_case='foo') sql = 'select * from "foo"."bar"' res = sqlparse.format(sql, identifier_case="upper") - self.ndiffAssertEqual(res, 'select * from "foo"."bar"') + assert res == 'select * from "foo"."bar"' def test_strip_comments_single(self): sql = 'select *-- statement starts here\nfrom foo' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select * from foo') + assert res == 'select * from foo' sql = 'select * -- statement starts here\nfrom foo' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select * from foo') + assert res == 'select * from foo' sql = 'select-- foo\nfrom -- bar\nwhere' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select from where') + assert res == 'select from where' self.assertRaises(SQLParseError, sqlparse.format, sql, strip_comments=None) def test_strip_comments_multi(self): sql = '/* sql starts here */\nselect' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select') + assert res == 'select' sql = '/* sql starts here */ select' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select') + assert res == 'select' sql = '/*\n * sql starts here\n */\nselect' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select') + assert res == 'select' sql = 'select (/* sql starts here */ select 2)' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select (select 2)') + assert res == 'select (select 2)' sql = 'select (/* sql /* starts here */ select 2)' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select (select 2)') + assert res == 'select (select 2)' def test_strip_ws(self): f = lambda sql: sqlparse.format(sql, strip_whitespace=True) s = 'select\n* from foo\n\twhere ( 1 = 2 )\n' - self.ndiffAssertEqual(f(s), 'select * from foo where (1 = 2)') + assert f(s) == 'select * from foo where (1 = 2)' s = 'select -- foo\nfrom bar\n' - self.ndiffAssertEqual(f(s), 'select -- foo\nfrom bar') + assert f(s) == 'select -- foo\nfrom bar' self.assertRaises(SQLParseError, sqlparse.format, s, strip_whitespace=None) @@ -78,7 +76,7 @@ class TestFormat(TestCaseBase): # preserve at least one whitespace after subgroups f = lambda sql: sqlparse.format(sql, strip_whitespace=True) s = 'select\n* /* foo */ from bar ' - self.ndiffAssertEqual(f(s), 'select * /* foo */ from bar') + assert f(s) == 'select * /* foo */ from bar' def test_notransform_of_quoted_crlf(self): # Make sure that CR/CR+LF characters inside string literals don't get @@ -92,13 +90,11 @@ class TestFormat(TestCaseBase): f = lambda x: sqlparse.format(x) # Because of the use of - self.ndiffAssertEqual(f(s1), "SELECT some_column LIKE 'value\r'") - self.ndiffAssertEqual( - f(s2), "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n") - self.ndiffAssertEqual( - f(s3), "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n") - self.ndiffAssertEqual( - f(s4), "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n") + assert f(s1) == "SELECT some_column LIKE 'value\r'" + assert f(s2) == "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n" + assert f(s3) == "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n" + assert f( + s4) == "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n" def test_outputformat(self): sql = 'select * from foo;' @@ -121,23 +117,22 @@ class TestFormatReindentAligned(TestCaseBase): or d is 'blue' limit 10 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' b as bb,', - ' c', - ' from table', - ' join (', - ' select a * 2 as a', - ' from new_table', - ' ) other', - ' on table.a = other.a', - ' where c is true', - ' and b between 3 and 4', - " or d is 'blue'", - ' limit 10', - ])) + + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' b as bb,', + ' c', + ' from table', + ' join (', + ' select a * 2 as a', + ' from new_table', + ' ) other', + ' on table.a = other.a', + ' where c is true', + ' and b between 3 and 4', + " or d is 'blue'", + ' limit 10', + ]) def test_joins(self): sql = """ @@ -148,22 +143,20 @@ class TestFormatReindentAligned(TestCaseBase): cross join e on e.four = a.four join f using (one, two, three) """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select *', - ' from a', - ' join b', - ' on a.one = b.one', - ' left join c', - ' on c.two = a.two', - ' and c.three = a.three', - ' full outer join d', - ' on d.three = a.three', - ' cross join e', - ' on e.four = a.four', - ' join f using (one, two, three)', - ])) + assert self.formatter(sql) == '\n'.join([ + 'select *', + ' from a', + ' join b', + ' on a.one = b.one', + ' left join c', + ' on c.two = a.two', + ' and c.three = a.three', + ' full outer join d', + ' on d.three = a.three', + ' cross join e', + ' on e.four = a.four', + ' join f using (one, two, three)', + ]) def test_case_statement(self): sql = """ @@ -178,20 +171,18 @@ class TestFormatReindentAligned(TestCaseBase): where c is true and b between 3 and 4 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' case when a = 0 then 1', - ' when bb = 1 then 1', - ' when c = 2 then 2', - ' else 0', - ' end as d,', - ' extra_col', - ' from table', - ' where c is true', - ' and b between 3 and 4' - ])) + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' case when a = 0 then 1', + ' when bb = 1 then 1', + ' when c = 2 then 2', + ' else 0', + ' end as d,', + ' extra_col', + ' from table', + ' where c is true', + ' and b between 3 and 4' + ]) def test_case_statement_with_between(self): sql = """ @@ -207,21 +198,19 @@ class TestFormatReindentAligned(TestCaseBase): where c is true and b between 3 and 4 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' case when a = 0 then 1', - ' when bb = 1 then 1', - ' when c = 2 then 2', - ' when d between 3 and 5 then 3', - ' else 0', - ' end as d,', - ' extra_col', - ' from table', - ' where c is true', - ' and b between 3 and 4' - ])) + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' case when a = 0 then 1', + ' when bb = 1 then 1', + ' when c = 2 then 2', + ' when d between 3 and 5 then 3', + ' else 0', + ' end as d,', + ' extra_col', + ' from table', + ' where c is true', + ' and b between 3 and 4' + ]) def test_group_by(self): sql = """ @@ -232,24 +221,22 @@ class TestFormatReindentAligned(TestCaseBase): and count(y) > 5 order by 3,2,1 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' b,', - ' c,', - ' sum(x) as sum_x,', - ' count(y) as cnt_y', - ' from table', - ' group by a,', - ' b,', - ' c', - 'having sum(x) > 1', - ' and count(y) > 5', - ' order by 3,', - ' 2,', - ' 1', - ])) + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' b,', + ' c,', + ' sum(x) as sum_x,', + ' count(y) as cnt_y', + ' from table', + ' group by a,', + ' b,', + ' c', + 'having sum(x) > 1', + ' and count(y) > 5', + ' order by 3,', + ' 2,', + ' 1', + ]) def test_group_by_subquery(self): # TODO: add subquery alias when test_identifier_list_subquery fixed @@ -261,21 +248,19 @@ class TestFormatReindentAligned(TestCaseBase): group by a,z) order by 1,2 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select *,', - ' sum_b + 2 as mod_sum', - ' from (', - ' select a,', - ' sum(b) as sum_b', - ' from table', - ' group by a,', - ' z', - ' )', - ' order by 1,', - ' 2', - ])) + assert self.formatter(sql) == '\n'.join([ + 'select *,', + ' sum_b + 2 as mod_sum', + ' from (', + ' select a,', + ' sum(b) as sum_b', + ' from table', + ' group by a,', + ' z', + ' )', + ' order by 1,', + ' 2', + ]) def test_window_functions(self): sql = """ @@ -286,16 +271,14 @@ class TestFormatReindentAligned(TestCaseBase): (PARTITION BY b, c ORDER BY d DESC) as row_num from table """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - (' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS ' - 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,'), - (' ROW_NUMBER() OVER ' - '(PARTITION BY b, c ORDER BY d DESC) as row_num'), - ' from table', - ])) + assert self.formatter(sql) == '\n'.join([ + 'select a,', + (' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS ' + 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,'), + (' ROW_NUMBER() OVER ' + '(PARTITION BY b, c ORDER BY d DESC) as row_num'), + ' from table', + ]) class TestSpacesAroundOperators(TestCaseBase): @@ -306,36 +289,27 @@ class TestSpacesAroundOperators(TestCaseBase): def test_basic(self): sql = ('select a+b as d from table ' 'where (c-d)%2= 1 and e> 3.0/4 and z^2 <100') - self.ndiffAssertEqual( - self.formatter(sql), ( - 'select a + b as d from table ' - 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100') + assert self.formatter(sql) == ( + 'select a + b as d from table ' + 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100' ) def test_bools(self): sql = 'select * from table where a &&b or c||d' - self.ndiffAssertEqual( - self.formatter(sql), - 'select * from table where a && b or c || d' - ) + assert self.formatter( + sql) == 'select * from table where a && b or c || d' def test_nested(self): sql = 'select *, case when a-b then c end from table' - self.ndiffAssertEqual( - self.formatter(sql), - 'select *, case when a - b then c end from table' - ) + assert self.formatter( + sql) == 'select *, case when a - b then c end from table' def test_wildcard_vs_mult(self): sql = 'select a*b-c from table' - self.ndiffAssertEqual( - self.formatter(sql), - 'select a * b - c from table' - ) + assert self.formatter(sql) == 'select a * b - c from table' class TestFormatReindent(TestCaseBase): - def test_option(self): self.assertRaises(SQLParseError, sqlparse.format, 'foo', reindent=2) @@ -353,206 +327,204 @@ class TestFormatReindent(TestCaseBase): def test_stmts(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo; select bar' - self.ndiffAssertEqual(f(s), 'select foo;\n\nselect bar') + assert f(s) == 'select foo;\n\nselect bar' s = 'select foo' - self.ndiffAssertEqual(f(s), 'select foo') + assert f(s) == 'select foo' s = 'select foo; -- test\n select bar' - self.ndiffAssertEqual(f(s), 'select foo; -- test\n\nselect bar') + assert f(s) == 'select foo; -- test\n\nselect bar' def test_keywords(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo union select * from bar;' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'union', - 'select *', - 'from bar;'])) + assert f(s) == '\n'.join(['select *', + 'from foo', + 'union', + 'select *', + 'from bar;']) def test_keywords_between(self): # issue 14 # don't break AND after BETWEEN f = lambda sql: sqlparse.format(sql, reindent=True) s = 'and foo between 1 and 2 and bar = 3' - self.ndiffAssertEqual(f(s), '\n'.join(['', - 'and foo between 1 and 2', - 'and bar = 3'])) + assert f(s) == '\n'.join(['', + 'and foo between 1 and 2', + 'and bar = 3']) def test_parenthesis(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select count(*) from (select * from foo);' - self.ndiffAssertEqual(f(s), - '\n'.join(['select count(*)', - 'from', - ' (select *', - ' from foo);', - ]) - ) + assert f(s) == '\n'.join(['select count(*)', + 'from', + ' (select *', + ' from foo);', + ] + ) def test_where(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;' - self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n' - 'where bar = 1\n' - ' and baz = 2\n' - ' or bzz = 3;')) + assert f(s) == ('select *\nfrom foo\n' + 'where bar = 1\n' + ' and baz = 2\n' + ' or bzz = 3;') s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);' - self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n' - 'where bar = 1\n' - ' and (baz = 2\n' - ' or bzz = 3);')) + assert f(s) == ('select *\nfrom foo\n' + 'where bar = 1\n' + ' and (baz = 2\n' + ' or bzz = 3);') def test_join(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'join bar on 1 = 2'])) + assert f(s) == '\n'.join(['select *', + 'from foo', + 'join bar on 1 = 2']) s = 'select * from foo inner join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'inner join bar on 1 = 2'])) + assert f(s) == '\n'.join(['select *', + 'from foo', + 'inner join bar on 1 = 2']) s = 'select * from foo left outer join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'left outer join bar on 1 = 2'] - )) + assert f(s) == '\n'.join(['select *', + 'from foo', + 'left outer join bar on 1 = 2'] + ) s = 'select * from foo straight_join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'straight_join bar on 1 = 2'] - )) + assert f(s) == '\n'.join(['select *', + 'from foo', + 'straight_join bar on 1 = 2'] + ) def test_identifier_list(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo, bar, baz from table1, table2 where 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select foo,', - ' bar,', - ' baz', - 'from table1,', - ' table2', - 'where 1 = 2'])) + assert f(s) == '\n'.join(['select foo,', + ' bar,', + ' baz', + 'from table1,', + ' table2', + 'where 1 = 2']) s = 'select a.*, b.id from a, b' - self.ndiffAssertEqual(f(s), '\n'.join(['select a.*,', - ' b.id', - 'from a,', - ' b'])) + assert f(s) == '\n'.join(['select a.*,', + ' b.id', + 'from a,', + ' b']) def test_identifier_list_with_wrap_after(self): f = lambda sql: sqlparse.format(sql, reindent=True, wrap_after=14) s = 'select foo, bar, baz from table1, table2 where 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select foo, bar,', - ' baz', - 'from table1, table2', - 'where 1 = 2'])) + assert f(s) == '\n'.join(['select foo, bar,', + ' baz', + 'from table1, table2', + 'where 1 = 2']) def test_identifier_list_with_functions(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = ("select 'abc' as foo, coalesce(col1, col2)||col3 as bar," "col3 from my_table") - self.ndiffAssertEqual(f(s), '\n'.join( + assert f(s) == '\n'.join( ["select 'abc' as foo,", " coalesce(col1, col2)||col3 as bar,", " col3", - "from my_table"])) + "from my_table"]) def test_case(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end' - self.ndiffAssertEqual(f(s), '\n'.join(['case', - ' when foo = 1 then 2', - ' when foo = 3 then 4', - ' else 5', - 'end'])) + assert f(s) == '\n'.join(['case', + ' when foo = 1 then 2', + ' when foo = 3 then 4', + ' else 5', + 'end']) def test_case2(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'case(foo) when bar = 1 then 2 else 3 end' - self.ndiffAssertEqual(f(s), '\n'.join(['case(foo)', - ' when bar = 1 then 2', - ' else 3', - 'end'])) + assert f(s) == '\n'.join(['case(foo)', + ' when bar = 1 then 2', + ' else 3', + 'end']) def test_nested_identifier_list(self): # issue4 f = lambda sql: sqlparse.format(sql, reindent=True) s = '(foo as bar, bar1, bar2 as bar3, b4 as b5)' - self.ndiffAssertEqual(f(s), '\n'.join(['(foo as bar,', - ' bar1,', - ' bar2 as bar3,', - ' b4 as b5)'])) + assert f(s) == '\n'.join(['(foo as bar,', + ' bar1,', + ' bar2 as bar3,', + ' b4 as b5)']) def test_duplicate_linebreaks(self): # issue3 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select c1 -- column1\nfrom foo' - self.ndiffAssertEqual(f(s), '\n'.join(['select c1 -- column1', - 'from foo'])) + assert f(s) == '\n'.join(['select c1 -- column1', + 'from foo']) s = 'select c1 -- column1\nfrom foo' r = sqlparse.format(s, reindent=True, strip_comments=True) - self.ndiffAssertEqual(r, '\n'.join(['select c1', - 'from foo'])) + assert r == '\n'.join(['select c1', + 'from foo']) s = 'select c1\nfrom foo\norder by c1' - self.ndiffAssertEqual(f(s), '\n'.join(['select c1', - 'from foo', - 'order by c1'])) + assert f(s) == '\n'.join(['select c1', + 'from foo', + 'order by c1']) s = 'select c1 from t1 where (c1 = 1) order by c1' - self.ndiffAssertEqual(f(s), '\n'.join(['select c1', - 'from t1', - 'where (c1 = 1)', - 'order by c1'])) + assert f(s) == '\n'.join(['select c1', + 'from t1', + 'where (c1 = 1)', + 'order by c1']) def test_keywordfunctions(self): # issue36 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select max(a) b, foo, bar' - self.ndiffAssertEqual(f(s), '\n'.join(['select max(a) b,', - ' foo,', - ' bar'])) + assert f(s) == '\n'.join(['select max(a) b,', + ' foo,', + ' bar']) def test_identifier_and_functions(self): # issue45 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo.bar, nvl(1) from dual' - self.ndiffAssertEqual(f(s), '\n'.join(['select foo.bar,', - ' nvl(1)', - 'from dual'])) + assert f(s) == '\n'.join(['select foo.bar,', + ' nvl(1)', + 'from dual']) class TestOutputFormat(TestCaseBase): - def test_python(self): sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='python') - self.ndiffAssertEqual(f(sql), "sql = 'select * from foo;'") + assert f(sql) == "sql = 'select * from foo;'" f = lambda sql: sqlparse.format(sql, output_format='python', reindent=True) - self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n" - " 'from foo;')")) + assert f(sql) == ("sql = ('select * '\n" + " 'from foo;')") def test_python_multiple_statements(self): sql = 'select * from foo; select 1 from dual' f = lambda sql: sqlparse.format(sql, output_format='python') - self.ndiffAssertEqual(f(sql), ("sql = 'select * from foo; '\n" - "sql2 = 'select 1 from dual'")) + assert f(sql) == ("sql = 'select * from foo; '\n" + "sql2 = 'select 1 from dual'") @pytest.mark.xfail(reason="Needs fixing") def test_python_multiple_statements_with_formatting(self): sql = 'select * from foo; select 1 from dual' f = lambda sql: sqlparse.format(sql, output_format='python', reindent=True) - self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n" - " 'from foo;')\n" - "sql2 = ('select 1 '\n" - " 'from dual')")) + assert f(sql) == ("sql = ('select * '\n" + " 'from foo;')\n" + "sql2 = ('select 1 '\n" + " 'from dual')") def test_php(self): sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='php') - self.ndiffAssertEqual(f(sql), '$sql = "select * from foo;";') + assert f(sql) == '$sql = "select * from foo;";' f = lambda sql: sqlparse.format(sql, output_format='php', reindent=True) - self.ndiffAssertEqual(f(sql), ('$sql = "select * ";\n' - '$sql .= "from foo;";')) + assert f(sql) == ('$sql = "select * ";\n' + '$sql .= "from foo;";') def test_sql(self): # "sql" is an allowed option but has no effect sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='sql') - self.ndiffAssertEqual(f(sql), 'select * from foo;') + assert f(sql) == 'select * from foo;' def test_format_column_ordering(): # issue89 |
