summaryrefslogtreecommitdiff
path: root/tests/test_split.py
blob: 8b22aadd7492eb7cf53d4ff65df47673872a82ae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# -*- coding: utf-8 -*-

# Tests splitting functions.
import types

from tests.utils import load_file, TestCaseBase

import sqlparse
from sqlparse import compat


class SQLSplitTest(TestCaseBase):
    """Tests sqlparse.sqlsplit()."""

    _sql1 = 'select * from foo;'
    _sql2 = 'select * from bar;'

    def test_split_semicolon(self):
        sql2 = 'select * from foo where bar = \'foo;bar\';'
        stmts = sqlparse.parse(''.join([self._sql1, sql2]))
        self.assertEqual(len(stmts), 2)
        self.ndiffAssertEqual(compat.text_type(stmts[0]), self._sql1)
        self.ndiffAssertEqual(compat.text_type(stmts[1]), sql2)

    def test_split_backslash(self):
        stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';")
        self.assertEqual(len(stmts), 3)

    def test_create_function(self):
        sql = load_file('function.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 1)
        self.ndiffAssertEqual(compat.text_type(stmts[0]), sql)

    def test_create_function_psql(self):
        sql = load_file('function_psql.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 1)
        self.ndiffAssertEqual(compat.text_type(stmts[0]), sql)

    def test_create_function_psql3(self):
        sql = load_file('function_psql3.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 1)
        self.ndiffAssertEqual(compat.text_type(stmts[0]), sql)

    def test_create_function_psql2(self):
        sql = load_file('function_psql2.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 1)
        self.ndiffAssertEqual(compat.text_type(stmts[0]), sql)

    def test_dashcomments(self):
        sql = load_file('dashcomment.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 3)
        self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)

    def test_dashcomments_eol(self):
        stmts = sqlparse.parse('select foo; -- comment\n')
        self.assertEqual(len(stmts), 1)
        stmts = sqlparse.parse('select foo; -- comment\r')
        self.assertEqual(len(stmts), 1)
        stmts = sqlparse.parse('select foo; -- comment\r\n')
        self.assertEqual(len(stmts), 1)
        stmts = sqlparse.parse('select foo; -- comment')
        self.assertEqual(len(stmts), 1)

    def test_begintag(self):
        sql = load_file('begintag.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 3)
        self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)

    def test_begintag_2(self):
        sql = load_file('begintag_2.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 1)
        self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)

    def test_dropif(self):
        sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 2)
        self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)

    def test_comment_with_umlaut(self):
        sql = (compat.u(
            'select * from foo;\n'
            '-- Testing an umlaut: รค\n'
            'select * from bar;'))
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 2)
        self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)

    def test_comment_end_of_line(self):
        sql = ('select * from foo; -- foo\n'
               'select * from bar;')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 2)
        self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)
        # make sure the comment belongs to first query
        self.ndiffAssertEqual(
            compat.text_type(stmts[0]), 'select * from foo; -- foo\n')

    def test_casewhen(self):
        sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
               'comment on table actor is \'The actor table.\';')
        stmts = sqlparse.split(sql)
        self.assertEqual(len(stmts), 2)

    def test_cursor_declare(self):
        sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n'
               'SELECT 2;')
        stmts = sqlparse.split(sql)
        self.assertEqual(len(stmts), 2)

    def test_if_function(self):  # see issue 33
        # don't let IF as a function confuse the splitter
        sql = ('CREATE TEMPORARY TABLE tmp '
               'SELECT IF(a=1, a, b) AS o FROM one; '
               'SELECT t FROM two')
        stmts = sqlparse.split(sql)
        self.assertEqual(len(stmts), 2)

    def test_split_stream(self):
        stream = compat.StringIO("SELECT 1; SELECT 2;")
        stmts = sqlparse.parsestream(stream)
        self.assertEqual(type(stmts), types.GeneratorType)
        self.assertEqual(len(list(stmts)), 2)

    def test_encoding_parsestream(self):
        stream = compat.StringIO("SELECT 1; SELECT 2;")
        stmts = list(sqlparse.parsestream(stream))
        self.assertEqual(type(stmts[0].tokens[0].value), compat.text_type)


def test_split_simple():
    stmts = sqlparse.split('select * from foo; select * from bar;')
    assert len(stmts) == 2
    assert stmts[0] == 'select * from foo;'
    assert stmts[1] == 'select * from bar;'