summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_format.py472
-rw-r--r--tests/test_grouping.py34
-rw-r--r--tests/test_parse.py3
-rw-r--r--tests/test_regressions.py50
-rw-r--r--tests/test_split.py29
-rw-r--r--tests/test_tokenize.py3
6 files changed, 277 insertions, 314 deletions
diff --git a/tests/test_format.py b/tests/test_format.py
index 74fce71..fc15b2f 100644
--- a/tests/test_format.py
+++ b/tests/test_format.py
@@ -2,75 +2,73 @@
import pytest
-from tests.utils import TestCaseBase
-
import sqlparse
from sqlparse.exceptions import SQLParseError
+from tests.utils import TestCaseBase
class TestFormat(TestCaseBase):
-
def test_keywordcase(self):
sql = 'select * from bar; -- select foo\n'
res = sqlparse.format(sql, keyword_case='upper')
- self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- select foo\n')
+ assert res == 'SELECT * FROM bar; -- select foo\n'
res = sqlparse.format(sql, keyword_case='capitalize')
- self.ndiffAssertEqual(res, 'Select * From bar; -- select foo\n')
+ assert res == 'Select * From bar; -- select foo\n'
res = sqlparse.format(sql.upper(), keyword_case='lower')
- self.ndiffAssertEqual(res, 'select * from BAR; -- SELECT FOO\n')
+ assert res == 'select * from BAR; -- SELECT FOO\n'
self.assertRaises(SQLParseError, sqlparse.format, sql,
keyword_case='foo')
def test_identifiercase(self):
sql = 'select * from bar; -- select foo\n'
res = sqlparse.format(sql, identifier_case='upper')
- self.ndiffAssertEqual(res, 'select * from BAR; -- select foo\n')
+ assert res == 'select * from BAR; -- select foo\n'
res = sqlparse.format(sql, identifier_case='capitalize')
- self.ndiffAssertEqual(res, 'select * from Bar; -- select foo\n')
+ assert res == 'select * from Bar; -- select foo\n'
res = sqlparse.format(sql.upper(), identifier_case='lower')
- self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- SELECT FOO\n')
+ assert res == 'SELECT * FROM bar; -- SELECT FOO\n'
self.assertRaises(SQLParseError, sqlparse.format, sql,
identifier_case='foo')
sql = 'select * from "foo"."bar"'
res = sqlparse.format(sql, identifier_case="upper")
- self.ndiffAssertEqual(res, 'select * from "foo"."bar"')
+ assert res == 'select * from "foo"."bar"'
def test_strip_comments_single(self):
sql = 'select *-- statement starts here\nfrom foo'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select * from foo')
+ assert res == 'select * from foo'
sql = 'select * -- statement starts here\nfrom foo'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select * from foo')
+ assert res == 'select * from foo'
sql = 'select-- foo\nfrom -- bar\nwhere'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select from where')
+ assert res == 'select from where'
self.assertRaises(SQLParseError, sqlparse.format, sql,
strip_comments=None)
def test_strip_comments_multi(self):
sql = '/* sql starts here */\nselect'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select')
+ assert res == 'select'
sql = '/* sql starts here */ select'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select')
+ assert res == 'select'
sql = '/*\n * sql starts here\n */\nselect'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select')
+ assert res == 'select'
sql = 'select (/* sql starts here */ select 2)'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select (select 2)')
+ assert res == 'select (select 2)'
sql = 'select (/* sql /* starts here */ select 2)'
res = sqlparse.format(sql, strip_comments=True)
- self.ndiffAssertEqual(res, 'select (select 2)')
+ assert res == 'select (select 2)'
def test_strip_ws(self):
f = lambda sql: sqlparse.format(sql, strip_whitespace=True)
s = 'select\n* from foo\n\twhere ( 1 = 2 )\n'
- self.ndiffAssertEqual(f(s), 'select * from foo where (1 = 2)')
+ assert f(s) == 'select * from foo where (1 = 2)'
s = 'select -- foo\nfrom bar\n'
- self.ndiffAssertEqual(f(s), 'select -- foo\nfrom bar')
+ assert f(s) == 'select -- foo\nfrom bar'
self.assertRaises(SQLParseError, sqlparse.format, s,
strip_whitespace=None)
@@ -78,7 +76,7 @@ class TestFormat(TestCaseBase):
# preserve at least one whitespace after subgroups
f = lambda sql: sqlparse.format(sql, strip_whitespace=True)
s = 'select\n* /* foo */ from bar '
- self.ndiffAssertEqual(f(s), 'select * /* foo */ from bar')
+ assert f(s) == 'select * /* foo */ from bar'
def test_notransform_of_quoted_crlf(self):
# Make sure that CR/CR+LF characters inside string literals don't get
@@ -92,13 +90,11 @@ class TestFormat(TestCaseBase):
f = lambda x: sqlparse.format(x)
# Because of the use of
- self.ndiffAssertEqual(f(s1), "SELECT some_column LIKE 'value\r'")
- self.ndiffAssertEqual(
- f(s2), "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n")
- self.ndiffAssertEqual(
- f(s3), "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n")
- self.ndiffAssertEqual(
- f(s4), "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n")
+ assert f(s1) == "SELECT some_column LIKE 'value\r'"
+ assert f(s2) == "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n"
+ assert f(s3) == "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n"
+ assert f(
+ s4) == "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n"
def test_outputformat(self):
sql = 'select * from foo;'
@@ -121,23 +117,22 @@ class TestFormatReindentAligned(TestCaseBase):
or d is 'blue'
limit 10
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select a,',
- ' b as bb,',
- ' c',
- ' from table',
- ' join (',
- ' select a * 2 as a',
- ' from new_table',
- ' ) other',
- ' on table.a = other.a',
- ' where c is true',
- ' and b between 3 and 4',
- " or d is 'blue'",
- ' limit 10',
- ]))
+
+ assert self.formatter(sql) == '\n'.join([
+ 'select a,',
+ ' b as bb,',
+ ' c',
+ ' from table',
+ ' join (',
+ ' select a * 2 as a',
+ ' from new_table',
+ ' ) other',
+ ' on table.a = other.a',
+ ' where c is true',
+ ' and b between 3 and 4',
+ " or d is 'blue'",
+ ' limit 10',
+ ])
def test_joins(self):
sql = """
@@ -148,22 +143,20 @@ class TestFormatReindentAligned(TestCaseBase):
cross join e on e.four = a.four
join f using (one, two, three)
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select *',
- ' from a',
- ' join b',
- ' on a.one = b.one',
- ' left join c',
- ' on c.two = a.two',
- ' and c.three = a.three',
- ' full outer join d',
- ' on d.three = a.three',
- ' cross join e',
- ' on e.four = a.four',
- ' join f using (one, two, three)',
- ]))
+ assert self.formatter(sql) == '\n'.join([
+ 'select *',
+ ' from a',
+ ' join b',
+ ' on a.one = b.one',
+ ' left join c',
+ ' on c.two = a.two',
+ ' and c.three = a.three',
+ ' full outer join d',
+ ' on d.three = a.three',
+ ' cross join e',
+ ' on e.four = a.four',
+ ' join f using (one, two, three)',
+ ])
def test_case_statement(self):
sql = """
@@ -178,20 +171,18 @@ class TestFormatReindentAligned(TestCaseBase):
where c is true
and b between 3 and 4
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select a,',
- ' case when a = 0 then 1',
- ' when bb = 1 then 1',
- ' when c = 2 then 2',
- ' else 0',
- ' end as d,',
- ' extra_col',
- ' from table',
- ' where c is true',
- ' and b between 3 and 4'
- ]))
+ assert self.formatter(sql) == '\n'.join([
+ 'select a,',
+ ' case when a = 0 then 1',
+ ' when bb = 1 then 1',
+ ' when c = 2 then 2',
+ ' else 0',
+ ' end as d,',
+ ' extra_col',
+ ' from table',
+ ' where c is true',
+ ' and b between 3 and 4'
+ ])
def test_case_statement_with_between(self):
sql = """
@@ -207,21 +198,19 @@ class TestFormatReindentAligned(TestCaseBase):
where c is true
and b between 3 and 4
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select a,',
- ' case when a = 0 then 1',
- ' when bb = 1 then 1',
- ' when c = 2 then 2',
- ' when d between 3 and 5 then 3',
- ' else 0',
- ' end as d,',
- ' extra_col',
- ' from table',
- ' where c is true',
- ' and b between 3 and 4'
- ]))
+ assert self.formatter(sql) == '\n'.join([
+ 'select a,',
+ ' case when a = 0 then 1',
+ ' when bb = 1 then 1',
+ ' when c = 2 then 2',
+ ' when d between 3 and 5 then 3',
+ ' else 0',
+ ' end as d,',
+ ' extra_col',
+ ' from table',
+ ' where c is true',
+ ' and b between 3 and 4'
+ ])
def test_group_by(self):
sql = """
@@ -232,24 +221,22 @@ class TestFormatReindentAligned(TestCaseBase):
and count(y) > 5
order by 3,2,1
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select a,',
- ' b,',
- ' c,',
- ' sum(x) as sum_x,',
- ' count(y) as cnt_y',
- ' from table',
- ' group by a,',
- ' b,',
- ' c',
- 'having sum(x) > 1',
- ' and count(y) > 5',
- ' order by 3,',
- ' 2,',
- ' 1',
- ]))
+ assert self.formatter(sql) == '\n'.join([
+ 'select a,',
+ ' b,',
+ ' c,',
+ ' sum(x) as sum_x,',
+ ' count(y) as cnt_y',
+ ' from table',
+ ' group by a,',
+ ' b,',
+ ' c',
+ 'having sum(x) > 1',
+ ' and count(y) > 5',
+ ' order by 3,',
+ ' 2,',
+ ' 1',
+ ])
def test_group_by_subquery(self):
# TODO: add subquery alias when test_identifier_list_subquery fixed
@@ -261,21 +248,19 @@ class TestFormatReindentAligned(TestCaseBase):
group by a,z)
order by 1,2
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select *,',
- ' sum_b + 2 as mod_sum',
- ' from (',
- ' select a,',
- ' sum(b) as sum_b',
- ' from table',
- ' group by a,',
- ' z',
- ' )',
- ' order by 1,',
- ' 2',
- ]))
+ assert self.formatter(sql) == '\n'.join([
+ 'select *,',
+ ' sum_b + 2 as mod_sum',
+ ' from (',
+ ' select a,',
+ ' sum(b) as sum_b',
+ ' from table',
+ ' group by a,',
+ ' z',
+ ' )',
+ ' order by 1,',
+ ' 2',
+ ])
def test_window_functions(self):
sql = """
@@ -286,16 +271,14 @@ class TestFormatReindentAligned(TestCaseBase):
(PARTITION BY b, c ORDER BY d DESC) as row_num
from table
"""
- self.ndiffAssertEqual(
- self.formatter(sql),
- '\n'.join([
- 'select a,',
- (' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS '
- 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,'),
- (' ROW_NUMBER() OVER '
- '(PARTITION BY b, c ORDER BY d DESC) as row_num'),
- ' from table',
- ]))
+ assert self.formatter(sql) == '\n'.join([
+ 'select a,',
+ (' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS '
+ 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,'),
+ (' ROW_NUMBER() OVER '
+ '(PARTITION BY b, c ORDER BY d DESC) as row_num'),
+ ' from table',
+ ])
class TestSpacesAroundOperators(TestCaseBase):
@@ -306,36 +289,27 @@ class TestSpacesAroundOperators(TestCaseBase):
def test_basic(self):
sql = ('select a+b as d from table '
'where (c-d)%2= 1 and e> 3.0/4 and z^2 <100')
- self.ndiffAssertEqual(
- self.formatter(sql), (
- 'select a + b as d from table '
- 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100')
+ assert self.formatter(sql) == (
+ 'select a + b as d from table '
+ 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100'
)
def test_bools(self):
sql = 'select * from table where a &&b or c||d'
- self.ndiffAssertEqual(
- self.formatter(sql),
- 'select * from table where a && b or c || d'
- )
+ assert self.formatter(
+ sql) == 'select * from table where a && b or c || d'
def test_nested(self):
sql = 'select *, case when a-b then c end from table'
- self.ndiffAssertEqual(
- self.formatter(sql),
- 'select *, case when a - b then c end from table'
- )
+ assert self.formatter(
+ sql) == 'select *, case when a - b then c end from table'
def test_wildcard_vs_mult(self):
sql = 'select a*b-c from table'
- self.ndiffAssertEqual(
- self.formatter(sql),
- 'select a * b - c from table'
- )
+ assert self.formatter(sql) == 'select a * b - c from table'
class TestFormatReindent(TestCaseBase):
-
def test_option(self):
self.assertRaises(SQLParseError, sqlparse.format, 'foo',
reindent=2)
@@ -353,206 +327,204 @@ class TestFormatReindent(TestCaseBase):
def test_stmts(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select foo; select bar'
- self.ndiffAssertEqual(f(s), 'select foo;\n\nselect bar')
+ assert f(s) == 'select foo;\n\nselect bar'
s = 'select foo'
- self.ndiffAssertEqual(f(s), 'select foo')
+ assert f(s) == 'select foo'
s = 'select foo; -- test\n select bar'
- self.ndiffAssertEqual(f(s), 'select foo; -- test\n\nselect bar')
+ assert f(s) == 'select foo; -- test\n\nselect bar'
def test_keywords(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select * from foo union select * from bar;'
- self.ndiffAssertEqual(f(s), '\n'.join(['select *',
- 'from foo',
- 'union',
- 'select *',
- 'from bar;']))
+ assert f(s) == '\n'.join(['select *',
+ 'from foo',
+ 'union',
+ 'select *',
+ 'from bar;'])
def test_keywords_between(self): # issue 14
# don't break AND after BETWEEN
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'and foo between 1 and 2 and bar = 3'
- self.ndiffAssertEqual(f(s), '\n'.join(['',
- 'and foo between 1 and 2',
- 'and bar = 3']))
+ assert f(s) == '\n'.join(['',
+ 'and foo between 1 and 2',
+ 'and bar = 3'])
def test_parenthesis(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select count(*) from (select * from foo);'
- self.ndiffAssertEqual(f(s),
- '\n'.join(['select count(*)',
- 'from',
- ' (select *',
- ' from foo);',
- ])
- )
+ assert f(s) == '\n'.join(['select count(*)',
+ 'from',
+ ' (select *',
+ ' from foo);',
+ ]
+ )
def test_where(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;'
- self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
- 'where bar = 1\n'
- ' and baz = 2\n'
- ' or bzz = 3;'))
+ assert f(s) == ('select *\nfrom foo\n'
+ 'where bar = 1\n'
+ ' and baz = 2\n'
+ ' or bzz = 3;')
s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);'
- self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
- 'where bar = 1\n'
- ' and (baz = 2\n'
- ' or bzz = 3);'))
+ assert f(s) == ('select *\nfrom foo\n'
+ 'where bar = 1\n'
+ ' and (baz = 2\n'
+ ' or bzz = 3);')
def test_join(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select * from foo join bar on 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select *',
- 'from foo',
- 'join bar on 1 = 2']))
+ assert f(s) == '\n'.join(['select *',
+ 'from foo',
+ 'join bar on 1 = 2'])
s = 'select * from foo inner join bar on 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select *',
- 'from foo',
- 'inner join bar on 1 = 2']))
+ assert f(s) == '\n'.join(['select *',
+ 'from foo',
+ 'inner join bar on 1 = 2'])
s = 'select * from foo left outer join bar on 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select *',
- 'from foo',
- 'left outer join bar on 1 = 2']
- ))
+ assert f(s) == '\n'.join(['select *',
+ 'from foo',
+ 'left outer join bar on 1 = 2']
+ )
s = 'select * from foo straight_join bar on 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select *',
- 'from foo',
- 'straight_join bar on 1 = 2']
- ))
+ assert f(s) == '\n'.join(['select *',
+ 'from foo',
+ 'straight_join bar on 1 = 2']
+ )
def test_identifier_list(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select foo, bar, baz from table1, table2 where 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select foo,',
- ' bar,',
- ' baz',
- 'from table1,',
- ' table2',
- 'where 1 = 2']))
+ assert f(s) == '\n'.join(['select foo,',
+ ' bar,',
+ ' baz',
+ 'from table1,',
+ ' table2',
+ 'where 1 = 2'])
s = 'select a.*, b.id from a, b'
- self.ndiffAssertEqual(f(s), '\n'.join(['select a.*,',
- ' b.id',
- 'from a,',
- ' b']))
+ assert f(s) == '\n'.join(['select a.*,',
+ ' b.id',
+ 'from a,',
+ ' b'])
def test_identifier_list_with_wrap_after(self):
f = lambda sql: sqlparse.format(sql, reindent=True, wrap_after=14)
s = 'select foo, bar, baz from table1, table2 where 1 = 2'
- self.ndiffAssertEqual(f(s), '\n'.join(['select foo, bar,',
- ' baz',
- 'from table1, table2',
- 'where 1 = 2']))
+ assert f(s) == '\n'.join(['select foo, bar,',
+ ' baz',
+ 'from table1, table2',
+ 'where 1 = 2'])
def test_identifier_list_with_functions(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = ("select 'abc' as foo, coalesce(col1, col2)||col3 as bar,"
"col3 from my_table")
- self.ndiffAssertEqual(f(s), '\n'.join(
+ assert f(s) == '\n'.join(
["select 'abc' as foo,",
" coalesce(col1, col2)||col3 as bar,",
" col3",
- "from my_table"]))
+ "from my_table"])
def test_case(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end'
- self.ndiffAssertEqual(f(s), '\n'.join(['case',
- ' when foo = 1 then 2',
- ' when foo = 3 then 4',
- ' else 5',
- 'end']))
+ assert f(s) == '\n'.join(['case',
+ ' when foo = 1 then 2',
+ ' when foo = 3 then 4',
+ ' else 5',
+ 'end'])
def test_case2(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'case(foo) when bar = 1 then 2 else 3 end'
- self.ndiffAssertEqual(f(s), '\n'.join(['case(foo)',
- ' when bar = 1 then 2',
- ' else 3',
- 'end']))
+ assert f(s) == '\n'.join(['case(foo)',
+ ' when bar = 1 then 2',
+ ' else 3',
+ 'end'])
def test_nested_identifier_list(self): # issue4
f = lambda sql: sqlparse.format(sql, reindent=True)
s = '(foo as bar, bar1, bar2 as bar3, b4 as b5)'
- self.ndiffAssertEqual(f(s), '\n'.join(['(foo as bar,',
- ' bar1,',
- ' bar2 as bar3,',
- ' b4 as b5)']))
+ assert f(s) == '\n'.join(['(foo as bar,',
+ ' bar1,',
+ ' bar2 as bar3,',
+ ' b4 as b5)'])
def test_duplicate_linebreaks(self): # issue3
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select c1 -- column1\nfrom foo'
- self.ndiffAssertEqual(f(s), '\n'.join(['select c1 -- column1',
- 'from foo']))
+ assert f(s) == '\n'.join(['select c1 -- column1',
+ 'from foo'])
s = 'select c1 -- column1\nfrom foo'
r = sqlparse.format(s, reindent=True, strip_comments=True)
- self.ndiffAssertEqual(r, '\n'.join(['select c1',
- 'from foo']))
+ assert r == '\n'.join(['select c1',
+ 'from foo'])
s = 'select c1\nfrom foo\norder by c1'
- self.ndiffAssertEqual(f(s), '\n'.join(['select c1',
- 'from foo',
- 'order by c1']))
+ assert f(s) == '\n'.join(['select c1',
+ 'from foo',
+ 'order by c1'])
s = 'select c1 from t1 where (c1 = 1) order by c1'
- self.ndiffAssertEqual(f(s), '\n'.join(['select c1',
- 'from t1',
- 'where (c1 = 1)',
- 'order by c1']))
+ assert f(s) == '\n'.join(['select c1',
+ 'from t1',
+ 'where (c1 = 1)',
+ 'order by c1'])
def test_keywordfunctions(self): # issue36
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select max(a) b, foo, bar'
- self.ndiffAssertEqual(f(s), '\n'.join(['select max(a) b,',
- ' foo,',
- ' bar']))
+ assert f(s) == '\n'.join(['select max(a) b,',
+ ' foo,',
+ ' bar'])
def test_identifier_and_functions(self): # issue45
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select foo.bar, nvl(1) from dual'
- self.ndiffAssertEqual(f(s), '\n'.join(['select foo.bar,',
- ' nvl(1)',
- 'from dual']))
+ assert f(s) == '\n'.join(['select foo.bar,',
+ ' nvl(1)',
+ 'from dual'])
class TestOutputFormat(TestCaseBase):
-
def test_python(self):
sql = 'select * from foo;'
f = lambda sql: sqlparse.format(sql, output_format='python')
- self.ndiffAssertEqual(f(sql), "sql = 'select * from foo;'")
+ assert f(sql) == "sql = 'select * from foo;'"
f = lambda sql: sqlparse.format(sql, output_format='python',
reindent=True)
- self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n"
- " 'from foo;')"))
+ assert f(sql) == ("sql = ('select * '\n"
+ " 'from foo;')")
def test_python_multiple_statements(self):
sql = 'select * from foo; select 1 from dual'
f = lambda sql: sqlparse.format(sql, output_format='python')
- self.ndiffAssertEqual(f(sql), ("sql = 'select * from foo; '\n"
- "sql2 = 'select 1 from dual'"))
+ assert f(sql) == ("sql = 'select * from foo; '\n"
+ "sql2 = 'select 1 from dual'")
@pytest.mark.xfail(reason="Needs fixing")
def test_python_multiple_statements_with_formatting(self):
sql = 'select * from foo; select 1 from dual'
f = lambda sql: sqlparse.format(sql, output_format='python',
reindent=True)
- self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n"
- " 'from foo;')\n"
- "sql2 = ('select 1 '\n"
- " 'from dual')"))
+ assert f(sql) == ("sql = ('select * '\n"
+ " 'from foo;')\n"
+ "sql2 = ('select 1 '\n"
+ " 'from dual')")
def test_php(self):
sql = 'select * from foo;'
f = lambda sql: sqlparse.format(sql, output_format='php')
- self.ndiffAssertEqual(f(sql), '$sql = "select * from foo;";')
+ assert f(sql) == '$sql = "select * from foo;";'
f = lambda sql: sqlparse.format(sql, output_format='php',
reindent=True)
- self.ndiffAssertEqual(f(sql), ('$sql = "select * ";\n'
- '$sql .= "from foo;";'))
+ assert f(sql) == ('$sql = "select * ";\n'
+ '$sql .= "from foo;";')
def test_sql(self): # "sql" is an allowed option but has no effect
sql = 'select * from foo;'
f = lambda sql: sqlparse.format(sql, output_format='sql')
- self.ndiffAssertEqual(f(sql), 'select * from foo;')
+ assert f(sql) == 'select * from foo;'
def test_format_column_ordering(): # issue89
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index bf6bfeb..e32d82c 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -6,29 +6,27 @@ import sqlparse
from sqlparse import sql
from sqlparse import tokens as T
from sqlparse.compat import u
-
from tests.utils import TestCaseBase
class TestGrouping(TestCaseBase):
-
def test_parenthesis(self):
s = 'select (select (x3) x2) and (y2) bar'
parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, str(parsed))
+ assert s == str(parsed)
self.assertEqual(len(parsed.tokens), 7)
self.assert_(isinstance(parsed.tokens[2], sql.Parenthesis))
self.assert_(isinstance(parsed.tokens[-1], sql.Identifier))
self.assertEqual(len(parsed.tokens[2].tokens), 5)
self.assert_(isinstance(parsed.tokens[2].tokens[3], sql.Identifier))
self.assert_(isinstance(parsed.tokens[2].tokens[3].tokens[0],
- sql.Parenthesis))
+ sql.Parenthesis))
self.assertEqual(len(parsed.tokens[2].tokens[3].tokens), 3)
def test_comments(self):
s = '/*\n * foo\n */ \n bar'
parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(parsed))
+ assert s == u(parsed)
self.assertEqual(len(parsed.tokens), 2)
def test_assignment(self):
@@ -44,18 +42,18 @@ class TestGrouping(TestCaseBase):
def test_identifiers(self):
s = 'select foo.bar from "myscheme"."table" where fail. order'
parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(parsed))
+ assert s == u(parsed)
self.assert_(isinstance(parsed.tokens[2], sql.Identifier))
self.assert_(isinstance(parsed.tokens[6], sql.Identifier))
self.assert_(isinstance(parsed.tokens[8], sql.Where))
s = 'select * from foo where foo.id = 1'
parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(parsed))
+ assert s == u(parsed)
self.assert_(isinstance(parsed.tokens[-1].tokens[-1].tokens[0],
sql.Identifier))
s = 'select * from (select "foo"."id" from foo)'
parsed = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(parsed))
+ assert s == u(parsed)
self.assert_(isinstance(parsed.tokens[-1].tokens[3], sql.Identifier))
s = "INSERT INTO `test` VALUES('foo', 'bar');"
@@ -170,50 +168,50 @@ class TestGrouping(TestCaseBase):
p = sqlparse.parse('1, 2 desc, 3')[0]
self.assert_(isinstance(p.tokens[0], sql.IdentifierList))
self.assert_(isinstance(p.tokens[0].tokens[3], sql.Identifier))
- self.ndiffAssertEqual(u(p.tokens[0].tokens[3]), '2 desc')
+ assert u(p.tokens[0].tokens[3]) == '2 desc'
def test_where(self):
s = 'select * from foo where bar = 1 order by id desc'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
+ assert s == u(p)
self.assert_(len(p.tokens) == 14)
s = 'select x from (select y from foo where bar = 1) z'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
+ assert s == u(p)
self.assert_(isinstance(p.tokens[-1].tokens[0].tokens[-2], sql.Where))
def test_typecast(self):
s = 'select foo::integer from bar'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
+ assert s == u(p)
self.assertEqual(p.tokens[2].get_typecast(), 'integer')
self.assertEqual(p.tokens[2].get_name(), 'foo')
s = 'select (current_database())::information_schema.sql_identifier'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
+ assert s == u(p)
self.assertEqual(p.tokens[2].get_typecast(),
'information_schema.sql_identifier')
def test_alias(self):
s = 'select foo as bar from mytable'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
+ assert s == u(p)
self.assertEqual(p.tokens[2].get_real_name(), 'foo')
self.assertEqual(p.tokens[2].get_alias(), 'bar')
s = 'select foo from mytable t1'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
+ assert s == u(p)
self.assertEqual(p.tokens[6].get_real_name(), 'mytable')
self.assertEqual(p.tokens[6].get_alias(), 't1')
s = 'select foo::integer as bar from mytable'
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
+ assert s == u(p)
self.assertEqual(p.tokens[2].get_alias(), 'bar')
s = ('SELECT DISTINCT '
'(current_database())::information_schema.sql_identifier AS view')
p = sqlparse.parse(s)[0]
- self.ndiffAssertEqual(s, u(p))
+ assert s == u(p)
self.assertEqual(p.tokens[4].get_alias(), 'view')
def test_alias_case(self): # see issue46
@@ -259,10 +257,10 @@ class TestGrouping(TestCaseBase):
class TestStatement(TestCaseBase):
-
def test_get_type(self):
def f(sql):
return sqlparse.parse(sql)[0]
+
self.assertEqual(f('select * from foo').get_type(), 'SELECT')
self.assertEqual(f('update foo').get_type(), 'UPDATE')
self.assertEqual(f(' update foo').get_type(), 'UPDATE')
diff --git a/tests/test_parse.py b/tests/test_parse.py
index 75a7ab5..d8a8c27 100644
--- a/tests/test_parse.py
+++ b/tests/test_parse.py
@@ -4,12 +4,11 @@
import pytest
-from tests.utils import TestCaseBase
-
import sqlparse
import sqlparse.sql
from sqlparse import tokens as T
from sqlparse.compat import u, StringIO
+from tests.utils import TestCaseBase
class SQLParseTest(TestCaseBase):
diff --git a/tests/test_regressions.py b/tests/test_regressions.py
index 14aab24..0887c40 100644
--- a/tests/test_regressions.py
+++ b/tests/test_regressions.py
@@ -3,15 +3,14 @@
import sys
import pytest # noqa
-from tests.utils import TestCaseBase, load_file
import sqlparse
from sqlparse import sql
from sqlparse import tokens as T
+from tests.utils import TestCaseBase, load_file
class RegressionTests(TestCaseBase):
-
def test_issue9(self):
# make sure where doesn't consume parenthesis
p = sqlparse.parse('(where 1)')[0]
@@ -57,16 +56,16 @@ class RegressionTests(TestCaseBase):
# missing space before LIMIT
sql = sqlparse.format("select * from foo where bar = 1 limit 1",
reindent=True)
- self.ndiffAssertEqual(sql, "\n".join(["select *",
- "from foo",
- "where bar = 1 limit 1"]))
+ assert sql == "\n".join(["select *",
+ "from foo",
+ "where bar = 1 limit 1"])
def test_issue38(self):
sql = sqlparse.format("SELECT foo; -- comment",
strip_comments=True)
- self.ndiffAssertEqual(sql, "SELECT foo;")
+ assert sql == "SELECT foo;"
sql = sqlparse.format("/* foo */", strip_comments=True)
- self.ndiffAssertEqual(sql, "")
+ assert sql == ""
def test_issue39(self):
p = sqlparse.parse('select user.id from user')[0]
@@ -89,26 +88,22 @@ class RegressionTests(TestCaseBase):
sp = p.tokens[-1].tokens[0]
self.assertEqual(sp.tokens[3].__class__, sql.IdentifierList)
# make sure that formatting works as expected
- self.ndiffAssertEqual(
- sqlparse.format(('SELECT id, name FROM '
- '(SELECT id, name FROM bar)'),
- reindent=True),
- ('SELECT id,\n'
- ' name\n'
- 'FROM\n'
- ' (SELECT id,\n'
- ' name\n'
- ' FROM bar)'))
- self.ndiffAssertEqual(
- sqlparse.format(('SELECT id, name FROM '
- '(SELECT id, name FROM bar) as foo'),
- reindent=True),
- ('SELECT id,\n'
- ' name\n'
- 'FROM\n'
- ' (SELECT id,\n'
- ' name\n'
- ' FROM bar) as foo'))
+ assert sqlparse.format(('SELECT id == name FROM '
+ '(SELECT id, name FROM bar)'),
+ reindent=True), ('SELECT id,\n'
+ ' name\n'
+ 'FROM\n'
+ ' (SELECT id,\n'
+ ' name\n'
+ ' FROM bar)')
+ assert sqlparse.format(('SELECT id == name FROM '
+ '(SELECT id, name FROM bar) as foo'),
+ reindent=True), ('SELECT id,\n'
+ ' name\n'
+ 'FROM\n'
+ ' (SELECT id,\n'
+ ' name\n'
+ ' FROM bar) as foo')
def test_issue78():
@@ -116,6 +111,7 @@ def test_issue78():
def _get_identifier(sql):
p = sqlparse.parse(sql)[0]
return p.tokens[2]
+
results = (('get_name', 'z'),
('get_real_name', 'y'),
('get_parent_name', 'x'),
diff --git a/tests/test_split.py b/tests/test_split.py
index 7c2645d..aa74aed 100644
--- a/tests/test_split.py
+++ b/tests/test_split.py
@@ -4,10 +4,11 @@
import types
-from tests.utils import load_file, TestCaseBase
+import pytest
import sqlparse
from sqlparse.compat import StringIO, u, text_type
+from tests.utils import load_file, TestCaseBase
class SQLSplitTest(TestCaseBase):
@@ -20,8 +21,8 @@ class SQLSplitTest(TestCaseBase):
sql2 = 'select * from foo where bar = \'foo;bar\';'
stmts = sqlparse.parse(''.join([self._sql1, sql2]))
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(u(stmts[0]), self._sql1)
- self.ndiffAssertEqual(u(stmts[1]), sql2)
+ assert u(stmts[0]) == self._sql1
+ assert u(stmts[1]) == sql2
def test_split_backslash(self):
stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';")
@@ -31,31 +32,31 @@ class SQLSplitTest(TestCaseBase):
sql = load_file('function.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
+ assert u(stmts[0]) == sql
def test_create_function_psql(self):
sql = load_file('function_psql.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
+ assert u(stmts[0]) == sql
def test_create_function_psql3(self):
sql = load_file('function_psql3.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
+ assert u(stmts[0]) == sql
def test_create_function_psql2(self):
sql = load_file('function_psql2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(u(stmts[0]), sql)
+ assert u(stmts[0]) == sql
def test_dashcomments(self):
sql = load_file('dashcomment.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
+ assert ''.join(u(q) for q in stmts) == sql
def test_dashcomments_eol(self):
stmts = sqlparse.parse('select foo; -- comment\n')
@@ -71,19 +72,19 @@ class SQLSplitTest(TestCaseBase):
sql = load_file('begintag.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
+ assert ''.join(u(q) for q in stmts) == sql
def test_begintag_2(self):
sql = load_file('begintag_2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
+ assert ''.join(u(q) for q in stmts) == sql
def test_dropif(self):
sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
+ assert ''.join(u(q) for q in stmts) == sql
def test_comment_with_umlaut(self):
sql = (u'select * from foo;\n'
@@ -91,16 +92,16 @@ class SQLSplitTest(TestCaseBase):
u'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
+ assert ''.join(u(q) for q in stmts) == sql
def test_comment_end_of_line(self):
sql = ('select * from foo; -- foo\n'
'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(u(q) for q in stmts), sql)
+ assert ''.join(u(q) for q in stmts) == sql
# make sure the comment belongs to first query
- self.ndiffAssertEqual(u(stmts[0]), 'select * from foo; -- foo\n')
+ assert u(stmts[0]) == 'select * from foo; -- foo\n'
def test_casewhen(self):
sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py
index 61eaa3e..0446dfa 100644
--- a/tests/test_tokenize.py
+++ b/tests/test_tokenize.py
@@ -13,7 +13,6 @@ from sqlparse.compat import StringIO
class TestTokenize(unittest.TestCase):
-
def test_simple(self):
s = 'select * from foo;'
stream = lexer.tokenize(s)
@@ -74,7 +73,6 @@ class TestTokenize(unittest.TestCase):
class TestToken(unittest.TestCase):
-
def test_str(self):
token = sql.Token(None, 'FoO')
self.assertEqual(str(token), 'FoO')
@@ -96,7 +94,6 @@ class TestToken(unittest.TestCase):
class TestTokenList(unittest.TestCase):
-
def test_repr(self):
p = sqlparse.parse('foo, bar, baz')[0]
tst = "<IdentifierList 'foo, b...' at 0x"