diff options
Diffstat (limited to 'test/ext/compiler.py')
-rw-r--r-- | test/ext/compiler.py | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/test/ext/compiler.py b/test/ext/compiler.py new file mode 100644 index 000000000..370ea62ab --- /dev/null +++ b/test/ext/compiler.py @@ -0,0 +1,126 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy.sql.expression import ClauseElement, ColumnClause +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql import table, column +from testlib import * + +class UserDefinedTest(TestBase, AssertsCompiledSQL): + + def test_column(self): + + class MyThingy(ColumnClause): + def __init__(self, arg= None): + super(MyThingy, self).__init__(arg or 'MYTHINGY!') + + @compiles(MyThingy) + def visit_thingy(thingy, compiler, **kw): + return ">>%s<<" % thingy.name + + self.assert_compile( + select([column('foo'), MyThingy()]), + "SELECT foo, >>MYTHINGY!<<" + ) + + self.assert_compile( + select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5), + "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1" + ) + + def test_stateful(self): + class MyThingy(ColumnClause): + def __init__(self): + super(MyThingy, self).__init__('MYTHINGY!') + + @compiles(MyThingy) + def visit_thingy(thingy, compiler, **kw): + if not hasattr(compiler, 'counter'): + compiler.counter = 0 + compiler.counter += 1 + return str(compiler.counter) + + self.assert_compile( + select([column('foo'), MyThingy()]).order_by(desc(MyThingy())), + "SELECT foo, 1 ORDER BY 2 DESC" + ) + + self.assert_compile( + select([MyThingy(), MyThingy()]).where(MyThingy() == 5), + "SELECT 1, 2 WHERE 3 = :MYTHINGY!_1" + ) + + def test_callout_to_compiler(self): + class InsertFromSelect(ClauseElement): + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True), + compiler.process(element.select) + ) + + t1 = table("mytable", column('x'), column('y'), column('z')) + self.assert_compile( + InsertFromSelect( + t1, + select([t1]).where(t1.c.x>5) + ), + "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)" + ) + + def test_dialect_specific(self): + class AddThingy(ClauseElement): + __visit_name__ = 'add_thingy' + + class DropThingy(ClauseElement): + __visit_name__ = 'drop_thingy' + + @compiles(AddThingy, 'sqlite') + def visit_add_thingy(thingy, compiler, **kw): + return "ADD SPECIAL SL THINGY" + + @compiles(AddThingy) + def visit_add_thingy(thingy, compiler, **kw): + return "ADD THINGY" + + @compiles(DropThingy) + def visit_drop_thingy(thingy, compiler, **kw): + return "DROP THINGY" + + self.assert_compile(AddThingy(), + "ADD THINGY" + ) + + self.assert_compile(DropThingy(), + "DROP THINGY" + ) + + from sqlalchemy.databases import sqlite as base + self.assert_compile(AddThingy(), + "ADD SPECIAL SL THINGY", + dialect=base.dialect() + ) + + self.assert_compile(DropThingy(), + "DROP THINGY", + dialect=base.dialect() + ) + + @compiles(DropThingy, 'sqlite') + def visit_drop_thingy(thingy, compiler, **kw): + return "DROP SPECIAL SL THINGY" + + self.assert_compile(DropThingy(), + "DROP SPECIAL SL THINGY", + dialect=base.dialect() + ) + + self.assert_compile(DropThingy(), + "DROP THINGY", + ) + +if __name__ == '__main__': + testenv.main() |