diff options
Diffstat (limited to 'test/ext/test_compiler.py')
| -rw-r--r-- | test/ext/test_compiler.py | 213 |
1 files changed, 98 insertions, 115 deletions
diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index c23d5f2ac..22ab1f163 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -1,8 +1,13 @@ from sqlalchemy import * from sqlalchemy.types import TypeEngine -from sqlalchemy.sql.expression import ClauseElement, ColumnClause,\ - FunctionElement, Select, \ - BindParameter, ColumnElement +from sqlalchemy.sql.expression import ( + ClauseElement, + ColumnClause, + FunctionElement, + Select, + BindParameter, + ColumnElement, +) from sqlalchemy.schema import DDLElement, CreateColumn, CreateTable from sqlalchemy.ext.compiler import compiles, deregister @@ -14,92 +19,87 @@ from sqlalchemy.testing import fixtures, AssertsCompiledSQL class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_column(self): - class MyThingy(ColumnClause): def __init__(self, arg=None): - super(MyThingy, self).__init__(arg or 'MYTHINGY!') + super(MyThingy, self).__init__(arg or "MYTHINGY!") @compiles(MyThingy) def visit_thingy(thingy, compiler, **kw): return ">>%s<<" % thingy.name self.assert_compile( - select([column('foo'), MyThingy()]), - "SELECT foo, >>MYTHINGY!<<" + select([column("foo"), MyThingy()]), "SELECT foo, >>MYTHINGY!<<" ) self.assert_compile( - select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5), - "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1" + select([MyThingy("x"), MyThingy("y")]).where(MyThingy() == 5), + "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1", ) def test_create_column_skip(self): @compiles(CreateColumn) def skip_xmin(element, compiler, **kw): - if element.element.name == 'xmin': + if element.element.name == "xmin": return None else: return compiler.visit_create_column(element, **kw) - t = Table('t', MetaData(), Column('a', Integer), - Column('xmin', Integer), - Column('c', Integer)) + t = Table( + "t", + MetaData(), + Column("a", Integer), + Column("xmin", Integer), + Column("c", Integer), + ) self.assert_compile( - CreateTable(t), - "CREATE TABLE t (a INTEGER, c INTEGER)" + CreateTable(t), "CREATE TABLE t (a INTEGER, c INTEGER)" ) def test_types(self): class MyType(TypeEngine): pass - @compiles(MyType, 'sqlite') + @compiles(MyType, "sqlite") def visit_type(type, compiler, **kw): return "SQLITE_FOO" - @compiles(MyType, 'postgresql') + @compiles(MyType, "postgresql") def visit_type(type, compiler, **kw): return "POSTGRES_FOO" from sqlalchemy.dialects.sqlite import base as sqlite from sqlalchemy.dialects.postgresql import base as postgresql - self.assert_compile( - MyType(), - "SQLITE_FOO", - dialect=sqlite.dialect() - ) + self.assert_compile(MyType(), "SQLITE_FOO", dialect=sqlite.dialect()) self.assert_compile( - MyType(), - "POSTGRES_FOO", - dialect=postgresql.dialect() + MyType(), "POSTGRES_FOO", dialect=postgresql.dialect() ) def test_stateful(self): class MyThingy(ColumnClause): def __init__(self): - super(MyThingy, self).__init__('MYTHINGY!') + super(MyThingy, self).__init__("MYTHINGY!") @compiles(MyThingy) def visit_thingy(thingy, compiler, **kw): - if not hasattr(compiler, 'counter'): + if not hasattr(compiler, "counter"): compiler.counter = 0 compiler.counter += 1 return str(compiler.counter) self.assert_compile( - select([column('foo'), MyThingy()]).order_by(desc(MyThingy())), - "SELECT foo, 1 ORDER BY 2 DESC" + select([column("foo"), MyThingy()]).order_by(desc(MyThingy())), + "SELECT foo, 1 ORDER BY 2 DESC", ) self.assert_compile( select([MyThingy(), MyThingy()]).where(MyThingy() == 5), - "SELECT 1, 2 WHERE 3 = :MYTHINGY!_1" + "SELECT 1, 2 WHERE 3 = :MYTHINGY!_1", ) def test_callout_to_compiler(self): @@ -112,34 +112,31 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): def visit_insert_from_select(element, compiler, **kw): return "INSERT INTO %s (%s)" % ( compiler.process(element.table, asfrom=True), - compiler.process(element.select) + compiler.process(element.select), ) - t1 = table("mytable", column('x'), column('y'), column('z')) + t1 = table("mytable", column("x"), column("y"), column("z")) self.assert_compile( - InsertFromSelect( - t1, - select([t1]).where(t1.c.x > 5) - ), + InsertFromSelect(t1, select([t1]).where(t1.c.x > 5)), "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z " - "FROM mytable WHERE mytable.x > :x_1)" + "FROM mytable WHERE mytable.x > :x_1)", ) def test_no_default_but_has_a_visit(self): class MyThingy(ColumnClause): pass - @compiles(MyThingy, 'postgresql') + @compiles(MyThingy, "postgresql") def visit_thingy(thingy, compiler, **kw): return "mythingy" - eq_(str(MyThingy('x')), "x") + eq_(str(MyThingy("x")), "x") def test_no_default_has_no_visit(self): class MyThingy(TypeEngine): pass - @compiles(MyThingy, 'postgresql') + @compiles(MyThingy, "postgresql") def visit_thingy(thingy, compiler, **kw): return "mythingy" @@ -147,14 +144,15 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): exc.CompileError, "<class 'test.ext.test_compiler..*MyThingy'> " "construct has no default compilation handler.", - str, MyThingy() + str, + MyThingy(), ) def test_no_default_message(self): class MyThingy(ClauseElement): pass - @compiles(MyThingy, 'postgresql') + @compiles(MyThingy, "postgresql") def visit_thingy(thingy, compiler, **kw): return "mythingy" @@ -162,7 +160,8 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): exc.CompileError, "<class 'test.ext.test_compiler..*MyThingy'> " "construct has no default compilation handler.", - str, MyThingy() + str, + MyThingy(), ) def test_default_subclass(self): @@ -176,9 +175,7 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): return "array" self.assert_compile( - MyArray(Integer), - "INTEGER[]", - dialect="postgresql" + MyArray(Integer), "INTEGER[]", dialect="postgresql" ) def test_annotations(self): @@ -187,35 +184,31 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): """ - t1 = table('t1', column('c1'), column('c2')) + t1 = table("t1", column("c1"), column("c2")) dispatch = Select._compiler_dispatch try: + @compiles(Select) def compile(element, compiler, **kw): return "OVERRIDE" s1 = select([t1]) - self.assert_compile( - s1, "OVERRIDE" - ) - self.assert_compile( - s1._annotate({}), - "OVERRIDE" - ) + self.assert_compile(s1, "OVERRIDE") + self.assert_compile(s1._annotate({}), "OVERRIDE") finally: Select._compiler_dispatch = dispatch - if hasattr(Select, '_compiler_dispatcher'): + if hasattr(Select, "_compiler_dispatcher"): del Select._compiler_dispatcher def test_dialect_specific(self): class AddThingy(DDLElement): - __visit_name__ = 'add_thingy' + __visit_name__ = "add_thingy" class DropThingy(DDLElement): - __visit_name__ = 'drop_thingy' + __visit_name__ = "drop_thingy" - @compiles(AddThingy, 'sqlite') + @compiles(AddThingy, "sqlite") def visit_add_thingy(thingy, compiler, **kw): return "ADD SPECIAL SL THINGY" @@ -232,21 +225,22 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(DropThingy(), "DROP THINGY") from sqlalchemy.dialects.sqlite import base - self.assert_compile(AddThingy(), - "ADD SPECIAL SL THINGY", - dialect=base.dialect()) - self.assert_compile(DropThingy(), - "DROP THINGY", - dialect=base.dialect()) + self.assert_compile( + AddThingy(), "ADD SPECIAL SL THINGY", dialect=base.dialect() + ) - @compiles(DropThingy, 'sqlite') + self.assert_compile( + DropThingy(), "DROP THINGY", dialect=base.dialect() + ) + + @compiles(DropThingy, "sqlite") def visit_drop_thingy(thingy, compiler, **kw): return "DROP SPECIAL SL THINGY" - self.assert_compile(DropThingy(), - "DROP SPECIAL SL THINGY", - dialect=base.dialect()) + self.assert_compile( + DropThingy(), "DROP SPECIAL SL THINGY", dialect=base.dialect() + ) self.assert_compile(DropThingy(), "DROP THINGY") @@ -260,19 +254,17 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): def visit_myfunc(element, compiler, **kw): return "utcnow()" - @compiles(MyUtcFunction, 'postgresql') + @compiles(MyUtcFunction, "postgresql") def visit_myfunc(element, compiler, **kw): return "timezone('utc', current_timestamp)" self.assert_compile( - MyUtcFunction(), - "utcnow()", - use_default_dialect=True + MyUtcFunction(), "utcnow()", use_default_dialect=True ) self.assert_compile( MyUtcFunction(), "timezone('utc', current_timestamp)", - dialect=postgresql.dialect() + dialect=postgresql.dialect(), ) def test_function_calls_base(self): @@ -280,13 +272,13 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): class greatest(FunctionElement): type = Numeric() - name = 'greatest' + name = "greatest" @compiles(greatest) def default_greatest(element, compiler, **kw): return compiler.visit_function(element) - @compiles(greatest, 'mssql') + @compiles(greatest, "mssql") def case_greatest(element, compiler, **kw): arg1, arg2 = list(element.clauses) return "CASE WHEN %s > %s THEN %s ELSE %s END" % ( @@ -297,26 +289,26 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile( - greatest('a', 'b'), - 'greatest(:greatest_1, :greatest_2)', - use_default_dialect=True + greatest("a", "b"), + "greatest(:greatest_1, :greatest_2)", + use_default_dialect=True, ) self.assert_compile( - greatest('a', 'b'), + greatest("a", "b"), "CASE WHEN :greatest_1 > :greatest_2 " "THEN :greatest_1 ELSE :greatest_2 END", - dialect=mssql.dialect() + dialect=mssql.dialect(), ) def test_subclasses_one(self): class Base(FunctionElement): - name = 'base' + name = "base" class Sub1(Base): - name = 'sub1' + name = "sub1" class Sub2(Base): - name = 'sub2' + name = "sub2" @compiles(Base) def visit_base(element, compiler, **kw): @@ -328,31 +320,31 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([Sub1(), Sub2()]), - 'SELECT FOOsub1, sub2', - use_default_dialect=True + "SELECT FOOsub1, sub2", + use_default_dialect=True, ) def test_subclasses_two(self): class Base(FunctionElement): - name = 'base' + name = "base" class Sub1(Base): - name = 'sub1' + name = "sub1" @compiles(Base) def visit_base(element, compiler, **kw): return element.name class Sub2(Base): - name = 'sub2' + name = "sub2" class SubSub1(Sub1): - name = 'subsub1' + name = "subsub1" self.assert_compile( select([Sub1(), Sub2(), SubSub1()]), - 'SELECT sub1, sub2, subsub1', - use_default_dialect=True + "SELECT sub1, sub2, subsub1", + use_default_dialect=True, ) @compiles(Sub1) @@ -361,42 +353,36 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([Sub1(), Sub2(), SubSub1()]), - 'SELECT FOOsub1, sub2, FOOsubsub1', - use_default_dialect=True + "SELECT FOOsub1, sub2, FOOsubsub1", + use_default_dialect=True, ) class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL): """Test replacement of default compilation on existing constructs.""" - __dialect__ = 'default' + + __dialect__ = "default" def teardown(self): for cls in (Select, BindParameter): deregister(cls) def test_select(self): - t1 = table('t1', column('c1'), column('c2')) + t1 = table("t1", column("c1"), column("c2")) - @compiles(Select, 'sqlite') + @compiles(Select, "sqlite") def compile(element, compiler, **kw): return "OVERRIDE" s1 = select([t1]) - self.assert_compile( - s1, "SELECT t1.c1, t1.c2 FROM t1", - ) + self.assert_compile(s1, "SELECT t1.c1, t1.c2 FROM t1") from sqlalchemy.dialects.sqlite import base as sqlite - self.assert_compile( - s1, "OVERRIDE", - dialect=sqlite.dialect() - ) + + self.assert_compile(s1, "OVERRIDE", dialect=sqlite.dialect()) def test_binds_in_select(self): - t = table('t', - column('a'), - column('b'), - column('c')) + t = table("t", column("a"), column("b"), column("c")) @compiles(BindParameter) def gen_bind(element, compiler, **kw): @@ -405,14 +391,11 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( t.select().where(t.c.c == 5), "SELECT t.a, t.b, t.c FROM t WHERE t.c = BIND(:c_1)", - use_default_dialect=True + use_default_dialect=True, ) def test_binds_in_dml(self): - t = table('t', - column('a'), - column('b'), - column('c')) + t = table("t", column("a"), column("b"), column("c")) @compiles(BindParameter) def gen_bind(element, compiler, **kw): @@ -421,6 +404,6 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( t.insert(), "INSERT INTO t (a, b) VALUES (BIND(:a), BIND(:b))", - {'a': 1, 'b': 2}, - use_default_dialect=True + {"a": 1, "b": 2}, + use_default_dialect=True, ) |
