diff options
| author | Victor Uriarte <victor.m.uriarte@intel.com> | 2016-06-19 00:58:41 -0700 |
|---|---|---|
| committer | Victor Uriarte <victor.m.uriarte@intel.com> | 2016-06-20 07:41:13 -0700 |
| commit | 5d50f349cda37986bb3704e8fe25d57c78e6047a (patch) | |
| tree | 8c2b2061534767316e11f01a77bcf55c4c560889 | |
| parent | c9e8230502fd2c72833a2ea4d2c6bac9234e580a (diff) | |
| download | sqlparse-5d50f349cda37986bb3704e8fe25d57c78e6047a.tar.gz | |
Remove some test classes and clean-up
| -rw-r--r-- | tests/test_format.py | 332 | ||||
| -rw-r--r-- | tests/test_grouping.py | 587 | ||||
| -rw-r--r-- | tests/test_parse.py | 417 | ||||
| -rw-r--r-- | tests/test_regressions.py | 241 | ||||
| -rw-r--r-- | tests/test_split.py | 267 | ||||
| -rw-r--r-- | tests/test_tokenize.py | 282 |
6 files changed, 1102 insertions, 1024 deletions
diff --git a/tests/test_format.py b/tests/test_format.py index 0518b07..6cf4973 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -26,11 +26,11 @@ class TestFormat(object): assert res == 'select * from Bar; -- select foo\n' res = sqlparse.format(sql.upper(), identifier_case='lower') assert res == 'SELECT * FROM bar; -- SELECT FOO\n' - with pytest.raises(SQLParseError): - sqlparse.format(sql, identifier_case='foo') sql = 'select * from "foo"."bar"' res = sqlparse.format(sql, identifier_case="upper") assert res == 'select * from "foo"."bar"' + with pytest.raises(SQLParseError): + sqlparse.format(sql, identifier_case='foo') def test_strip_comments_single(self): sql = 'select *-- statement starts here\nfrom foo' @@ -130,8 +130,7 @@ class TestFormatReindentAligned(object): ' where c is true', ' and b between 3 and 4', " or d is 'blue'", - ' limit 10', - ]) + ' limit 10']) def test_joins(self): sql = """ @@ -154,8 +153,7 @@ class TestFormatReindentAligned(object): ' on d.three = a.three', ' cross join e', ' on e.four = a.four', - ' join f using (one, two, three)', - ]) + ' join f using (one, two, three)']) def test_case_statement(self): sql = """ @@ -180,8 +178,7 @@ class TestFormatReindentAligned(object): ' extra_col', ' from table', ' where c is true', - ' and b between 3 and 4' - ]) + ' and b between 3 and 4']) def test_case_statement_with_between(self): sql = """ @@ -208,8 +205,7 @@ class TestFormatReindentAligned(object): ' extra_col', ' from table', ' where c is true', - ' and b between 3 and 4' - ]) + ' and b between 3 and 4']) def test_group_by(self): sql = """ @@ -234,8 +230,7 @@ class TestFormatReindentAligned(object): ' and count(y) > 5', ' order by 3,', ' 2,', - ' 1', - ]) + ' 1']) def test_group_by_subquery(self): # TODO: add subquery alias when test_identifier_list_subquery fixed @@ -258,8 +253,7 @@ class TestFormatReindentAligned(object): ' z', ' )', ' order by 1,', - ' 2', - ]) + ' 2']) def test_window_functions(self): sql = """ @@ -268,16 +262,14 @@ class TestFormatReindentAligned(object): BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a, ROW_NUMBER() OVER (PARTITION BY b, c ORDER BY d DESC) as row_num - from table - """ + from table""" assert self.formatter(sql) == '\n'.join([ 'select a,', - (' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS ' - 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,'), - (' ROW_NUMBER() OVER ' - '(PARTITION BY b, c ORDER BY d DESC) as row_num'), - ' from table', - ]) + ' SUM(a) OVER (PARTITION BY b ORDER BY c ROWS ' + 'BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_a,', + ' ROW_NUMBER() OVER ' + '(PARTITION BY b, c ORDER BY d DESC) as row_num', + ' from table']) class TestSpacesAroundOperators(object): @@ -290,8 +282,7 @@ class TestSpacesAroundOperators(object): 'where (c-d)%2= 1 and e> 3.0/4 and z^2 <100') assert self.formatter(sql) == ( 'select a + b as d from table ' - 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100' - ) + 'where (c - d) % 2 = 1 and e > 3.0 / 4 and z ^ 2 < 100') def test_bools(self): sql = 'select * from table where a &&b or c||d' @@ -335,156 +326,179 @@ class TestFormatReindent(object): def test_keywords(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo union select * from bar;' - assert f(s) == '\n'.join(['select *', - 'from foo', - 'union', - 'select *', - 'from bar;']) + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'union', + 'select *', + 'from bar;']) - def test_keywords_between(self): # issue 14 + def test_keywords_between(self): + # issue 14 # don't break AND after BETWEEN f = lambda sql: sqlparse.format(sql, reindent=True) s = 'and foo between 1 and 2 and bar = 3' - assert f(s) == '\n'.join(['', - 'and foo between 1 and 2', - 'and bar = 3']) + assert f(s) == '\n'.join([ + '', + 'and foo between 1 and 2', + 'and bar = 3']) def test_parenthesis(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select count(*) from (select * from foo);' - assert f(s) == '\n'.join(['select count(*)', - 'from', - ' (select *', - ' from foo);', - ]) + assert f(s) == '\n'.join([ + 'select count(*)', + 'from', + ' (select *', + ' from foo);']) def test_where(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;' - assert f(s) == ('select *\nfrom foo\n' - 'where bar = 1\n' - ' and baz = 2\n' - ' or bzz = 3;') + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'where bar = 1', + ' and baz = 2', + ' or bzz = 3;']) + s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);' - assert f(s) == ('select *\nfrom foo\n' - 'where bar = 1\n' - ' and (baz = 2\n' - ' or bzz = 3);') + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'where bar = 1', + ' and (baz = 2', + ' or bzz = 3);']) def test_join(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo join bar on 1 = 2' - assert f(s) == '\n'.join(['select *', - 'from foo', - 'join bar on 1 = 2']) + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'join bar on 1 = 2']) s = 'select * from foo inner join bar on 1 = 2' - assert f(s) == '\n'.join(['select *', - 'from foo', - 'inner join bar on 1 = 2']) + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'inner join bar on 1 = 2']) s = 'select * from foo left outer join bar on 1 = 2' - assert f(s) == '\n'.join(['select *', - 'from foo', - 'left outer join bar on 1 = 2' - ]) + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'left outer join bar on 1 = 2']) s = 'select * from foo straight_join bar on 1 = 2' - assert f(s) == '\n'.join(['select *', - 'from foo', - 'straight_join bar on 1 = 2' - ]) + assert f(s) == '\n'.join([ + 'select *', + 'from foo', + 'straight_join bar on 1 = 2']) def test_identifier_list(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo, bar, baz from table1, table2 where 1 = 2' - assert f(s) == '\n'.join(['select foo,', - ' bar,', - ' baz', - 'from table1,', - ' table2', - 'where 1 = 2' - ]) + assert f(s) == '\n'.join([ + 'select foo,', + ' bar,', + ' baz', + 'from table1,', + ' table2', + 'where 1 = 2']) s = 'select a.*, b.id from a, b' - assert f(s) == '\n'.join(['select a.*,', - ' b.id', - 'from a,', - ' b' - ]) + assert f(s) == '\n'.join([ + 'select a.*,', + ' b.id', + 'from a,', + ' b']) def test_identifier_list_with_wrap_after(self): f = lambda sql: sqlparse.format(sql, reindent=True, wrap_after=14) s = 'select foo, bar, baz from table1, table2 where 1 = 2' - assert f(s) == '\n'.join(['select foo, bar,', - ' baz', - 'from table1, table2', - 'where 1 = 2' - ]) + assert f(s) == '\n'.join([ + 'select foo, bar,', + ' baz', + 'from table1, table2', + 'where 1 = 2']) def test_identifier_list_with_functions(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = ("select 'abc' as foo, coalesce(col1, col2)||col3 as bar," "col3 from my_table") - assert f(s) == '\n'.join( - ["select 'abc' as foo,", - " coalesce(col1, col2)||col3 as bar,", - " col3", - "from my_table"]) + assert f(s) == '\n'.join([ + "select 'abc' as foo,", + " coalesce(col1, col2)||col3 as bar,", + " col3", + "from my_table"]) def test_case(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end' - assert f(s) == '\n'.join(['case', - ' when foo = 1 then 2', - ' when foo = 3 then 4', - ' else 5', - 'end']) + assert f(s) == '\n'.join([ + 'case', + ' when foo = 1 then 2', + ' when foo = 3 then 4', + ' else 5', + 'end']) def test_case2(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'case(foo) when bar = 1 then 2 else 3 end' - assert f(s) == '\n'.join(['case(foo)', - ' when bar = 1 then 2', - ' else 3', - 'end']) - - def test_nested_identifier_list(self): # issue4 + assert f(s) == '\n'.join([ + 'case(foo)', + ' when bar = 1 then 2', + ' else 3', + 'end']) + + def test_nested_identifier_list(self): + # issue4 f = lambda sql: sqlparse.format(sql, reindent=True) s = '(foo as bar, bar1, bar2 as bar3, b4 as b5)' - assert f(s) == '\n'.join(['(foo as bar,', - ' bar1,', - ' bar2 as bar3,', - ' b4 as b5)']) - - def test_duplicate_linebreaks(self): # issue3 + assert f(s) == '\n'.join([ + '(foo as bar,', + ' bar1,', + ' bar2 as bar3,', + ' b4 as b5)']) + + def test_duplicate_linebreaks(self): + # issue3 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select c1 -- column1\nfrom foo' - assert f(s) == '\n'.join(['select c1 -- column1', - 'from foo']) + assert f(s) == '\n'.join([ + 'select c1 -- column1', + 'from foo']) s = 'select c1 -- column1\nfrom foo' r = sqlparse.format(s, reindent=True, strip_comments=True) - assert r == '\n'.join(['select c1', - 'from foo']) + assert r == '\n'.join([ + 'select c1', + 'from foo']) s = 'select c1\nfrom foo\norder by c1' - assert f(s) == '\n'.join(['select c1', - 'from foo', - 'order by c1']) + assert f(s) == '\n'.join([ + 'select c1', + 'from foo', + 'order by c1']) s = 'select c1 from t1 where (c1 = 1) order by c1' - assert f(s) == '\n'.join(['select c1', - 'from t1', - 'where (c1 = 1)', - 'order by c1']) - - def test_keywordfunctions(self): # issue36 + assert f(s) == '\n'.join([ + 'select c1', + 'from t1', + 'where (c1 = 1)', + 'order by c1']) + + def test_keywordfunctions(self): + # issue36 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select max(a) b, foo, bar' - assert f(s) == '\n'.join(['select max(a) b,', - ' foo,', - ' bar']) + assert f(s) == '\n'.join([ + 'select max(a) b,', + ' foo,', + ' bar']) - def test_identifier_and_functions(self): # issue45 + def test_identifier_and_functions(self): + # issue45 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo.bar, nvl(1) from dual' - assert f(s) == '\n'.join(['select foo.bar,', - ' nvl(1)', - 'from dual']) + assert f(s) == '\n'.join([ + 'select foo.bar,', + ' nvl(1)', + 'from dual']) class TestOutputFormat(object): @@ -494,24 +508,27 @@ class TestOutputFormat(object): assert f(sql) == "sql = 'select * from foo;'" f = lambda sql: sqlparse.format(sql, output_format='python', reindent=True) - assert f(sql) == ("sql = ('select * '\n" - " 'from foo;')") + assert f(sql) == '\n'.join([ + "sql = ('select * '", + " 'from foo;')"]) def test_python_multiple_statements(self): sql = 'select * from foo; select 1 from dual' f = lambda sql: sqlparse.format(sql, output_format='python') - assert f(sql) == ("sql = 'select * from foo; '\n" - "sql2 = 'select 1 from dual'") + assert f(sql) == '\n'.join([ + "sql = 'select * from foo; '", + "sql2 = 'select 1 from dual'"]) @pytest.mark.xfail(reason="Needs fixing") def test_python_multiple_statements_with_formatting(self): sql = 'select * from foo; select 1 from dual' f = lambda sql: sqlparse.format(sql, output_format='python', reindent=True) - assert f(sql) == ("sql = ('select * '\n" - " 'from foo;')\n" - "sql2 = ('select 1 '\n" - " 'from dual')") + assert f(sql) == '\n'.join([ + "sql = ('select * '", + " 'from foo;')", + "sql2 = ('select 1 '", + " 'from dual')"]) def test_php(self): sql = 'select * from foo;' @@ -519,72 +536,69 @@ class TestOutputFormat(object): assert f(sql) == '$sql = "select * from foo;";' f = lambda sql: sqlparse.format(sql, output_format='php', reindent=True) - assert f(sql) == ('$sql = "select * ";\n' - '$sql .= "from foo;";') + assert f(sql) == '\n'.join([ + '$sql = "select * ";', + '$sql .= "from foo;";']) - def test_sql(self): # "sql" is an allowed option but has no effect + def test_sql(self): + # "sql" is an allowed option but has no effect sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='sql') assert f(sql) == 'select * from foo;' -def test_format_column_ordering(): # issue89 +def test_format_column_ordering(): + # issue89 sql = 'select * from foo order by c1 desc, c2, c3;' formatted = sqlparse.format(sql, reindent=True) - expected = '\n'.join(['select *', - 'from foo', - 'order by c1 desc,', - ' c2,', - ' c3;']) + expected = '\n'.join([ + 'select *', + 'from foo', + 'order by c1 desc,', + ' c2,', + ' c3;']) assert formatted == expected def test_truncate_strings(): - sql = 'update foo set value = \'' + 'x' * 1000 + '\';' + sql = "update foo set value = '{0}';".format('x' * 1000) formatted = sqlparse.format(sql, truncate_strings=10) - assert formatted == 'update foo set value = \'xxxxxxxxxx[...]\';' + assert formatted == "update foo set value = 'xxxxxxxxxx[...]';" formatted = sqlparse.format(sql, truncate_strings=3, truncate_char='YYY') - assert formatted == 'update foo set value = \'xxxYYY\';' + assert formatted == "update foo set value = 'xxxYYY';" -def test_truncate_strings_invalid_option(): - pytest.raises(SQLParseError, sqlparse.format, 'foo', - truncate_strings='bar') - pytest.raises(SQLParseError, sqlparse.format, 'foo', - truncate_strings=-1) - pytest.raises(SQLParseError, sqlparse.format, 'foo', - truncate_strings=0) +@pytest.mark.parametrize('option', ['bar', -1, 0]) +def test_truncate_strings_invalid_option2(option): + with pytest.raises(SQLParseError): + sqlparse.format('foo', truncate_strings=option) -@pytest.mark.parametrize('sql', ['select verrrylongcolumn from foo', - 'select "verrrylongcolumn" from "foo"', - ]) +@pytest.mark.parametrize('sql', [ + 'select verrrylongcolumn from foo', + 'select "verrrylongcolumn" from "foo"']) def test_truncate_strings_doesnt_truncate_identifiers(sql): formatted = sqlparse.format(sql, truncate_strings=2) assert formatted == sql def test_having_produces_newline(): - sql = ( - 'select * from foo, bar where bar.id = foo.bar_id' - ' having sum(bar.value) > 100') + sql = ('select * from foo, bar where bar.id = foo.bar_id ' + 'having sum(bar.value) > 100') formatted = sqlparse.format(sql, reindent=True) expected = [ 'select *', 'from foo,', ' bar', 'where bar.id = foo.bar_id', - 'having sum(bar.value) > 100' - ] + 'having sum(bar.value) > 100'] assert formatted == '\n'.join(expected) -def test_format_right_margin_invalid_input(): - with pytest.raises(SQLParseError): - sqlparse.format('foo', right_margin=2) - +@pytest.mark.parametrize('right_margin', ['ten', 2]) +def test_format_right_margin_invalid_input(right_margin): with pytest.raises(SQLParseError): - sqlparse.format('foo', right_margin="two") + sqlparse.format('foo', right_margin=right_margin) @pytest.mark.xfail(reason="Needs fixing") diff --git a/tests/test_grouping.py b/tests/test_grouping.py index bab0d9a..8356e16 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -6,270 +6,299 @@ import sqlparse from sqlparse import sql, tokens as T -class TestGrouping(object): - def test_parenthesis(self): - s = 'select (select (x3) x2) and (y2) bar' - parsed = sqlparse.parse(s)[0] - assert str(parsed) == s - assert len(parsed.tokens) == 7 - assert isinstance(parsed.tokens[2], sql.Parenthesis) - assert isinstance(parsed.tokens[-1], sql.Identifier) - assert len(parsed.tokens[2].tokens) == 5 - assert isinstance(parsed.tokens[2].tokens[3], sql.Identifier) - assert isinstance(parsed.tokens[2].tokens[3].tokens[0], - sql.Parenthesis) - assert len(parsed.tokens[2].tokens[3].tokens) == 3 - - def test_comments(self): - s = '/*\n * foo\n */ \n bar' - parsed = sqlparse.parse(s)[0] - assert str(parsed) == s - self.assertEqual(len(parsed.tokens), 2) - - def test_assignment(self): - s = 'foo := 1;' - parsed = sqlparse.parse(s)[0] - assert len(parsed.tokens) == 1 - assert isinstance(parsed.tokens[0], sql.Assignment) - s = 'foo := 1' - parsed = sqlparse.parse(s)[0] - assert len(parsed.tokens) == 1 - assert isinstance(parsed.tokens[0], sql.Assignment) - - def test_identifiers(self): - s = 'select foo.bar from "myscheme"."table" where fail. order' - parsed = sqlparse.parse(s)[0] - assert str(parsed) == s - assert isinstance(parsed.tokens[2], sql.Identifier) - assert isinstance(parsed.tokens[6], sql.Identifier) - assert isinstance(parsed.tokens[8], sql.Where) - s = 'select * from foo where foo.id = 1' - parsed = sqlparse.parse(s)[0] - assert str(parsed) == s - assert isinstance(parsed.tokens[-1].tokens[-1].tokens[0], - sql.Identifier) - s = 'select * from (select "foo"."id" from foo)' - parsed = sqlparse.parse(s)[0] - assert str(parsed) == s - assert isinstance(parsed.tokens[-1].tokens[3], sql.Identifier) - - s = "INSERT INTO `test` VALUES('foo', 'bar');" - parsed = sqlparse.parse(s)[0] - types = [l.ttype for l in parsed.tokens if not l.is_whitespace()] - assert types == [T.DML, T.Keyword, None, - T.Keyword, None, T.Punctuation] - - s = "select 1.0*(a+b) as col, sum(c)/sum(d) from myschema.mytable" - parsed = sqlparse.parse(s)[0] - assert len(parsed.tokens) == 7 - assert isinstance(parsed.tokens[2], sql.IdentifierList) - assert len(parsed.tokens[2].tokens) == 4 - identifiers = list(parsed.tokens[2].get_identifiers()) - assert len(identifiers) == 2 - assert identifiers[0].get_alias() == "col" - - def test_identifier_wildcard(self): - p = sqlparse.parse('a.*, b.id')[0] - assert isinstance(p.tokens[0], sql.IdentifierList) - assert isinstance(p.tokens[0].tokens[0], sql.Identifier) - assert isinstance(p.tokens[0].tokens[-1], sql.Identifier) - - def test_identifier_name_wildcard(self): - p = sqlparse.parse('a.*')[0] - t = p.tokens[0] - assert t.get_name() == '*' - assert t.is_wildcard() is True - - def test_identifier_invalid(self): - p = sqlparse.parse('a.')[0] - assert isinstance(p.tokens[0], sql.Identifier) - assert p.tokens[0].has_alias() is False - assert p.tokens[0].get_name() is None - assert p.tokens[0].get_real_name() is None - assert p.tokens[0].get_parent_name() == 'a' - - def test_identifier_invalid_in_middle(self): - # issue261 - s = 'SELECT foo. FROM foo' - p = sqlparse.parse(s)[0] - assert isinstance(p[2], sql.Identifier) - assert p[2][1].ttype == T.Punctuation - assert p[3].ttype == T.Whitespace - assert str(p[2]) == 'foo.' - - def test_identifier_as_invalid(self): # issue8 - p = sqlparse.parse('foo as select *')[0] - assert len(p.tokens), 5 - assert isinstance(p.tokens[0], sql.Identifier) - assert len(p.tokens[0].tokens) == 1 - assert p.tokens[2].ttype == T.Keyword - - def test_identifier_function(self): - p = sqlparse.parse('foo() as bar')[0] - assert isinstance(p.tokens[0], sql.Identifier) - assert isinstance(p.tokens[0].tokens[0], sql.Function) - p = sqlparse.parse('foo()||col2 bar')[0] - assert isinstance(p.tokens[0], sql.Identifier) - assert isinstance(p.tokens[0].tokens[0], sql.Operation) - assert isinstance(p.tokens[0].tokens[0].tokens[0], sql.Function) - - def test_identifier_extended(self): # issue 15 - p = sqlparse.parse('foo+100')[0] - assert isinstance(p.tokens[0], sql.Operation) - p = sqlparse.parse('foo + 100')[0] - assert isinstance(p.tokens[0], sql.Operation) - p = sqlparse.parse('foo*100')[0] - assert isinstance(p.tokens[0], sql.Operation) - - def test_identifier_list(self): - p = sqlparse.parse('a, b, c')[0] - assert isinstance(p.tokens[0], sql.IdentifierList) - p = sqlparse.parse('(a, b, c)')[0] - assert isinstance(p.tokens[0].tokens[1], sql.IdentifierList) - - def test_identifier_list_subquery(self): - """identifier lists should still work in subqueries with aliases""" - p = sqlparse.parse("select * from (" - "select a, b + c as d from table) sub")[0] - subquery = p.tokens[-1].tokens[0] - idx, iden_list = subquery.token_next_by(i=sql.IdentifierList) - assert iden_list is not None - # all the identifiers should be within the IdentifierList - _, ilist = subquery.token_next_by(i=sql.Identifier, idx=idx) - assert ilist is None - - def test_identifier_list_case(self): - p = sqlparse.parse('a, case when 1 then 2 else 3 end as b, c')[0] - assert isinstance(p.tokens[0], sql.IdentifierList) - p = sqlparse.parse('(a, case when 1 then 2 else 3 end as b, c)')[0] - assert isinstance(p.tokens[0].tokens[1], sql.IdentifierList) - - def test_identifier_list_other(self): # issue2 - p = sqlparse.parse("select *, null, 1, 'foo', bar from mytable, x")[0] - assert isinstance(p.tokens[2], sql.IdentifierList) - l = p.tokens[2] - assert len(l.tokens) == 13 - - def test_identifier_list_with_inline_comments(self): # issue163 - p = sqlparse.parse('foo /* a comment */, bar')[0] - assert isinstance(p.tokens[0], sql.IdentifierList) - assert isinstance(p.tokens[0].tokens[0], sql.Identifier) - self.assert_(isinstance(p.tokens[0].tokens[3], sql.Identifier)) - - def test_identifiers_with_operators(self): - p = sqlparse.parse('a+b as c from table where (d-e)%2= 1')[0] - assert len([x for x in p.flatten() - if x.ttype == sqlparse.tokens.Name]) == 5 - - def test_identifier_list_with_order(self): # issue101 - p = sqlparse.parse('1, 2 desc, 3')[0] - assert isinstance(p.tokens[0], sql.IdentifierList) - assert isinstance(p.tokens[0].tokens[3], sql.Identifier) - assert str(p.tokens[0].tokens[3]) == '2 desc' - - def test_where(self): - s = 'select * from foo where bar = 1 order by id desc' - p = sqlparse.parse(s)[0] - assert str(p) == s - self.assert_(len(p.tokens) == 14) - - s = 'select x from (select y from foo where bar = 1) z' - p = sqlparse.parse(s)[0] - assert str(p) == s - assert isinstance(p.tokens[-1].tokens[0].tokens[-2], sql.Where) - - def test_typecast(self): - s = 'select foo::integer from bar' - p = sqlparse.parse(s)[0] - assert str(p) == s - assert p.tokens[2].get_typecast() == 'integer' - assert p.tokens[2].get_name() == 'foo' - s = 'select (current_database())::information_schema.sql_identifier' - p = sqlparse.parse(s)[0] - assert str(p) == s - assert (p.tokens[2].get_typecast() == - 'information_schema.sql_identifier') - - def test_alias(self): - s = 'select foo as bar from mytable' - p = sqlparse.parse(s)[0] - assert str(p) == s - assert p.tokens[2].get_real_name() == 'foo' - assert p.tokens[2].get_alias() == 'bar' - s = 'select foo from mytable t1' - p = sqlparse.parse(s)[0] - assert str(p) == s - assert p.tokens[6].get_real_name() == 'mytable' - assert p.tokens[6].get_alias() == 't1' - s = 'select foo::integer as bar from mytable' - p = sqlparse.parse(s)[0] - assert str(p) == s - assert p.tokens[2].get_alias() == 'bar' - s = ('SELECT DISTINCT ' - '(current_database())::information_schema.sql_identifier AS view') - p = sqlparse.parse(s)[0] - assert str(p) == s - assert p.tokens[4].get_alias() == 'view' - - def test_alias_case(self): # see issue46 - p = sqlparse.parse('CASE WHEN 1 THEN 2 ELSE 3 END foo')[0] - assert len(p.tokens) == 1 - assert p.tokens[0].get_alias() == 'foo' - - def test_alias_returns_none(self): # see issue185 - p = sqlparse.parse('foo.bar')[0] - assert len(p.tokens) == 1 - assert p.tokens[0].get_alias() is None - - def test_idlist_function(self): # see issue10 too - p = sqlparse.parse('foo(1) x, bar')[0] - assert isinstance(p.tokens[0], sql.IdentifierList) - - def test_comparison_exclude(self): - # make sure operators are not handled too lazy - p = sqlparse.parse('(=)')[0] - assert isinstance(p.tokens[0], sql.Parenthesis) - assert not isinstance(p.tokens[0].tokens[1], sql.Comparison) - p = sqlparse.parse('(a=1)')[0] - assert isinstance(p.tokens[0].tokens[1], sql.Comparison) - p = sqlparse.parse('(a>=1)')[0] - assert isinstance(p.tokens[0].tokens[1], sql.Comparison) - - def test_function(self): - p = sqlparse.parse('foo()')[0] - assert isinstance(p.tokens[0], sql.Function) - p = sqlparse.parse('foo(null, bar)')[0] - assert isinstance(p.tokens[0], sql.Function) - assert len(list(p.tokens[0].get_parameters())) == 2 - - def test_function_not_in(self): # issue183 - p = sqlparse.parse('in(1, 2)')[0] - assert len(p.tokens) == 2 - assert p.tokens[0].ttype == T.Keyword - assert isinstance(p.tokens[1], sql.Parenthesis) - - def test_varchar(self): - p = sqlparse.parse('"text" Varchar(50) NOT NULL')[0] - assert isinstance(p.tokens[2], sql.Function) - - -class TestStatement(object): - def test_get_type(self): - def f(sql): - return sqlparse.parse(sql)[0] - - assert f('select * from foo').get_type() == 'SELECT' - assert f('update foo').get_type() == 'UPDATE' - assert f(' update foo').get_type() == 'UPDATE' - assert f('\nupdate foo').get_type() == 'UPDATE' - assert f('foo').get_type() == 'UNKNOWN' - # Statements that have a whitespace after the closing semicolon - # are parsed as two statements where later only consists of the - # trailing whitespace. - assert f('\n').get_type() == 'UNKNOWN' - - -def test_identifier_with_operators(): # issue 53 +def test_grouping_parenthesis(): + s = 'select (select (x3) x2) and (y2) bar' + parsed = sqlparse.parse(s)[0] + assert str(parsed) == s + assert len(parsed.tokens) == 7 + assert isinstance(parsed.tokens[2], sql.Parenthesis) + assert isinstance(parsed.tokens[-1], sql.Identifier) + assert len(parsed.tokens[2].tokens) == 5 + assert isinstance(parsed.tokens[2].tokens[3], sql.Identifier) + assert isinstance(parsed.tokens[2].tokens[3].tokens[0], sql.Parenthesis) + assert len(parsed.tokens[2].tokens[3].tokens) == 3 + + +def test_grouping_comments(): + s = '/*\n * foo\n */ \n bar' + parsed = sqlparse.parse(s)[0] + assert str(parsed) == s + assert len(parsed.tokens) == 2 + + +def test_grouping_assignment(): + s = 'foo := 1;' + parsed = sqlparse.parse(s)[0] + assert len(parsed.tokens) == 1 + assert isinstance(parsed.tokens[0], sql.Assignment) + s = 'foo := 1' + parsed = sqlparse.parse(s)[0] + assert len(parsed.tokens) == 1 + assert isinstance(parsed.tokens[0], sql.Assignment) + + +def test_grouping_identifiers(): + s = 'select foo.bar from "myscheme"."table" where fail. order' + parsed = sqlparse.parse(s)[0] + assert str(parsed) == s + assert isinstance(parsed.tokens[2], sql.Identifier) + assert isinstance(parsed.tokens[6], sql.Identifier) + assert isinstance(parsed.tokens[8], sql.Where) + s = 'select * from foo where foo.id = 1' + parsed = sqlparse.parse(s)[0] + assert str(parsed) == s + assert isinstance(parsed.tokens[-1].tokens[-1].tokens[0], sql.Identifier) + s = 'select * from (select "foo"."id" from foo)' + parsed = sqlparse.parse(s)[0] + assert str(parsed) == s + assert isinstance(parsed.tokens[-1].tokens[3], sql.Identifier) + + s = "INSERT INTO `test` VALUES('foo', 'bar');" + parsed = sqlparse.parse(s)[0] + types = [l.ttype for l in parsed.tokens if not l.is_whitespace()] + assert types == [T.DML, T.Keyword, None, T.Keyword, None, T.Punctuation] + + s = "select 1.0*(a+b) as col, sum(c)/sum(d) from myschema.mytable" + parsed = sqlparse.parse(s)[0] + assert len(parsed.tokens) == 7 + assert isinstance(parsed.tokens[2], sql.IdentifierList) + assert len(parsed.tokens[2].tokens) == 4 + identifiers = list(parsed.tokens[2].get_identifiers()) + assert len(identifiers) == 2 + assert identifiers[0].get_alias() == "col" + + +def test_grouping_identifier_wildcard(): + p = sqlparse.parse('a.*, b.id')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + assert isinstance(p.tokens[0].tokens[0], sql.Identifier) + assert isinstance(p.tokens[0].tokens[-1], sql.Identifier) + + +def test_grouping_identifier_name_wildcard(): + p = sqlparse.parse('a.*')[0] + t = p.tokens[0] + assert t.get_name() == '*' + assert t.is_wildcard() is True + + +def test_grouping_identifier_invalid(): + p = sqlparse.parse('a.')[0] + assert isinstance(p.tokens[0], sql.Identifier) + assert p.tokens[0].has_alias() is False + assert p.tokens[0].get_name() is None + assert p.tokens[0].get_real_name() is None + assert p.tokens[0].get_parent_name() == 'a' + + +def test_grouping_identifier_invalid_in_middle(): + # issue261 + s = 'SELECT foo. FROM foo' + p = sqlparse.parse(s)[0] + assert isinstance(p[2], sql.Identifier) + assert p[2][1].ttype == T.Punctuation + assert p[3].ttype == T.Whitespace + assert str(p[2]) == 'foo.' + + +def test_grouping_identifier_as_invalid(): + # issue8 + p = sqlparse.parse('foo as select *')[0] + assert len(p.tokens), 5 + assert isinstance(p.tokens[0], sql.Identifier) + assert len(p.tokens[0].tokens) == 1 + assert p.tokens[2].ttype == T.Keyword + + +def test_grouping_identifier_function(): + p = sqlparse.parse('foo() as bar')[0] + assert isinstance(p.tokens[0], sql.Identifier) + assert isinstance(p.tokens[0].tokens[0], sql.Function) + p = sqlparse.parse('foo()||col2 bar')[0] + assert isinstance(p.tokens[0], sql.Identifier) + assert isinstance(p.tokens[0].tokens[0], sql.Operation) + assert isinstance(p.tokens[0].tokens[0].tokens[0], sql.Function) + + +def test_grouping_identifier_extended(): + # issue 15 + p = sqlparse.parse('foo+100')[0] + assert isinstance(p.tokens[0], sql.Operation) + p = sqlparse.parse('foo + 100')[0] + assert isinstance(p.tokens[0], sql.Operation) + p = sqlparse.parse('foo*100')[0] + assert isinstance(p.tokens[0], sql.Operation) + + +def test_grouping_identifier_list(): + p = sqlparse.parse('a, b, c')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + p = sqlparse.parse('(a, b, c)')[0] + assert isinstance(p.tokens[0].tokens[1], sql.IdentifierList) + + +def test_grouping_identifier_list_subquery(): + """identifier lists should still work in subqueries with aliases""" + p = sqlparse.parse("select * from (" + "select a, b + c as d from table) sub")[0] + subquery = p.tokens[-1].tokens[0] + idx, iden_list = subquery.token_next_by(i=sql.IdentifierList) + assert iden_list is not None + # all the identifiers should be within the IdentifierList + _, ilist = subquery.token_next_by(i=sql.Identifier, idx=idx) + assert ilist is None + + +def test_grouping_identifier_list_case(): + p = sqlparse.parse('a, case when 1 then 2 else 3 end as b, c')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + p = sqlparse.parse('(a, case when 1 then 2 else 3 end as b, c)')[0] + assert isinstance(p.tokens[0].tokens[1], sql.IdentifierList) + + +def test_grouping_identifier_list_other(): + # issue2 + p = sqlparse.parse("select *, null, 1, 'foo', bar from mytable, x")[0] + assert isinstance(p.tokens[2], sql.IdentifierList) + assert len(p.tokens[2].tokens) == 13 + + +def test_grouping_identifier_list_with_inline_comments(): + # issue163 + p = sqlparse.parse('foo /* a comment */, bar')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + assert isinstance(p.tokens[0].tokens[0], sql.Identifier) + assert isinstance(p.tokens[0].tokens[3], sql.Identifier) + + +def test_grouping_identifiers_with_operators(): + p = sqlparse.parse('a+b as c from table where (d-e)%2= 1')[0] + assert len([x for x in p.flatten() if x.ttype == T.Name]) == 5 + + +def test_grouping_identifier_list_with_order(): + # issue101 + p = sqlparse.parse('1, 2 desc, 3')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + assert isinstance(p.tokens[0].tokens[3], sql.Identifier) + assert str(p.tokens[0].tokens[3]) == '2 desc' + + +def test_grouping_where(): + s = 'select * from foo where bar = 1 order by id desc' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert len(p.tokens) == 14 + + s = 'select x from (select y from foo where bar = 1) z' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert isinstance(p.tokens[-1].tokens[0].tokens[-2], sql.Where) + + +def test_grouping_typecast(): + s = 'select foo::integer from bar' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert p.tokens[2].get_typecast() == 'integer' + assert p.tokens[2].get_name() == 'foo' + s = 'select (current_database())::information_schema.sql_identifier' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert (p.tokens[2].get_typecast() == 'information_schema.sql_identifier') + + +def test_grouping_alias(): + s = 'select foo as bar from mytable' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert p.tokens[2].get_real_name() == 'foo' + assert p.tokens[2].get_alias() == 'bar' + s = 'select foo from mytable t1' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert p.tokens[6].get_real_name() == 'mytable' + assert p.tokens[6].get_alias() == 't1' + s = 'select foo::integer as bar from mytable' + p = sqlparse.parse(s)[0] + assert str(p) == s + assert p.tokens[2].get_alias() == 'bar' + s = ('SELECT DISTINCT ' + '(current_database())::information_schema.sql_identifier AS view') + p = sqlparse.parse(s)[0] + assert str(p) == s + assert p.tokens[4].get_alias() == 'view' + + +def test_grouping_alias_case(): + # see issue46 + p = sqlparse.parse('CASE WHEN 1 THEN 2 ELSE 3 END foo')[0] + assert len(p.tokens) == 1 + assert p.tokens[0].get_alias() == 'foo' + + +def test_grouping_alias_returns_none(): + # see issue185 + p = sqlparse.parse('foo.bar')[0] + assert len(p.tokens) == 1 + assert p.tokens[0].get_alias() is None + + +def test_grouping_idlist_function(): + # see issue10 too + p = sqlparse.parse('foo(1) x, bar')[0] + assert isinstance(p.tokens[0], sql.IdentifierList) + + +def test_grouping_comparison_exclude(): + # make sure operators are not handled too lazy + p = sqlparse.parse('(=)')[0] + assert isinstance(p.tokens[0], sql.Parenthesis) + assert not isinstance(p.tokens[0].tokens[1], sql.Comparison) + p = sqlparse.parse('(a=1)')[0] + assert isinstance(p.tokens[0].tokens[1], sql.Comparison) + p = sqlparse.parse('(a>=1)')[0] + assert isinstance(p.tokens[0].tokens[1], sql.Comparison) + + +def test_grouping_function(): + p = sqlparse.parse('foo()')[0] + assert isinstance(p.tokens[0], sql.Function) + p = sqlparse.parse('foo(null, bar)')[0] + assert isinstance(p.tokens[0], sql.Function) + assert len(list(p.tokens[0].get_parameters())) == 2 + + +def test_grouping_function_not_in(): + # issue183 + p = sqlparse.parse('in(1, 2)')[0] + assert len(p.tokens) == 2 + assert p.tokens[0].ttype == T.Keyword + assert isinstance(p.tokens[1], sql.Parenthesis) + + +def test_grouping_varchar(): + p = sqlparse.parse('"text" Varchar(50) NOT NULL')[0] + assert isinstance(p.tokens[2], sql.Function) + + +def test_statement_get_type(): + def f(sql): + return sqlparse.parse(sql)[0] + + assert f('select * from foo').get_type() == 'SELECT' + assert f('update foo').get_type() == 'UPDATE' + assert f(' update foo').get_type() == 'UPDATE' + assert f('\nupdate foo').get_type() == 'UPDATE' + assert f('foo').get_type() == 'UNKNOWN' + # Statements that have a whitespace after the closing semicolon + # are parsed as two statements where later only consists of the + # trailing whitespace. + assert f('\n').get_type() == 'UNKNOWN' + + +def test_identifier_with_operators(): + # issue 53 p = sqlparse.parse('foo||bar')[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Operation) @@ -288,7 +317,7 @@ def test_identifier_with_op_trailing_ws(): def test_identifier_with_string_literals(): - p = sqlparse.parse('foo + \'bar\'')[0] + p = sqlparse.parse("foo + 'bar'")[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Operation) @@ -297,12 +326,13 @@ def test_identifier_with_string_literals(): # showed that this shouldn't be an identifier at all. I'm leaving this # commented in the source for a while. # def test_identifier_string_concat(): -# p = sqlparse.parse('\'foo\' || bar')[0] +# p = sqlparse.parse("'foo' || bar")[0] # assert len(p.tokens) == 1 # assert isinstance(p.tokens[0], sql.Identifier) -def test_identifier_consumes_ordering(): # issue89 +def test_identifier_consumes_ordering(): + # issue89 p = sqlparse.parse('select * from foo order by c1 desc, c2, c3')[0] assert isinstance(p.tokens[-1], sql.IdentifierList) ids = list(p.tokens[-1].get_identifiers()) @@ -313,7 +343,8 @@ def test_identifier_consumes_ordering(): # issue89 assert ids[1].get_ordering() is None -def test_comparison_with_keywords(): # issue90 +def test_comparison_with_keywords(): + # issue90 # in fact these are assignments, but for now we don't distinguish them p = sqlparse.parse('foo = NULL')[0] assert len(p.tokens) == 1 @@ -327,7 +358,8 @@ def test_comparison_with_keywords(): # issue90 assert isinstance(p.tokens[0], sql.Comparison) -def test_comparison_with_floats(): # issue145 +def test_comparison_with_floats(): + # issue145 p = sqlparse.parse('foo = 25.5')[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Comparison) @@ -336,7 +368,8 @@ def test_comparison_with_floats(): # issue145 assert p.tokens[0].right.value == '25.5' -def test_comparison_with_parenthesis(): # issue23 +def test_comparison_with_parenthesis(): + # issue23 p = sqlparse.parse('(3 + 4) = 7')[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Comparison) @@ -345,15 +378,17 @@ def test_comparison_with_parenthesis(): # issue23 assert comp.right.ttype is T.Number.Integer -def test_comparison_with_strings(): # issue148 - p = sqlparse.parse('foo = \'bar\'')[0] +def test_comparison_with_strings(): + # issue148 + p = sqlparse.parse("foo = 'bar'")[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Comparison) - assert p.tokens[0].right.value == '\'bar\'' + assert p.tokens[0].right.value == "'bar'" assert p.tokens[0].right.ttype == T.String.Single -def test_comparison_with_functions(): # issue230 +def test_comparison_with_functions(): + # issue230 p = sqlparse.parse('foo = DATE(bar.baz)')[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Comparison) @@ -376,9 +411,7 @@ def test_comparison_with_functions(): # issue230 assert p.tokens[0].right.value == 'bar.baz' -@pytest.mark.parametrize('start', ['FOR', - 'FOREACH', - ]) +@pytest.mark.parametrize('start', ['FOR', 'FOREACH']) def test_forloops(start): p = sqlparse.parse('{0} foo in bar LOOP foobar END LOOP'.format(start))[0] assert (len(p.tokens)) == 1 diff --git a/tests/test_parse.py b/tests/test_parse.py index 8ab7e9b..aa2cd15 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Tests sqlparse function.""" +"""Tests sqlparse.parse().""" import pytest @@ -9,131 +9,146 @@ from sqlparse import sql, tokens as T from sqlparse.compat import StringIO -class SQLParseTest(object): - """Tests sqlparse.parse().""" - - def test_tokenize(self): - s = 'select * from foo;' - stmts = sqlparse.parse(s) - assert len(stmts) == 1 - assert str(stmts[0]) == s - - def test_multistatement(self): - sql1 = 'select * from foo;' - sql2 = 'select * from bar;' - stmts = sqlparse.parse(sql1 + sql2) - assert len(stmts) == 2 - assert str(stmts[0]) == sql1 - assert str(stmts[1]) == sql2 - - def test_newlines(self): - s = 'select\n*from foo;' - p = sqlparse.parse(s)[0] - assert str(p) == s - s = 'select\r\n*from foo' - p = sqlparse.parse(s)[0] - assert str(p) == s - s = 'select\r*from foo' - p = sqlparse.parse(s)[0] - assert str(p) == s - s = 'select\r\n*from foo\n' - p = sqlparse.parse(s)[0] - assert str(p) == s - - def test_within(self): - s = 'foo(col1, col2)' - p = sqlparse.parse(s)[0] - col1 = p.tokens[0].tokens[1].tokens[1].tokens[0] - assert col1.within(sql.Function) - - def test_child_of(self): - s = '(col1, col2)' - p = sqlparse.parse(s)[0] - assert p.tokens[0].tokens[1].is_child_of(p.tokens[0]) - s = 'select foo' - p = sqlparse.parse(s)[0] - assert not p.tokens[2].is_child_of(p.tokens[0]) - assert p.tokens[2].is_child_of(p) - - def test_has_ancestor(self): - s = 'foo or (bar, baz)' - p = sqlparse.parse(s)[0] - baz = p.tokens[-1].tokens[1].tokens[-1] - assert baz.has_ancestor(p.tokens[-1].tokens[1]) - assert baz.has_ancestor(p.tokens[-1]) - assert baz.has_ancestor(p) - - def test_float(self): - t = sqlparse.parse('.5')[0].tokens - assert len(t) == 1 - assert t[0].ttype is sqlparse.tokens.Number.Float - t = sqlparse.parse('.51')[0].tokens - assert len(t) == 1 - assert t[0].ttype is sqlparse.tokens.Number.Float - t = sqlparse.parse('1.5')[0].tokens - assert len(t) == 1 - assert t[0].ttype is sqlparse.tokens.Number.Float - t = sqlparse.parse('12.5')[0].tokens - assert len(t) == 1 - assert t[0].ttype is sqlparse.tokens.Number.Float - - def test_placeholder(self): - def _get_tokens(s): - return sqlparse.parse(s)[0].tokens[-1].tokens - - t = _get_tokens('select * from foo where user = ?') - assert t[-1].ttype is sqlparse.tokens.Name.Placeholder - assert t[-1].value == '?' - t = _get_tokens('select * from foo where user = :1') - assert t[-1].ttype is sqlparse.tokens.Name.Placeholder - assert t[-1].value == ':1' - t = _get_tokens('select * from foo where user = :name') - assert t[-1].ttype is sqlparse.tokens.Name.Placeholder - assert t[-1].value == ':name' - t = _get_tokens('select * from foo where user = %s') - assert t[-1].ttype is sqlparse.tokens.Name.Placeholder - assert t[-1].value == '%s' - t = _get_tokens('select * from foo where user = $a') - assert t[-1].ttype is sqlparse.tokens.Name.Placeholder - assert t[-1].value == '$a' - - def test_modulo_not_placeholder(self): - tokens = list(sqlparse.lexer.tokenize('x %3')) - assert tokens[2][0] == sqlparse.tokens.Operator - - def test_access_symbol(self): # see issue27 - t = sqlparse.parse('select a.[foo bar] as foo')[0].tokens - assert isinstance(t[-1], sql.Identifier) - assert t[-1].get_name() == 'foo' - assert t[-1].get_real_name() == '[foo bar]' - assert t[-1].get_parent_name() == 'a' - - def test_square_brackets_notation_isnt_too_greedy(self): # see issue153 - t = sqlparse.parse('[foo], [bar]')[0].tokens - assert isinstance(t[0], sql.IdentifierList) - assert len(t[0].tokens) == 4 - assert t[0].tokens[0].get_real_name() == '[foo]' - assert t[0].tokens[-1].get_real_name() == '[bar]' - - def test_keyword_like_identifier(self): # see issue47 - t = sqlparse.parse('foo.key')[0].tokens - assert len(t) == 1 - assert isinstance(t[0], sql.Identifier) - - def test_function_parameter(self): # see issue94 - t = sqlparse.parse('abs(some_col)')[0].tokens[0].get_parameters() - assert len(t) == 1 - assert isinstance(t[0], sql.Identifier) - - def test_function_param_single_literal(self): - t = sqlparse.parse('foo(5)')[0].tokens[0].get_parameters() - assert len(t) == 1 - assert t[0].ttype is T.Number.Integer - - def test_nested_function(self): - t = sqlparse.parse('foo(bar(5))')[0].tokens[0].get_parameters() - assert len(t) == 1 - assert type(t[0]) is sql.Function +def test_parse_tokenize(): + s = 'select * from foo;' + stmts = sqlparse.parse(s) + assert len(stmts) == 1 + assert str(stmts[0]) == s + + +def test_parse_multistatement(): + sql1 = 'select * from foo;' + sql2 = 'select * from bar;' + stmts = sqlparse.parse(sql1 + sql2) + assert len(stmts) == 2 + assert str(stmts[0]) == sql1 + assert str(stmts[1]) == sql2 + + +def test_parse_newlines(): + s = 'select\n*from foo;' + p = sqlparse.parse(s)[0] + assert str(p) == s + s = 'select\r\n*from foo' + p = sqlparse.parse(s)[0] + assert str(p) == s + s = 'select\r*from foo' + p = sqlparse.parse(s)[0] + assert str(p) == s + s = 'select\r\n*from foo\n' + p = sqlparse.parse(s)[0] + assert str(p) == s + + +def test_parse_within(): + s = 'foo(col1, col2)' + p = sqlparse.parse(s)[0] + col1 = p.tokens[0].tokens[1].tokens[1].tokens[0] + assert col1.within(sql.Function) + + +def test_parse_child_of(): + s = '(col1, col2)' + p = sqlparse.parse(s)[0] + assert p.tokens[0].tokens[1].is_child_of(p.tokens[0]) + s = 'select foo' + p = sqlparse.parse(s)[0] + assert not p.tokens[2].is_child_of(p.tokens[0]) + assert p.tokens[2].is_child_of(p) + + +def test_parse_has_ancestor(): + s = 'foo or (bar, baz)' + p = sqlparse.parse(s)[0] + baz = p.tokens[-1].tokens[1].tokens[-1] + assert baz.has_ancestor(p.tokens[-1].tokens[1]) + assert baz.has_ancestor(p.tokens[-1]) + assert baz.has_ancestor(p) + + +def test_parse_float(): + t = sqlparse.parse('.5')[0].tokens + assert len(t) == 1 + assert t[0].ttype is sqlparse.tokens.Number.Float + t = sqlparse.parse('.51')[0].tokens + assert len(t) == 1 + assert t[0].ttype is sqlparse.tokens.Number.Float + t = sqlparse.parse('1.5')[0].tokens + assert len(t) == 1 + assert t[0].ttype is sqlparse.tokens.Number.Float + t = sqlparse.parse('12.5')[0].tokens + assert len(t) == 1 + assert t[0].ttype is sqlparse.tokens.Number.Float + + +def test_parse_placeholder(): + def _get_tokens(s): + return sqlparse.parse(s)[0].tokens[-1].tokens + + t = _get_tokens('select * from foo where user = ?') + assert t[-1].ttype is sqlparse.tokens.Name.Placeholder + assert t[-1].value == '?' + t = _get_tokens('select * from foo where user = :1') + assert t[-1].ttype is sqlparse.tokens.Name.Placeholder + assert t[-1].value == ':1' + t = _get_tokens('select * from foo where user = :name') + assert t[-1].ttype is sqlparse.tokens.Name.Placeholder + assert t[-1].value == ':name' + t = _get_tokens('select * from foo where user = %s') + assert t[-1].ttype is sqlparse.tokens.Name.Placeholder + assert t[-1].value == '%s' + t = _get_tokens('select * from foo where user = $a') + assert t[-1].ttype is sqlparse.tokens.Name.Placeholder + assert t[-1].value == '$a' + + +def test_parse_modulo_not_placeholder(): + tokens = list(sqlparse.lexer.tokenize('x %3')) + assert tokens[2][0] == sqlparse.tokens.Operator + + +def test_parse_access_symbol(): + # see issue27 + t = sqlparse.parse('select a.[foo bar] as foo')[0].tokens + assert isinstance(t[-1], sql.Identifier) + assert t[-1].get_name() == 'foo' + assert t[-1].get_real_name() == '[foo bar]' + assert t[-1].get_parent_name() == 'a' + + +def test_parse_square_brackets_notation_isnt_too_greedy(): + # see issue153 + t = sqlparse.parse('[foo], [bar]')[0].tokens + assert isinstance(t[0], sql.IdentifierList) + assert len(t[0].tokens) == 4 + assert t[0].tokens[0].get_real_name() == '[foo]' + assert t[0].tokens[-1].get_real_name() == '[bar]' + + +def test_parse_keyword_like_identifier(): + # see issue47 + t = sqlparse.parse('foo.key')[0].tokens + assert len(t) == 1 + assert isinstance(t[0], sql.Identifier) + + +def test_parse_function_parameter(): + # see issue94 + t = sqlparse.parse('abs(some_col)')[0].tokens[0].get_parameters() + assert len(t) == 1 + assert isinstance(t[0], sql.Identifier) + + +def test_parse_function_param_single_literal(): + t = sqlparse.parse('foo(5)')[0].tokens[0].get_parameters() + assert len(t) == 1 + assert t[0].ttype is T.Number.Integer + + +def test_parse_nested_function(): + t = sqlparse.parse('foo(bar(5))')[0].tokens[0].get_parameters() + assert len(t) == 1 + assert type(t[0]) is sql.Function def test_quoted_identifier(): @@ -143,15 +158,16 @@ def test_quoted_identifier(): assert t[2].get_real_name() == 'y' -@pytest.mark.parametrize('name', ['foo', - '_foo', - ]) -def test_valid_identifier_names(name): # issue175 +@pytest.mark.parametrize('name', ['foo', '_foo']) +def test_valid_identifier_names(name): + # issue175 t = sqlparse.parse(name)[0].tokens assert isinstance(t[0], sql.Identifier) -def test_psql_quotation_marks(): # issue83 +def test_psql_quotation_marks(): + # issue83 + # regression: make sure plain $$ work t = sqlparse.split(""" CREATE OR REPLACE FUNCTION testfunc1(integer) RETURNS integer AS $$ @@ -161,6 +177,7 @@ def test_psql_quotation_marks(): # issue83 .... $$ LANGUAGE plpgsql;""") assert len(t) == 2 + # make sure $SOMETHING$ works too t = sqlparse.split(""" CREATE OR REPLACE FUNCTION testfunc1(integer) RETURNS integer AS $PROC_1$ @@ -180,22 +197,14 @@ def test_double_precision_is_builtin(): assert t[0].value == 'DOUBLE PRECISION' -@pytest.mark.parametrize('ph', ['?', - ':1', - ':foo', - '%s', - '%(foo)s', - ]) +@pytest.mark.parametrize('ph', ['?', ':1', ':foo', '%s', '%(foo)s']) def test_placeholder(ph): p = sqlparse.parse(ph)[0].tokens assert len(p) == 1 assert p[0].ttype is T.Name.Placeholder -@pytest.mark.parametrize('num', ['6.67428E-8', - '1.988e33', - '1e-12', - ]) +@pytest.mark.parametrize('num', ['6.67428E-8', '1.988e33', '1e-12']) def test_scientific_numbers(num): p = sqlparse.parse(num)[0].tokens assert len(p) == 1 @@ -214,7 +223,8 @@ def test_double_quotes_are_identifiers(): assert isinstance(p[0], sql.Identifier) -def test_single_quotes_with_linebreaks(): # issue118 +def test_single_quotes_with_linebreaks(): + # issue118 p = sqlparse.parse("'f\nf'")[0].tokens assert len(p) == 1 assert p[0].ttype is T.String.Single @@ -288,29 +298,25 @@ def test_typed_array_definition(): assert names == ['x', 'y', 'z'] -@pytest.mark.parametrize('s', ['select 1 -- foo', - 'select 1 # foo', # see issue178 - ]) +@pytest.mark.parametrize('s', ['select 1 -- foo', 'select 1 # foo']) def test_single_line_comments(s): + # see issue178 p = sqlparse.parse(s)[0] assert len(p.tokens) == 5 assert p.tokens[-1].ttype == T.Comment.Single -@pytest.mark.parametrize('s', ['foo', - '@foo', - '#foo', # see issue192 - '##foo', - ]) +@pytest.mark.parametrize('s', ['foo', '@foo', '#foo', '##foo']) def test_names_and_special_names(s): + # see issue192 p = sqlparse.parse(s)[0] assert len(p.tokens) == 1 assert isinstance(p.tokens[0], sql.Identifier) def test_get_token_at_offset(): - # 0123456789 p = sqlparse.parse('select * from dual')[0] + # 0123456789 assert p.get_token_at_offset(0) == p.tokens[0] assert p.get_token_at_offset(1) == p.tokens[0] assert p.get_token_at_offset(6) == p.tokens[1] @@ -326,60 +332,61 @@ def test_pprint(): output = StringIO() p._pprint_tree(f=output) - pprint = u'\n'.join([" 0 DML 'select'", - " 1 Whitespace ' '", - " 2 IdentifierList 'a0, b0...'", - " | 0 Identifier 'a0'", - " | | 0 Name 'a0'", - " | 1 Punctuation ','", - " | 2 Whitespace ' '", - " | 3 Identifier 'b0'", - " | | 0 Name 'b0'", - " | 4 Punctuation ','", - " | 5 Whitespace ' '", - " | 6 Identifier 'c0'", - " | | 0 Name 'c0'", - " | 7 Punctuation ','", - " | 8 Whitespace ' '", - " | 9 Identifier 'd0'", - " | | 0 Name 'd0'", - " | 10 Punctuation ','", - " | 11 Whitespace ' '", - " | 12 Float 'e0'", - " 3 Whitespace ' '", - " 4 Keyword 'from'", - " 5 Whitespace ' '", - " 6 Identifier '(selec...'", - " | 0 Parenthesis '(selec...'", - " | | 0 Punctuation '('", - " | | 1 DML 'select'", - " | | 2 Whitespace ' '", - " | | 3 Wildcard '*'", - " | | 4 Whitespace ' '", - " | | 5 Keyword 'from'", - " | | 6 Whitespace ' '", - " | | 7 Identifier 'dual'", - " | | | 0 Name 'dual'", - " | | 8 Punctuation ')'", - " | 1 Whitespace ' '", - " | 2 Identifier 'q0'", - " | | 0 Name 'q0'", - " 7 Whitespace ' '", - " 8 Where 'where ...'", - " | 0 Keyword 'where'", - " | 1 Whitespace ' '", - " | 2 Comparison '1=1'", - " | | 0 Integer '1'", - " | | 1 Comparison '='", - " | | 2 Integer '1'", - " | 3 Whitespace ' '", - " | 4 Keyword 'and'", - " | 5 Whitespace ' '", - " | 6 Comparison '2=2'", - " | | 0 Integer '2'", - " | | 1 Comparison '='", - " | | 2 Integer '2'", - "", ]) + pprint = '\n'.join([ + " 0 DML 'select'", + " 1 Whitespace ' '", + " 2 IdentifierList 'a0, b0...'", + " | 0 Identifier 'a0'", + " | | 0 Name 'a0'", + " | 1 Punctuation ','", + " | 2 Whitespace ' '", + " | 3 Identifier 'b0'", + " | | 0 Name 'b0'", + " | 4 Punctuation ','", + " | 5 Whitespace ' '", + " | 6 Identifier 'c0'", + " | | 0 Name 'c0'", + " | 7 Punctuation ','", + " | 8 Whitespace ' '", + " | 9 Identifier 'd0'", + " | | 0 Name 'd0'", + " | 10 Punctuation ','", + " | 11 Whitespace ' '", + " | 12 Float 'e0'", + " 3 Whitespace ' '", + " 4 Keyword 'from'", + " 5 Whitespace ' '", + " 6 Identifier '(selec...'", + " | 0 Parenthesis '(selec...'", + " | | 0 Punctuation '('", + " | | 1 DML 'select'", + " | | 2 Whitespace ' '", + " | | 3 Wildcard '*'", + " | | 4 Whitespace ' '", + " | | 5 Keyword 'from'", + " | | 6 Whitespace ' '", + " | | 7 Identifier 'dual'", + " | | | 0 Name 'dual'", + " | | 8 Punctuation ')'", + " | 1 Whitespace ' '", + " | 2 Identifier 'q0'", + " | | 0 Name 'q0'", + " 7 Whitespace ' '", + " 8 Where 'where ...'", + " | 0 Keyword 'where'", + " | 1 Whitespace ' '", + " | 2 Comparison '1=1'", + " | | 0 Integer '1'", + " | | 1 Comparison '='", + " | | 2 Integer '1'", + " | 3 Whitespace ' '", + " | 4 Keyword 'and'", + " | 5 Whitespace ' '", + " | 6 Comparison '2=2'", + " | | 0 Integer '2'", + " | | 1 Comparison '='", + " | | 2 Integer '2'", + ""]) assert output.getvalue() == pprint diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 10fd4a1..71fa2bd 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -7,100 +7,107 @@ from sqlparse import sql, tokens as T from sqlparse.compat import PY2 -class RegressionTests(object): - def test_issue9(self): - # make sure where doesn't consume parenthesis - p = sqlparse.parse('(where 1)')[0] - assert isinstance(p, sql.Statement) - assert len(p.tokens) == 1 - assert isinstance(p.tokens[0], sql.Parenthesis) - prt = p.tokens[0] - assert len(prt.tokens) == 3 - assert prt.tokens[0].ttype == T.Punctuation - assert prt.tokens[-1].ttype == T.Punctuation - - def test_issue13(self): - parsed = sqlparse.parse(("select 'one';\n" - "select 'two\\'';\n" - "select 'three';")) - assert len(parsed) == 3 - assert str(parsed[1]).strip() == "select 'two\\'';" - - def test_issue26(self): - # parse stand-alone comments - p = sqlparse.parse('--hello')[0] - assert len(p.tokens) == 1 - assert p.tokens[0].ttype is T.Comment.Single - p = sqlparse.parse('-- hello')[0] - assert len(p.tokens) == 1 - assert p.tokens[0].ttype is T.Comment.Single - p = sqlparse.parse('--hello\n')[0] - assert len(p.tokens) == 1 - assert p.tokens[0].ttype is T.Comment.Single - p = sqlparse.parse('--')[0] - assert len(p.tokens) == 1 - assert p.tokens[0].ttype is T.Comment.Single - p = sqlparse.parse('--\n')[0] - assert len(p.tokens) == 1 - assert p.tokens[0].ttype is T.Comment.Single - - def test_issue34(self): - t = sqlparse.parse("create")[0].token_first() - assert t.match(T.Keyword.DDL, "create") is True - assert t.match(T.Keyword.DDL, "CREATE") is True - - def test_issue35(self): - # missing space before LIMIT - sql = sqlparse.format("select * from foo where bar = 1 limit 1", - reindent=True) - assert sql == "\n".join(["select *", - "from foo", - "where bar = 1 limit 1"]) - - def test_issue38(self): - sql = sqlparse.format("SELECT foo; -- comment", - strip_comments=True) - assert sql == "SELECT foo;" - sql = sqlparse.format("/* foo */", strip_comments=True) - assert sql == "" - - def test_issue39(self): - p = sqlparse.parse('select user.id from user')[0] - assert len(p.tokens) == 7 - idt = p.tokens[2] - assert idt.__class__ == sql.Identifier - assert len(idt.tokens) == 3 - assert idt.tokens[0].match(T.Name, 'user') is True - assert idt.tokens[1].match(T.Punctuation, '.') is True - assert idt.tokens[2].match(T.Name, 'id') is True - - def test_issue40(self): - # make sure identifier lists in subselects are grouped - p = sqlparse.parse(('SELECT id, name FROM ' - '(SELECT id, name FROM bar) as foo'))[0] - assert len(p.tokens) == 7 - assert p.tokens[2].__class__ == sql.IdentifierList - assert p.tokens[-1].__class__ == sql.Identifier - assert p.tokens[-1].get_name() == u'foo' - sp = p.tokens[-1].tokens[0] - assert sp.tokens[3].__class__ == sql.IdentifierList - # make sure that formatting works as expected - assert sqlparse.format(('SELECT id == name FROM ' - '(SELECT id, name FROM bar)'), - reindent=True), ('SELECT id,\n' - ' name\n' - 'FROM\n' - ' (SELECT id,\n' - ' name\n' - ' FROM bar)') - assert sqlparse.format(('SELECT id == name FROM ' - '(SELECT id, name FROM bar) as foo'), - reindent=True), ('SELECT id,\n' - ' name\n' - 'FROM\n' - ' (SELECT id,\n' - ' name\n' - ' FROM bar) as foo') +def test_issue9(): + # make sure where doesn't consume parenthesis + p = sqlparse.parse('(where 1)')[0] + assert isinstance(p, sql.Statement) + assert len(p.tokens) == 1 + assert isinstance(p.tokens[0], sql.Parenthesis) + prt = p.tokens[0] + assert len(prt.tokens) == 3 + assert prt.tokens[0].ttype == T.Punctuation + assert prt.tokens[-1].ttype == T.Punctuation + + +def test_issue13(): + parsed = sqlparse.parse(("select 'one';\n" + "select 'two\\'';\n" + "select 'three';")) + assert len(parsed) == 3 + assert str(parsed[1]).strip() == "select 'two\\'';" + + +def test_issue26(): + # parse stand-alone comments + p = sqlparse.parse('--hello')[0] + assert len(p.tokens) == 1 + assert p.tokens[0].ttype is T.Comment.Single + p = sqlparse.parse('-- hello')[0] + assert len(p.tokens) == 1 + assert p.tokens[0].ttype is T.Comment.Single + p = sqlparse.parse('--hello\n')[0] + assert len(p.tokens) == 1 + assert p.tokens[0].ttype is T.Comment.Single + p = sqlparse.parse('--')[0] + assert len(p.tokens) == 1 + assert p.tokens[0].ttype is T.Comment.Single + p = sqlparse.parse('--\n')[0] + assert len(p.tokens) == 1 + assert p.tokens[0].ttype is T.Comment.Single + + +@pytest.mark.parametrize('value', ['create', 'CREATE']) +def test_issue34(value): + t = sqlparse.parse("create")[0].token_first() + assert t.match(T.Keyword.DDL, value) is True + + +def test_issue35(): + # missing space before LIMIT + sql = sqlparse.format("select * from foo where bar = 1 limit 1", + reindent=True) + assert sql == "\n".join([ + "select *", + "from foo", + "where bar = 1 limit 1"]) + + +def test_issue38(): + sql = sqlparse.format("SELECT foo; -- comment", strip_comments=True) + assert sql == "SELECT foo;" + sql = sqlparse.format("/* foo */", strip_comments=True) + assert sql == "" + + +def test_issue39(): + p = sqlparse.parse('select user.id from user')[0] + assert len(p.tokens) == 7 + idt = p.tokens[2] + assert idt.__class__ == sql.Identifier + assert len(idt.tokens) == 3 + assert idt.tokens[0].match(T.Name, 'user') is True + assert idt.tokens[1].match(T.Punctuation, '.') is True + assert idt.tokens[2].match(T.Name, 'id') is True + + +def test_issue40(): + # make sure identifier lists in subselects are grouped + p = sqlparse.parse(('SELECT id, name FROM ' + '(SELECT id, name FROM bar) as foo'))[0] + assert len(p.tokens) == 7 + assert p.tokens[2].__class__ == sql.IdentifierList + assert p.tokens[-1].__class__ == sql.Identifier + assert p.tokens[-1].get_name() == 'foo' + sp = p.tokens[-1].tokens[0] + assert sp.tokens[3].__class__ == sql.IdentifierList + # make sure that formatting works as expected + s = sqlparse.format('SELECT id == name FROM ' + '(SELECT id, name FROM bar)', reindent=True) + assert s == '\n'.join([ + 'SELECT id == name', + 'FROM', + ' (SELECT id,', + ' name', + ' FROM bar)']) + + s = sqlparse.format('SELECT id == name FROM ' + '(SELECT id, name FROM bar) as foo', reindent=True) + assert s == '\n'.join([ + 'SELECT id == name', + 'FROM', + ' (SELECT id,', + ' name', + ' FROM bar) as foo']) def test_issue78(): @@ -198,17 +205,18 @@ def test_issue90(): ' "rating_score" = 0, "thumbnail_width" = NULL,' ' "thumbnail_height" = NULL, "price" = 1, "description" = NULL') formatted = sqlparse.format(sql, reindent=True) - tformatted = '\n'.join(['UPDATE "gallery_photo"', - 'SET "owner_id" = 4018,', - ' "deleted_at" = NULL,', - ' "width" = NULL,', - ' "height" = NULL,', - ' "rating_votes" = 0,', - ' "rating_score" = 0,', - ' "thumbnail_width" = NULL,', - ' "thumbnail_height" = NULL,', - ' "price" = 1,', - ' "description" = NULL']) + tformatted = '\n'.join([ + 'UPDATE "gallery_photo"', + 'SET "owner_id" = 4018,', + ' "deleted_at" = NULL,', + ' "width" = NULL,', + ' "height" = NULL,', + ' "rating_votes" = 0,', + ' "rating_score" = 0,', + ' "thumbnail_width" = NULL,', + ' "thumbnail_height" = NULL,', + ' "price" = 1,', + ' "description" = NULL']) assert formatted == tformatted @@ -222,8 +230,7 @@ def test_except_formatting(): 'EXCEPT', 'SELECT 2', 'FROM bar', - 'WHERE 1 = 2' - ]) + 'WHERE 1 = 2']) assert formatted == tformatted @@ -233,8 +240,7 @@ def test_null_with_as(): tformatted = '\n'.join([ 'SELECT NULL AS c1,', ' NULL AS c2', - 'FROM t1' - ]) + 'FROM t1']) assert formatted == tformatted @@ -270,8 +276,8 @@ def test_issue186_get_type(): def test_issue212_py2unicode(): - t1 = sql.Token(T.String, u"schöner ") - t2 = sql.Token(T.String, u"bug") + t1 = sql.Token(T.String, u'schöner ') + t2 = sql.Token(T.String, 'bug') l = sql.TokenList([t1, t2]) assert str(l) == 'schöner bug' @@ -298,12 +304,13 @@ def test_issue227_gettype_cte(): def test_issue207_runaway_format(): sql = 'select 1 from (select 1 as one, 2 as two, 3 from dual) t0' p = sqlparse.format(sql, reindent=True) - assert p == '\n'.join(["select 1", - "from", - " (select 1 as one,", - " 2 as two,", - " 3", - " from dual) t0"]) + assert p == '\n'.join([ + "select 1", + "from", + " (select 1 as one,", + " 2 as two,", + " 3", + " from dual) t0"]) def token_next_doesnt_ignore_skip_cm(): diff --git a/tests/test_split.py b/tests/test_split.py index 5968806..8a2fe2d 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -10,133 +10,146 @@ import sqlparse from sqlparse.compat import StringIO, text_type -class SQLSplitTest(object): - """Tests sqlparse.sqlsplit().""" - - _sql1 = 'select * from foo;' - _sql2 = 'select * from bar;' - - def test_split_semicolon(self): - sql2 = 'select * from foo where bar = \'foo;bar\';' - stmts = sqlparse.parse(''.join([self._sql1, sql2])) - assert len(stmts) == 2 - assert str(stmts[0]) == self._sql1 - assert str(stmts[1]) == sql2 - - def test_split_backslash(self): - stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';") - assert len(stmts) == 3 - - def test_create_function(self, load_file): - sql = load_file('function.sql') - stmts = sqlparse.parse(sql) - assert len(stmts) == 1 - assert str(stmts[0]) == sql - - def test_create_function_psql(self, load_file): - sql = load_file('function_psql.sql') - stmts = sqlparse.parse(sql) - assert len(stmts) == 1 - assert str(stmts[0]) == sql - - def test_create_function_psql3(self, load_file): - sql = load_file('function_psql3.sql') - stmts = sqlparse.parse(sql) - assert len(stmts) == 1 - assert str(stmts[0]) == sql - - def test_create_function_psql2(self, load_file): - sql = load_file('function_psql2.sql') - stmts = sqlparse.parse(sql) - assert len(stmts) == 1 - assert str(stmts[0]) == sql - - def test_dashcomments(self, load_file): - sql = load_file('dashcomment.sql') - stmts = sqlparse.parse(sql) - assert len(stmts) == 3 - assert ''.join(str(q) for q in stmts) == sql - - def test_dashcomments_eol(self): - stmts = sqlparse.parse('select foo; -- comment\n') - assert len(stmts) == 1 - stmts = sqlparse.parse('select foo; -- comment\r') - assert len(stmts) == 1 - stmts = sqlparse.parse('select foo; -- comment\r\n') - assert len(stmts) == 1 - stmts = sqlparse.parse('select foo; -- comment') - assert len(stmts) == 1 - - def test_begintag(self, load_file): - sql = load_file('begintag.sql') - stmts = sqlparse.parse(sql) - assert len(stmts) == 3 - assert ''.join(str(q) for q in stmts) == sql - - def test_begintag_2(self, load_file): - sql = load_file('begintag_2.sql') - stmts = sqlparse.parse(sql) - assert len(stmts) == 1 - assert ''.join(str(q) for q in stmts) == sql - - def test_dropif(self): - sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;' - stmts = sqlparse.parse(sql) - assert len(stmts) == 2 - assert ''.join(str(q) for q in stmts) == sql - - def test_comment_with_umlaut(self): - sql = (u'select * from foo;\n' - u'-- Testing an umlaut: ä\n' - u'select * from bar;') - stmts = sqlparse.parse(sql) - assert len(stmts) == 2 - assert ''.join(str(q) for q in stmts) == sql - - def test_comment_end_of_line(self): - sql = ('select * from foo; -- foo\n' - 'select * from bar;') - stmts = sqlparse.parse(sql) - assert len(stmts) == 2 - assert ''.join(str(q) for q in stmts) == sql - # make sure the comment belongs to first query - assert str(stmts[0]) == 'select * from foo; -- foo\n' - - def test_casewhen(self): - sql = ('SELECT case when val = 1 then 2 else null end as foo;\n' - 'comment on table actor is \'The actor table.\';') - stmts = sqlparse.split(sql) - assert len(stmts) == 2 - - def test_cursor_declare(self): - sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n' - 'SELECT 2;') - stmts = sqlparse.split(sql) - assert len(stmts) == 2 - - def test_if_function(self): # see issue 33 - # don't let IF as a function confuse the splitter - sql = ('CREATE TEMPORARY TABLE tmp ' - 'SELECT IF(a=1, a, b) AS o FROM one; ' - 'SELECT t FROM two') - stmts = sqlparse.split(sql) - assert len(stmts) == 2 - - def test_split_stream(self): - stream = StringIO("SELECT 1; SELECT 2;") - stmts = sqlparse.parsestream(stream) - assert isinstance(stmts, types.GeneratorType) - assert len(list(stmts)) == 2 - - def test_encoding_parsestream(self): - stream = StringIO("SELECT 1; SELECT 2;") - stmts = list(sqlparse.parsestream(stream)) - assert isinstance(stmts[0].tokens[0].value, text_type) - - def test_unicode_parsestream(self): - stream = StringIO(u"SELECT ö") - stmts = list(sqlparse.parsestream(stream)) - assert str(stmts[0]) == "SELECT ö" +def test_split_semicolon(): + sql1 = 'select * from foo;' + sql2 = "select * from foo where bar = 'foo;bar';" + stmts = sqlparse.parse(''.join([sql1, sql2])) + assert len(stmts) == 2 + assert str(stmts[0]) == sql1 + assert str(stmts[1]) == sql2 + + +def test_split_backslash(): + stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';") + assert len(stmts) == 3 + + +def test_split_create_function(load_file): + sql = load_file('function.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 1 + assert str(stmts[0]) == sql + + +def test_split_create_function_psql(load_file): + sql = load_file('function_psql.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 1 + assert text_type(stmts[0]) == sql + + +def test_split_create_function_psql3(load_file): + sql = load_file('function_psql3.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 1 + assert str(stmts[0]) == sql + + +def test_split_create_function_psql2(load_file): + sql = load_file('function_psql2.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 1 + assert str(stmts[0]) == sql + + +def test_split_dashcomments(load_file): + sql = load_file('dashcomment.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 3 + assert ''.join(str(q) for q in stmts) == sql + + +def test_split_dashcomments_eol(): + stmts = sqlparse.parse('select foo; -- comment\n') + assert len(stmts) == 1 + stmts = sqlparse.parse('select foo; -- comment\r') + assert len(stmts) == 1 + stmts = sqlparse.parse('select foo; -- comment\r\n') + assert len(stmts) == 1 + stmts = sqlparse.parse('select foo; -- comment') + assert len(stmts) == 1 + + +def test_split_begintag(load_file): + sql = load_file('begintag.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 3 + assert ''.join(str(q) for q in stmts) == sql + + +def test_split_begintag_2(load_file): + sql = load_file('begintag_2.sql') + stmts = sqlparse.parse(sql) + assert len(stmts) == 1 + assert ''.join(str(q) for q in stmts) == sql + + +def test_split_dropif(): + sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;' + stmts = sqlparse.parse(sql) + assert len(stmts) == 2 + assert ''.join(str(q) for q in stmts) == sql + + +def test_split_comment_with_umlaut(): + sql = (u'select * from foo;\n' + u'-- Testing an umlaut: ä\n' + u'select * from bar;') + stmts = sqlparse.parse(sql) + assert len(stmts) == 2 + assert ''.join(text_type(q) for q in stmts) == sql + + +def test_split_comment_end_of_line(): + sql = ('select * from foo; -- foo\n' + 'select * from bar;') + stmts = sqlparse.parse(sql) + assert len(stmts) == 2 + assert ''.join(str(q) for q in stmts) == sql + # make sure the comment belongs to first query + assert str(stmts[0]) == 'select * from foo; -- foo\n' + + +def test_split_casewhen(): + sql = ("SELECT case when val = 1 then 2 else null end as foo;\n" + "comment on table actor is 'The actor table.';") + stmts = sqlparse.split(sql) + assert len(stmts) == 2 + + +def test_split_cursor_declare(): + sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n' + 'SELECT 2;') + stmts = sqlparse.split(sql) + assert len(stmts) == 2 + + +def test_split_if_function(): # see issue 33 + # don't let IF as a function confuse the splitter + sql = ('CREATE TEMPORARY TABLE tmp ' + 'SELECT IF(a=1, a, b) AS o FROM one; ' + 'SELECT t FROM two') + stmts = sqlparse.split(sql) + assert len(stmts) == 2 + + +def test_split_stream(): + stream = StringIO("SELECT 1; SELECT 2;") + stmts = sqlparse.parsestream(stream) + assert isinstance(stmts, types.GeneratorType) + assert len(list(stmts)) == 2 + + +def test_split_encoding_parsestream(): + stream = StringIO("SELECT 1; SELECT 2;") + stmts = list(sqlparse.parsestream(stream)) + assert isinstance(stmts[0].tokens[0].value, text_type) + + +def test_split_unicode_parsestream(): + stream = StringIO(u'SELECT ö') + stmts = list(sqlparse.parsestream(stream)) + assert str(stmts[0]) == 'SELECT ö' def test_split_simple(): diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py index 4234043..a2cc388 100644 --- a/tests/test_tokenize.py +++ b/tests/test_tokenize.py @@ -10,145 +10,149 @@ from sqlparse import sql, tokens as T from sqlparse.compat import StringIO -class TestTokenize(object): - def test_simple(self): - s = 'select * from foo;' - stream = lexer.tokenize(s) - assert isinstance(stream, types.GeneratorType) - tokens = list(stream) - assert len(tokens) == 8 - assert len(tokens[0]) == 2 - assert tokens[0] == (T.Keyword.DML, 'select') - assert tokens[-1] == (T.Punctuation, ';') - - def test_backticks(self): - s = '`foo`.`bar`' - tokens = list(lexer.tokenize(s)) - assert len(tokens) == 3 - assert tokens[0] == (T.Name, '`foo`') - - def test_linebreaks(self): # issue1 - s = 'foo\nbar\n' - tokens = lexer.tokenize(s) - assert ''.join(str(x[1]) for x in tokens) == s - s = 'foo\rbar\r' - tokens = lexer.tokenize(s) - assert ''.join(str(x[1]) for x in tokens) == s - s = 'foo\r\nbar\r\n' - tokens = lexer.tokenize(s) - assert ''.join(str(x[1]) for x in tokens) == s - s = 'foo\r\nbar\n' - tokens = lexer.tokenize(s) - assert ''.join(str(x[1]) for x in tokens) == s - - def test_inline_keywords(self): # issue 7 - s = "create created_foo" - tokens = list(lexer.tokenize(s)) - assert len(tokens) == 3 - assert tokens[0][0] == T.Keyword.DDL - assert tokens[2][0] == T.Name - assert tokens[2][1] == 'created_foo' - s = "enddate" - tokens = list(lexer.tokenize(s)) - assert len(tokens) == 1 - assert tokens[0][0] == T.Name - s = "join_col" - tokens = list(lexer.tokenize(s)) - assert len(tokens) == 1 - assert tokens[0][0] == T.Name - s = "left join_col" - tokens = list(lexer.tokenize(s)) - assert len(tokens) == 3 - assert tokens[2][0] == T.Name - assert tokens[2][1] == 'join_col' - - def test_negative_numbers(self): - s = "values(-1)" - tokens = list(lexer.tokenize(s)) - assert len(tokens) == 4 - assert tokens[2][0] == T.Number.Integer - assert tokens[2][1] == '-1' - - -class TestToken(object): - def test_str(self): - token = sql.Token(None, 'FoO') - assert str(token) == 'FoO' - - def test_repr(self): - token = sql.Token(T.Keyword, 'foo') - tst = "<Keyword 'foo' at 0x" - assert repr(token)[:len(tst)] == tst - token = sql.Token(T.Keyword, '1234567890') - tst = "<Keyword '123456...' at 0x" - assert repr(token)[:len(tst)] == tst - - def test_flatten(self): - token = sql.Token(T.Keyword, 'foo') - gen = token.flatten() - assert isinstance(gen, types.GeneratorType) - lgen = list(gen) - assert lgen == [token] - - -class TestTokenList(object): - def test_repr(self): - p = sqlparse.parse('foo, bar, baz')[0] - tst = "<IdentifierList 'foo, b...' at 0x" - assert repr(p.tokens[0])[:len(tst)] == tst - - def test_token_first(self): - p = sqlparse.parse(' select foo')[0] - first = p.token_first() - assert first.value == 'select' - assert p.token_first(skip_ws=False).value == ' ' - assert sql.TokenList([]).token_first() is None - - def test_token_matching(self): - t1 = sql.Token(T.Keyword, 'foo') - t2 = sql.Token(T.Punctuation, ',') - x = sql.TokenList([t1, t2]) - assert x.token_matching( - [lambda t: t.ttype is T.Keyword], 0) == t1 - assert x.token_matching( - [lambda t: t.ttype is T.Punctuation], 0) == t2 - assert x.token_matching( - [lambda t: t.ttype is T.Keyword], 1) is None - - -class TestStream(object): - def test_simple(self): - stream = StringIO("SELECT 1; SELECT 2;") - - tokens = lexer.tokenize(stream) - assert len(list(tokens)) == 9 - - stream.seek(0) - tokens = list(lexer.tokenize(stream)) - assert len(tokens) == 9 - - stream.seek(0) - tokens = list(lexer.tokenize(stream)) - assert len(tokens) == 9 - - def test_error(self): - stream = StringIO("FOOBAR{") - - tokens = list(lexer.tokenize(stream)) - assert len(tokens) == 2 - assert tokens[1][0] == T.Error - - -@pytest.mark.parametrize('expr', ['JOIN', - 'LEFT JOIN', - 'LEFT OUTER JOIN', - 'FULL OUTER JOIN', - 'NATURAL JOIN', - 'CROSS JOIN', - 'STRAIGHT JOIN', - 'INNER JOIN', - 'LEFT INNER JOIN', - ]) +def test_tokenize_simple(): + s = 'select * from foo;' + stream = lexer.tokenize(s) + assert isinstance(stream, types.GeneratorType) + tokens = list(stream) + assert len(tokens) == 8 + assert len(tokens[0]) == 2 + assert tokens[0] == (T.Keyword.DML, 'select') + assert tokens[-1] == (T.Punctuation, ';') + + +def test_tokenize_backticks(): + s = '`foo`.`bar`' + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 3 + assert tokens[0] == (T.Name, '`foo`') + + +def test_tokenize_linebreaks(): + # issue1 + s = 'foo\nbar\n' + tokens = lexer.tokenize(s) + assert ''.join(str(x[1]) for x in tokens) == s + s = 'foo\rbar\r' + tokens = lexer.tokenize(s) + assert ''.join(str(x[1]) for x in tokens) == s + s = 'foo\r\nbar\r\n' + tokens = lexer.tokenize(s) + assert ''.join(str(x[1]) for x in tokens) == s + s = 'foo\r\nbar\n' + tokens = lexer.tokenize(s) + assert ''.join(str(x[1]) for x in tokens) == s + + +def test_tokenize_inline_keywords(): + # issue 7 + s = "create created_foo" + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 3 + assert tokens[0][0] == T.Keyword.DDL + assert tokens[2][0] == T.Name + assert tokens[2][1] == 'created_foo' + s = "enddate" + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 1 + assert tokens[0][0] == T.Name + s = "join_col" + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 1 + assert tokens[0][0] == T.Name + s = "left join_col" + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 3 + assert tokens[2][0] == T.Name + assert tokens[2][1] == 'join_col' + + +def test_tokenize_negative_numbers(): + s = "values(-1)" + tokens = list(lexer.tokenize(s)) + assert len(tokens) == 4 + assert tokens[2][0] == T.Number.Integer + assert tokens[2][1] == '-1' + + +def test_token_str(): + token = sql.Token(None, 'FoO') + assert str(token) == 'FoO' + + +def test_token_repr(): + token = sql.Token(T.Keyword, 'foo') + tst = "<Keyword 'foo' at 0x" + assert repr(token)[:len(tst)] == tst + token = sql.Token(T.Keyword, '1234567890') + tst = "<Keyword '123456...' at 0x" + assert repr(token)[:len(tst)] == tst + + +def test_token_flatten(): + token = sql.Token(T.Keyword, 'foo') + gen = token.flatten() + assert isinstance(gen, types.GeneratorType) + lgen = list(gen) + assert lgen == [token] + + +def test_tokenlist_repr(): + p = sqlparse.parse('foo, bar, baz')[0] + tst = "<IdentifierList 'foo, b...' at 0x" + assert repr(p.tokens[0])[:len(tst)] == tst + + +def test_tokenlist_first(): + p = sqlparse.parse(' select foo')[0] + first = p.token_first() + assert first.value == 'select' + assert p.token_first(skip_ws=False).value == ' ' + assert sql.TokenList([]).token_first() is None + + +def test_tokenlist_token_matching(): + t1 = sql.Token(T.Keyword, 'foo') + t2 = sql.Token(T.Punctuation, ',') + x = sql.TokenList([t1, t2]) + assert x.token_matching([lambda t: t.ttype is T.Keyword], 0) == t1 + assert x.token_matching([lambda t: t.ttype is T.Punctuation], 0) == t2 + assert x.token_matching([lambda t: t.ttype is T.Keyword], 1) is None + + +def test_stream_simple(): + stream = StringIO("SELECT 1; SELECT 2;") + + tokens = lexer.tokenize(stream) + assert len(list(tokens)) == 9 + + stream.seek(0) + tokens = list(lexer.tokenize(stream)) + assert len(tokens) == 9 + + stream.seek(0) + tokens = list(lexer.tokenize(stream)) + assert len(tokens) == 9 + + +def test_stream_error(): + stream = StringIO("FOOBAR{") + + tokens = list(lexer.tokenize(stream)) + assert len(tokens) == 2 + assert tokens[1][0] == T.Error + + +@pytest.mark.parametrize('expr', [ + 'JOIN', + 'LEFT JOIN', + 'LEFT OUTER JOIN', + 'FULL OUTER JOIN', + 'NATURAL JOIN', + 'CROSS JOIN', + 'STRAIGHT JOIN', + 'INNER JOIN', + 'LEFT INNER JOIN']) def test_parse_join(expr): p = sqlparse.parse('{0} foo'.format(expr))[0] assert len(p.tokens) == 3 |
