1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
|
from sqlalchemy import *
from sqlalchemy.types import TypeEngine
from sqlalchemy.sql.expression import ClauseElement, ColumnClause,\
FunctionElement
from sqlalchemy.schema import DDLElement
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import table, column
from sqlalchemy.test import *
class UserDefinedTest(TestBase, AssertsCompiledSQL):
def test_column(self):
class MyThingy(ColumnClause):
def __init__(self, arg= None):
super(MyThingy, self).__init__(arg or 'MYTHINGY!')
@compiles(MyThingy)
def visit_thingy(thingy, compiler, **kw):
return ">>%s<<" % thingy.name
self.assert_compile(
select([column('foo'), MyThingy()]),
"SELECT foo, >>MYTHINGY!<<"
)
self.assert_compile(
select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5),
"SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1"
)
def test_types(self):
class MyType(TypeEngine):
pass
@compiles(MyType, 'sqlite')
def visit_type(type, compiler, **kw):
return "SQLITE_FOO"
@compiles(MyType, 'postgresql')
def visit_type(type, compiler, **kw):
return "POSTGRES_FOO"
from sqlalchemy.dialects.sqlite import base as sqlite
from sqlalchemy.dialects.postgresql import base as postgresql
self.assert_compile(
MyType(),
"SQLITE_FOO",
dialect=sqlite.dialect()
)
self.assert_compile(
MyType(),
"POSTGRES_FOO",
dialect=postgresql.dialect()
)
def test_stateful(self):
class MyThingy(ColumnClause):
def __init__(self):
super(MyThingy, self).__init__('MYTHINGY!')
@compiles(MyThingy)
def visit_thingy(thingy, compiler, **kw):
if not hasattr(compiler, 'counter'):
compiler.counter = 0
compiler.counter += 1
return str(compiler.counter)
self.assert_compile(
select([column('foo'), MyThingy()]).order_by(desc(MyThingy())),
"SELECT foo, 1 ORDER BY 2 DESC"
)
self.assert_compile(
select([MyThingy(), MyThingy()]).where(MyThingy() == 5),
"SELECT 1, 2 WHERE 3 = :MYTHINGY!_1"
)
def test_callout_to_compiler(self):
class InsertFromSelect(ClauseElement):
def __init__(self, table, select):
self.table = table
self.select = select
@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
return "INSERT INTO %s (%s)" % (
compiler.process(element.table, asfrom=True),
compiler.process(element.select)
)
t1 = table("mytable", column('x'), column('y'), column('z'))
self.assert_compile(
InsertFromSelect(
t1,
select([t1]).where(t1.c.x>5)
),
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)"
)
def test_dialect_specific(self):
class AddThingy(DDLElement):
__visit_name__ = 'add_thingy'
class DropThingy(DDLElement):
__visit_name__ = 'drop_thingy'
@compiles(AddThingy, 'sqlite')
def visit_add_thingy(thingy, compiler, **kw):
return "ADD SPECIAL SL THINGY"
@compiles(AddThingy)
def visit_add_thingy(thingy, compiler, **kw):
return "ADD THINGY"
@compiles(DropThingy)
def visit_drop_thingy(thingy, compiler, **kw):
return "DROP THINGY"
self.assert_compile(AddThingy(),
"ADD THINGY"
)
self.assert_compile(DropThingy(),
"DROP THINGY"
)
from sqlalchemy.dialects.sqlite import base
self.assert_compile(AddThingy(),
"ADD SPECIAL SL THINGY",
dialect=base.dialect()
)
self.assert_compile(DropThingy(),
"DROP THINGY",
dialect=base.dialect()
)
@compiles(DropThingy, 'sqlite')
def visit_drop_thingy(thingy, compiler, **kw):
return "DROP SPECIAL SL THINGY"
self.assert_compile(DropThingy(),
"DROP SPECIAL SL THINGY",
dialect=base.dialect()
)
self.assert_compile(DropThingy(),
"DROP THINGY",
)
def test_functions(self):
from sqlalchemy.dialects.postgresql import base as postgresql
class MyUtcFunction(FunctionElement):
pass
@compiles(MyUtcFunction)
def visit_myfunc(element, compiler, **kw):
return "utcnow()"
@compiles(MyUtcFunction, 'postgresql')
def visit_myfunc(element, compiler, **kw):
return "timezone('utc', current_timestamp)"
self.assert_compile(
MyUtcFunction(),
"utcnow()",
use_default_dialect=True
)
self.assert_compile(
MyUtcFunction(),
"timezone('utc', current_timestamp)",
dialect=postgresql.dialect()
)
|