diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/test_format.py | 472 | ||||
| -rw-r--r-- | tests/test_grouping.py | 34 | ||||
| -rw-r--r-- | tests/test_parse.py | 3 | ||||
| -rw-r--r-- | tests/test_regressions.py | 50 | ||||
| -rw-r--r-- | tests/test_split.py | 29 | ||||
| -rw-r--r-- | tests/test_tokenize.py | 3 |
6 files changed, 277 insertions, 314 deletions
diff --git a/tests/test_format.py b/tests/test_format.py index 74fce71..fc15b2f 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -2,75 +2,73 @@ import pytest -from tests.utils import TestCaseBase - import sqlparse from sqlparse.exceptions import SQLParseError +from tests.utils import TestCaseBase class TestFormat(TestCaseBase): - def test_keywordcase(self): sql = 'select * from bar; -- select foo\n' res = sqlparse.format(sql, keyword_case='upper') - self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- select foo\n') + assert res == 'SELECT * FROM bar; -- select foo\n' res = sqlparse.format(sql, keyword_case='capitalize') - self.ndiffAssertEqual(res, 'Select * From bar; -- select foo\n') + assert res == 'Select * From bar; -- select foo\n' res = sqlparse.format(sql.upper(), keyword_case='lower') - self.ndiffAssertEqual(res, 'select * from BAR; -- SELECT FOO\n') + assert res == 'select * from BAR; -- SELECT FOO\n' self.assertRaises(SQLParseError, sqlparse.format, sql, keyword_case='foo') def test_identifiercase(self): sql = 'select * from bar; -- select foo\n' res = sqlparse.format(sql, identifier_case='upper') - self.ndiffAssertEqual(res, 'select * from BAR; -- select foo\n') + assert res == 'select * from BAR; -- select foo\n' res = sqlparse.format(sql, identifier_case='capitalize') - self.ndiffAssertEqual(res, 'select * from Bar; -- select foo\n') + assert res == 'select * from Bar; -- select foo\n' res = sqlparse.format(sql.upper(), identifier_case='lower') - self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- SELECT FOO\n') + assert res == 'SELECT * FROM bar; -- SELECT FOO\n' self.assertRaises(SQLParseError, sqlparse.format, sql, identifier_case='foo') sql = 'select * from "foo"."bar"' res = sqlparse.format(sql, identifier_case="upper") - self.ndiffAssertEqual(res, 'select * from "foo"."bar"') + assert res == 'select * from "foo"."bar"' def test_strip_comments_single(self): sql = 'select *-- statement starts here\nfrom foo' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select * from foo') + assert res == 'select * from foo' sql = 'select * -- statement starts here\nfrom foo' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select * from foo') + assert res == 'select * from foo' sql = 'select-- foo\nfrom -- bar\nwhere' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select from where') + assert res == 'select from where' self.assertRaises(SQLParseError, sqlparse.format, sql, strip_comments=None) def test_strip_comments_multi(self): sql = '/* sql starts here */\nselect' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select') + assert res == 'select' sql = '/* sql starts here */ select' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select') + assert res == 'select' sql = '/*\n * sql starts here\n */\nselect' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select') + assert res == 'select' sql = 'select (/* sql starts here */ select 2)' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select (select 2)') + assert res == 'select (select 2)' sql = 'select (/* sql /* starts here */ select 2)' res = sqlparse.format(sql, strip_comments=True) - self.ndiffAssertEqual(res, 'select (select 2)') + assert res == 'select (select 2)' def test_strip_ws(self): f = lambda sql: sqlparse.format(sql, strip_whitespace=True) s = 'select\n* from foo\n\twhere ( 1 = 2 )\n' - self.ndiffAssertEqual(f(s), 'select * from foo where (1 = 2)') + assert f(s) == 'select * from foo where (1 = 2)' s = 'select -- foo\nfrom bar\n' - self.ndiffAssertEqual(f(s), 'select -- foo\nfrom bar') + assert f(s) == 'select -- foo\nfrom bar' self.assertRaises(SQLParseError, sqlparse.format, s, strip_whitespace=None) @@ -78,7 +76,7 @@ class TestFormat(TestCaseBase): # preserve at least one whitespace after subgroups f = lambda sql: sqlparse.format(sql, strip_whitespace=True) s = 'select\n* /* foo */ from bar ' - self.ndiffAssertEqual(f(s), 'select * /* foo */ from bar') + assert f(s) == 'select * /* foo */ from bar' def test_notransform_of_quoted_crlf(self): # Make sure that CR/CR+LF characters inside string literals don't get @@ -92,13 +90,11 @@ class TestFormat(TestCaseBase): f = lambda x: sqlparse.format(x) # Because of the use of - self.ndiffAssertEqual(f(s1), "SELECT some_column LIKE 'value\r'") - self.ndiffAssertEqual( - f(s2), "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n") - self.ndiffAssertEqual( - f(s3), "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n") - self.ndiffAssertEqual( - f(s4), "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n") + assert f(s1) == "SELECT some_column LIKE 'value\r'" + assert f(s2) == "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n" + assert f(s3) == "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n" + assert f( + s4) == "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n" def test_outputformat(self): sql = 'select * from foo;' @@ -121,23 +117,22 @@ class TestFormatReindentAligned(TestCaseBase): or d is 'blue' limit 10 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' b as bb,', - ' c', - ' from table', - ' join (', - ' select a * 2 as a', - ' from new_table', - ' ) other', - ' on table.a = other.a', - ' where c is true', - ' and b between 3 and 4', - " or d is 'blue'", - ' limit 10', - ])) + + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' b as bb,', + ' c', + ' from table', + ' join (', + ' select a * 2 as a', + ' from new_table', + ' ) other', + ' on table.a = other.a', + ' where c is true', + ' and b between 3 and 4', + " or d is 'blue'", + ' limit 10', + ]) def test_joins(self): sql = """ @@ -148,22 +143,20 @@ class TestFormatReindentAligned(TestCaseBase): cross join e on e.four = a.four join f using (one, two, three) """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select *', - ' from a', - ' join b', - ' on a.one = b.one', - ' left join c', - ' on c.two = a.two', - ' and c.three = a.three', - ' full outer join d', - ' on d.three = a.three', - ' cross join e', - ' on e.four = a.four', - ' join f using (one, two, three)', - ])) + assert self.formatter(sql) == '\n'.join([ + 'select *', + ' from a', + ' join b', + ' on a.one = b.one', + ' left join c', + ' on c.two = a.two', + ' and c.three = a.three', + ' full outer join d', + ' on d.three = a.three', + ' cross join e', + ' on e.four = a.four', + ' join f using (one, two, three)', + ]) def test_case_statement(self): sql = """ @@ -178,20 +171,18 @@ class TestFormatReindentAligned(TestCaseBase): where c is true and b between 3 and 4 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' case when a = 0 then 1', - ' when bb = 1 then 1', - ' when c = 2 then 2', - ' else 0', - ' end as d,', - ' extra_col', - ' from table', - ' where c is true', - ' and b between 3 and 4' - ])) + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' case when a = 0 then 1', + ' when bb = 1 then 1', + ' when c = 2 then 2', + ' else 0', + ' end as d,', + ' extra_col', + ' from table', + ' where c is true', + ' and b between 3 and 4' + ]) def test_case_statement_with_between(self): sql = """ @@ -207,21 +198,19 @@ class TestFormatReindentAligned(TestCaseBase): where c is true and b between 3 and 4 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' case when a = 0 then 1', - ' when bb = 1 then 1', - ' when c = 2 then 2', - ' when d between 3 and 5 then 3', - ' else 0', - ' end as d,', - ' extra_col', - ' from table', - ' where c is true', - ' and b between 3 and 4' - ])) + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' case when a = 0 then 1', + ' when bb = 1 then 1', + ' when c = 2 then 2', + ' when d between 3 and 5 then 3', + ' else 0', + ' end as d,', + ' extra_col', + ' from table', + ' where c is true', + ' and b between 3 and 4' + ]) def test_group_by(self): sql = """ @@ -232,24 +221,22 @@ class TestFormatReindentAligned(TestCaseBase): and count(y) > 5 order by 3,2,1 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - ' b,', - ' c,', - ' sum(x) as sum_x,', - ' count(y) as cnt_y', - ' from table', - ' group by a,', - ' b,', - ' c', - 'having sum(x) > 1', - ' and count(y) > 5', - ' order by 3,', - ' 2,', - ' 1', - ])) + assert self.formatter(sql) == '\n'.join([ + 'select a,', + ' b,', + ' c,', + ' sum(x) as sum_x,', + ' count(y) as cnt_y', + ' from table', + ' group by a,', + ' b,', + ' c', + 'having sum(x) > 1', + ' and count(y) > 5', + ' order by 3,', + ' 2,', + ' 1', + ]) def test_group_by_subquery(self): # TODO: add subquery alias when test_identifier_list_subquery fixed @@ -261,21 +248,19 @@ class TestFormatReindentAligned(TestCaseBase): group by a,z) order by 1,2 """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select *,', - ' sum_b + 2 as mod_sum', - ' from (', - ' select a,', - ' sum(b) as sum_b', - ' from table', - ' group by a,', - ' z', - ' )', - ' order by 1,', - ' 2', - ])) + assert self.formatter(sql) == '\n'.join([ + 'select *,', + ' sum_b + 2 as mod_sum', + ' from (', + ' select a,', + ' sum(b) as sum_b', + ' from table', + ' group by a,', + ' z', + ' )', + ' order by 1,', + ' 2', + ]) def test_window_functions(self): sql = """ @@ -286,16 +271,14 @@ class TestFormatReindentAligned(TestCaseBase): (PARTITION BY b, c ORDER BY d DESC) as row_num from table """ - self.ndiffAssertEqual( - self.formatter(sql), - '\n'.join([ - 'select a,', - (' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS ' - 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,'), - (' ROW_NUMBER() OVER ' - '(PARTITION BY b, c ORDER BY d DESC) as row_num'), - ' from table', - ])) + assert self.formatter(sql) == '\n'.join([ + 'select a,', + (' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS ' + 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,'), + (' ROW_NUMBER() OVER ' + '(PARTITION BY b, c ORDER BY d DESC) as row_num'), + ' from table', + ]) class TestSpacesAroundOperators(TestCaseBase): @@ -306,36 +289,27 @@ class TestSpacesAroundOperators(TestCaseBase): def test_basic(self): sql = ('select a+b as d from table ' 'where (c-d)%2= 1 and e> 3.0/4 and z^2 <100') - self.ndiffAssertEqual( - self.formatter(sql), ( - 'select a + b as d from table ' - 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100') + assert self.formatter(sql) == ( + 'select a + b as d from table ' + 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100' ) def test_bools(self): sql = 'select * from table where a &&b or c||d' - self.ndiffAssertEqual( - self.formatter(sql), - 'select * from table where a && b or c || d' - ) + assert self.formatter( + sql) == 'select * from table where a && b or c || d' def test_nested(self): sql = 'select *, case when a-b then c end from table' - self.ndiffAssertEqual( - self.formatter(sql), - 'select *, case when a - b then c end from table' - ) + assert self.formatter( + sql) == 'select *, case when a - b then c end from table' def test_wildcard_vs_mult(self): sql = 'select a*b-c from table' - self.ndiffAssertEqual( - self.formatter(sql), - 'select a * b - c from table' - ) + assert self.formatter(sql) == 'select a * b - c from table' class TestFormatReindent(TestCaseBase): - def test_option(self): self.assertRaises(SQLParseError, sqlparse.format, 'foo', reindent=2) @@ -353,206 +327,204 @@ class TestFormatReindent(TestCaseBase): def test_stmts(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo; select bar' - self.ndiffAssertEqual(f(s), 'select foo;\n\nselect bar') + assert f(s) == 'select foo;\n\nselect bar' s = 'select foo' - self.ndiffAssertEqual(f(s), 'select foo') + assert f(s) == 'select foo' s = 'select foo; -- test\n select bar' - self.ndiffAssertEqual(f(s), 'select foo; -- test\n\nselect bar') + assert f(s) == 'select foo; -- test\n\nselect bar' def test_keywords(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo union select * from bar;' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'union', - 'select *', - 'from bar;'])) + assert f(s) == '\n'.join(['select *', + 'from foo', + 'union', + 'select *', + 'from bar;']) def test_keywords_between(self): # issue 14 # don't break AND after BETWEEN f = lambda sql: sqlparse.format(sql, reindent=True) s = 'and foo between 1 and 2 and bar = 3' - self.ndiffAssertEqual(f(s), '\n'.join(['', - 'and foo between 1 and 2', - 'and bar = 3'])) + assert f(s) == '\n'.join(['', + 'and foo between 1 and 2', + 'and bar = 3']) def test_parenthesis(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select count(*) from (select * from foo);' - self.ndiffAssertEqual(f(s), - '\n'.join(['select count(*)', - 'from', - ' (select *', - ' from foo);', - ]) - ) + assert f(s) == '\n'.join(['select count(*)', + 'from', + ' (select *', + ' from foo);', + ] + ) def test_where(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;' - self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n' - 'where bar = 1\n' - ' and baz = 2\n' - ' or bzz = 3;')) + assert f(s) == ('select *\nfrom foo\n' + 'where bar = 1\n' + ' and baz = 2\n' + ' or bzz = 3;') s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);' - self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n' - 'where bar = 1\n' - ' and (baz = 2\n' - ' or bzz = 3);')) + assert f(s) == ('select *\nfrom foo\n' + 'where bar = 1\n' + ' and (baz = 2\n' + ' or bzz = 3);') def test_join(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'join bar on 1 = 2'])) + assert f(s) == '\n'.join(['select *', + 'from foo', + 'join bar on 1 = 2']) s = 'select * from foo inner join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'inner join bar on 1 = 2'])) + assert f(s) == '\n'.join(['select *', + 'from foo', + 'inner join bar on 1 = 2']) s = 'select * from foo left outer join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'left outer join bar on 1 = 2'] - )) + assert f(s) == '\n'.join(['select *', + 'from foo', + 'left outer join bar on 1 = 2'] + ) s = 'select * from foo straight_join bar on 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select *', - 'from foo', - 'straight_join bar on 1 = 2'] - )) + assert f(s) == '\n'.join(['select *', + 'from foo', + 'straight_join bar on 1 = 2'] + ) def test_identifier_list(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo, bar, baz from table1, table2 where 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select foo,', - ' bar,', - ' baz', - 'from table1,', - ' table2', - 'where 1 = 2'])) + assert f(s) == '\n'.join(['select foo,', + ' bar,', + ' baz', + 'from table1,', + ' table2', + 'where 1 = 2']) s = 'select a.*, b.id from a, b' - self.ndiffAssertEqual(f(s), '\n'.join(['select a.*,', - ' b.id', - 'from a,', - ' b'])) + assert f(s) == '\n'.join(['select a.*,', + ' b.id', + 'from a,', + ' b']) def test_identifier_list_with_wrap_after(self): f = lambda sql: sqlparse.format(sql, reindent=True, wrap_after=14) s = 'select foo, bar, baz from table1, table2 where 1 = 2' - self.ndiffAssertEqual(f(s), '\n'.join(['select foo, bar,', - ' baz', - 'from table1, table2', - 'where 1 = 2'])) + assert f(s) == '\n'.join(['select foo, bar,', + ' baz', + 'from table1, table2', + 'where 1 = 2']) def test_identifier_list_with_functions(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = ("select 'abc' as foo, coalesce(col1, col2)||col3 as bar," "col3 from my_table") - self.ndiffAssertEqual(f(s), '\n'.join( + assert f(s) == '\n'.join( ["select 'abc' as foo,", " coalesce(col1, col2)||col3 as bar,", " col3", - "from my_table"])) + "from my_table"]) def test_case(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end' - self.ndiffAssertEqual(f(s), '\n'.join(['case', - ' when foo = 1 then 2', - ' when foo = 3 then 4', - ' else 5', - 'end'])) + assert f(s) == '\n'.join(['case', + ' when foo = 1 then 2', + ' when foo = 3 then 4', + ' else 5', + 'end']) def test_case2(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'case(foo) when bar = 1 then 2 else 3 end' - self.ndiffAssertEqual(f(s), '\n'.join(['case(foo)', - ' when bar = 1 then 2', - ' else 3', - 'end'])) + assert f(s) == '\n'.join(['case(foo)', + ' when bar = 1 then 2', + ' else 3', + 'end']) def test_nested_identifier_list(self): # issue4 f = lambda sql: sqlparse.format(sql, reindent=True) s = '(foo as bar, bar1, bar2 as bar3, b4 as b5)' - self.ndiffAssertEqual(f(s), '\n'.join(['(foo as bar,', - ' bar1,', - ' bar2 as bar3,', - ' b4 as b5)'])) + assert f(s) == '\n'.join(['(foo as bar,', + ' bar1,', + ' bar2 as bar3,', + ' b4 as b5)']) def test_duplicate_linebreaks(self): # issue3 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select c1 -- column1\nfrom foo' - self.ndiffAssertEqual(f(s), '\n'.join(['select c1 -- column1', - 'from foo'])) + assert f(s) == '\n'.join(['select c1 -- column1', + 'from foo']) s = 'select c1 -- column1\nfrom foo' r = sqlparse.format(s, reindent=True, strip_comments=True) - self.ndiffAssertEqual(r, '\n'.join(['select c1', - 'from foo'])) + assert r == '\n'.join(['select c1', + 'from foo']) s = 'select c1\nfrom foo\norder by c1' - self.ndiffAssertEqual(f(s), '\n'.join(['select c1', - 'from foo', - 'order by c1'])) + assert f(s) == '\n'.join(['select c1', + 'from foo', + 'order by c1']) s = 'select c1 from t1 where (c1 = 1) order by c1' - self.ndiffAssertEqual(f(s), '\n'.join(['select c1', - 'from t1', - 'where (c1 = 1)', - 'order by c1'])) + assert f(s) == '\n'.join(['select c1', + 'from t1', + 'where (c1 = 1)', + 'order by c1']) def test_keywordfunctions(self): # issue36 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select max(a) b, foo, bar' - self.ndiffAssertEqual(f(s), '\n'.join(['select max(a) b,', - ' foo,', - ' bar'])) + assert f(s) == '\n'.join(['select max(a) b,', + ' foo,', + ' bar']) def test_identifier_and_functions(self): # issue45 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo.bar, nvl(1) from dual' - self.ndiffAssertEqual(f(s), '\n'.join(['select foo.bar,', - ' nvl(1)', - 'from dual'])) + assert f(s) == '\n'.join(['select foo.bar,', + ' nvl(1)', + 'from dual']) class TestOutputFormat(TestCaseBase): - def test_python(self): sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='python') - self.ndiffAssertEqual(f(sql), "sql = 'select * from foo;'") + assert f(sql) == "sql = 'select * from foo;'" f = lambda sql: sqlparse.format(sql, output_format='python', reindent=True) - self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n" - " 'from foo;')")) + assert f(sql) == ("sql = ('select * '\n" + " 'from foo;')") def test_python_multiple_statements(self): sql = 'select * from foo; select 1 from dual' f = lambda sql: sqlparse.format(sql, output_format='python') - self.ndiffAssertEqual(f(sql), ("sql = 'select * from foo; '\n" - "sql2 = 'select 1 from dual'")) + assert f(sql) == ("sql = 'select * from foo; '\n" + "sql2 = 'select 1 from dual'") @pytest.mark.xfail(reason="Needs fixing") def test_python_multiple_statements_with_formatting(self): sql = 'select * from foo; select 1 from dual' f = lambda sql: sqlparse.format(sql, output_format='python', reindent=True) - self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n" - " 'from foo;')\n" - "sql2 = ('select 1 '\n" - " 'from dual')")) + assert f(sql) == ("sql = ('select * '\n" + " 'from foo;')\n" + "sql2 = ('select 1 '\n" + " 'from dual')") def test_php(self): sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='php') - self.ndiffAssertEqual(f(sql), '$sql = "select * from foo;";') + assert f(sql) == '$sql = "select * from foo;";' f = lambda sql: sqlparse.format(sql, output_format='php', reindent=True) - self.ndiffAssertEqual(f(sql), ('$sql = "select * ";\n' - '$sql .= "from foo;";')) + assert f(sql) == ('$sql = "select * ";\n' + '$sql .= "from foo;";') def test_sql(self): # "sql" is an allowed option but has no effect sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='sql') - self.ndiffAssertEqual(f(sql), 'select * from foo;') + assert f(sql) == 'select * from foo;' def test_format_column_ordering(): # issue89 diff --git a/tests/test_grouping.py b/tests/test_grouping.py index bf6bfeb..e32d82c 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -6,29 +6,27 @@ import sqlparse from sqlparse import sql from sqlparse import tokens as T from sqlparse.compat import u - from tests.utils import TestCaseBase class TestGrouping(TestCaseBase): - def test_parenthesis(self): s = 'select (select (x3) x2) and (y2) bar' parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, str(parsed)) + assert s == str(parsed) self.assertEqual(len(parsed.tokens), 7) self.assert_(isinstance(parsed.tokens[2], sql.Parenthesis)) self.assert_(isinstance(parsed.tokens[-1], sql.Identifier)) self.assertEqual(len(parsed.tokens[2].tokens), 5) self.assert_(isinstance(parsed.tokens[2].tokens[3], sql.Identifier)) self.assert_(isinstance(parsed.tokens[2].tokens[3].tokens[0], - sql.Parenthesis)) + sql.Parenthesis)) self.assertEqual(len(parsed.tokens[2].tokens[3].tokens), 3) def test_comments(self): s = '/*\n * foo\n */ \n bar' parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(parsed)) + assert s == u(parsed) self.assertEqual(len(parsed.tokens), 2) def test_assignment(self): @@ -44,18 +42,18 @@ class TestGrouping(TestCaseBase): def test_identifiers(self): s = 'select foo.bar from "myscheme"."table" where fail. order' parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(parsed)) + assert s == u(parsed) self.assert_(isinstance(parsed.tokens[2], sql.Identifier)) self.assert_(isinstance(parsed.tokens[6], sql.Identifier)) self.assert_(isinstance(parsed.tokens[8], sql.Where)) s = 'select * from foo where foo.id = 1' parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(parsed)) + assert s == u(parsed) self.assert_(isinstance(parsed.tokens[-1].tokens[-1].tokens[0], sql.Identifier)) s = 'select * from (select "foo"."id" from foo)' parsed = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(parsed)) + assert s == u(parsed) self.assert_(isinstance(parsed.tokens[-1].tokens[3], sql.Identifier)) s = "INSERT INTO `test` VALUES('foo', 'bar');" @@ -170,50 +168,50 @@ class TestGrouping(TestCaseBase): p = sqlparse.parse('1, 2 desc, 3')[0] self.assert_(isinstance(p.tokens[0], sql.IdentifierList)) self.assert_(isinstance(p.tokens[0].tokens[3], sql.Identifier)) - self.ndiffAssertEqual(u(p.tokens[0].tokens[3]), '2 desc') + assert u(p.tokens[0].tokens[3]) == '2 desc' def test_where(self): s = 'select * from foo where bar = 1 order by id desc' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) + assert s == u(p) self.assert_(len(p.tokens) == 14) s = 'select x from (select y from foo where bar = 1) z' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) + assert s == u(p) self.assert_(isinstance(p.tokens[-1].tokens[0].tokens[-2], sql.Where)) def test_typecast(self): s = 'select foo::integer from bar' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) + assert s == u(p) self.assertEqual(p.tokens[2].get_typecast(), 'integer') self.assertEqual(p.tokens[2].get_name(), 'foo') s = 'select (current_database())::information_schema.sql_identifier' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) + assert s == u(p) self.assertEqual(p.tokens[2].get_typecast(), 'information_schema.sql_identifier') def test_alias(self): s = 'select foo as bar from mytable' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) + assert s == u(p) self.assertEqual(p.tokens[2].get_real_name(), 'foo') self.assertEqual(p.tokens[2].get_alias(), 'bar') s = 'select foo from mytable t1' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) + assert s == u(p) self.assertEqual(p.tokens[6].get_real_name(), 'mytable') self.assertEqual(p.tokens[6].get_alias(), 't1') s = 'select foo::integer as bar from mytable' p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) + assert s == u(p) self.assertEqual(p.tokens[2].get_alias(), 'bar') s = ('SELECT DISTINCT ' '(current_database())::information_schema.sql_identifier AS view') p = sqlparse.parse(s)[0] - self.ndiffAssertEqual(s, u(p)) + assert s == u(p) self.assertEqual(p.tokens[4].get_alias(), 'view') def test_alias_case(self): # see issue46 @@ -259,10 +257,10 @@ class TestGrouping(TestCaseBase): class TestStatement(TestCaseBase): - def test_get_type(self): def f(sql): return sqlparse.parse(sql)[0] + self.assertEqual(f('select * from foo').get_type(), 'SELECT') self.assertEqual(f('update foo').get_type(), 'UPDATE') self.assertEqual(f(' update foo').get_type(), 'UPDATE') diff --git a/tests/test_parse.py b/tests/test_parse.py index 75a7ab5..d8a8c27 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -4,12 +4,11 @@ import pytest -from tests.utils import TestCaseBase - import sqlparse import sqlparse.sql from sqlparse import tokens as T from sqlparse.compat import u, StringIO +from tests.utils import TestCaseBase class SQLParseTest(TestCaseBase): diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 14aab24..0887c40 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -3,15 +3,14 @@ import sys import pytest # noqa -from tests.utils import TestCaseBase, load_file import sqlparse from sqlparse import sql from sqlparse import tokens as T +from tests.utils import TestCaseBase, load_file class RegressionTests(TestCaseBase): - def test_issue9(self): # make sure where doesn't consume parenthesis p = sqlparse.parse('(where 1)')[0] @@ -57,16 +56,16 @@ class RegressionTests(TestCaseBase): # missing space before LIMIT sql = sqlparse.format("select * from foo where bar = 1 limit 1", reindent=True) - self.ndiffAssertEqual(sql, "\n".join(["select *", - "from foo", - "where bar = 1 limit 1"])) + assert sql == "\n".join(["select *", + "from foo", + "where bar = 1 limit 1"]) def test_issue38(self): sql = sqlparse.format("SELECT foo; -- comment", strip_comments=True) - self.ndiffAssertEqual(sql, "SELECT foo;") + assert sql == "SELECT foo;" sql = sqlparse.format("/* foo */", strip_comments=True) - self.ndiffAssertEqual(sql, "") + assert sql == "" def test_issue39(self): p = sqlparse.parse('select user.id from user')[0] @@ -89,26 +88,22 @@ class RegressionTests(TestCaseBase): sp = p.tokens[-1].tokens[0] self.assertEqual(sp.tokens[3].__class__, sql.IdentifierList) # make sure that formatting works as expected - self.ndiffAssertEqual( - sqlparse.format(('SELECT id, name FROM ' - '(SELECT id, name FROM bar)'), - reindent=True), - ('SELECT id,\n' - ' name\n' - 'FROM\n' - ' (SELECT id,\n' - ' name\n' - ' FROM bar)')) - self.ndiffAssertEqual( - sqlparse.format(('SELECT id, name FROM ' - '(SELECT id, name FROM bar) as foo'), - reindent=True), - ('SELECT id,\n' - ' name\n' - 'FROM\n' - ' (SELECT id,\n' - ' name\n' - ' FROM bar) as foo')) + assert sqlparse.format(('SELECT id == name FROM ' + '(SELECT id, name FROM bar)'), + reindent=True), ('SELECT id,\n' + ' name\n' + 'FROM\n' + ' (SELECT id,\n' + ' name\n' + ' FROM bar)') + assert sqlparse.format(('SELECT id == name FROM ' + '(SELECT id, name FROM bar) as foo'), + reindent=True), ('SELECT id,\n' + ' name\n' + 'FROM\n' + ' (SELECT id,\n' + ' name\n' + ' FROM bar) as foo') def test_issue78(): @@ -116,6 +111,7 @@ def test_issue78(): def _get_identifier(sql): p = sqlparse.parse(sql)[0] return p.tokens[2] + results = (('get_name', 'z'), ('get_real_name', 'y'), ('get_parent_name', 'x'), diff --git a/tests/test_split.py b/tests/test_split.py index 7c2645d..aa74aed 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -4,10 +4,11 @@ import types -from tests.utils import load_file, TestCaseBase +import pytest import sqlparse from sqlparse.compat import StringIO, u, text_type +from tests.utils import load_file, TestCaseBase class SQLSplitTest(TestCaseBase): @@ -20,8 +21,8 @@ class SQLSplitTest(TestCaseBase): sql2 = 'select * from foo where bar = \'foo;bar\';' stmts = sqlparse.parse(''.join([self._sql1, sql2])) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(u(stmts[0]), self._sql1) - self.ndiffAssertEqual(u(stmts[1]), sql2) + assert u(stmts[0]) == self._sql1 + assert u(stmts[1]) == sql2 def test_split_backslash(self): stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';") @@ -31,31 +32,31 @@ class SQLSplitTest(TestCaseBase): sql = load_file('function.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) + assert u(stmts[0]) == sql def test_create_function_psql(self): sql = load_file('function_psql.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) + assert u(stmts[0]) == sql def test_create_function_psql3(self): sql = load_file('function_psql3.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) + assert u(stmts[0]) == sql def test_create_function_psql2(self): sql = load_file('function_psql2.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(u(stmts[0]), sql) + assert u(stmts[0]) == sql def test_dashcomments(self): sql = load_file('dashcomment.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 3) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) + assert ''.join(u(q) for q in stmts) == sql def test_dashcomments_eol(self): stmts = sqlparse.parse('select foo; -- comment\n') @@ -71,19 +72,19 @@ class SQLSplitTest(TestCaseBase): sql = load_file('begintag.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 3) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) + assert ''.join(u(q) for q in stmts) == sql def test_begintag_2(self): sql = load_file('begintag_2.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) + assert ''.join(u(q) for q in stmts) == sql def test_dropif(self): sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;' stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) + assert ''.join(u(q) for q in stmts) == sql def test_comment_with_umlaut(self): sql = (u'select * from foo;\n' @@ -91,16 +92,16 @@ class SQLSplitTest(TestCaseBase): u'select * from bar;') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) + assert ''.join(u(q) for q in stmts) == sql def test_comment_end_of_line(self): sql = ('select * from foo; -- foo\n' 'select * from bar;') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql) + assert ''.join(u(q) for q in stmts) == sql # make sure the comment belongs to first query - self.ndiffAssertEqual(u(stmts[0]), 'select * from foo; -- foo\n') + assert u(stmts[0]) == 'select * from foo; -- foo\n' def test_casewhen(self): sql = ('SELECT case when val = 1 then 2 else null end as foo;\n' diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py index 61eaa3e..0446dfa 100644 --- a/tests/test_tokenize.py +++ b/tests/test_tokenize.py @@ -13,7 +13,6 @@ from sqlparse.compat import StringIO class TestTokenize(unittest.TestCase): - def test_simple(self): s = 'select * from foo;' stream = lexer.tokenize(s) @@ -74,7 +73,6 @@ class TestTokenize(unittest.TestCase): class TestToken(unittest.TestCase): - def test_str(self): token = sql.Token(None, 'FoO') self.assertEqual(str(token), 'FoO') @@ -96,7 +94,6 @@ class TestToken(unittest.TestCase): class TestTokenList(unittest.TestCase): - def test_repr(self): p = sqlparse.parse('foo, bar, baz')[0] tst = "<IdentifierList 'foo, b...' at 0x" |
