summaryrefslogtreecommitdiff
path: root/test/ext/test_compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/ext/test_compiler.py')
-rw-r--r--test/ext/test_compiler.py213
1 files changed, 98 insertions, 115 deletions
diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py
index c23d5f2ac..22ab1f163 100644
--- a/test/ext/test_compiler.py
+++ b/test/ext/test_compiler.py
@@ -1,8 +1,13 @@
from sqlalchemy import *
from sqlalchemy.types import TypeEngine
-from sqlalchemy.sql.expression import ClauseElement, ColumnClause,\
- FunctionElement, Select, \
- BindParameter, ColumnElement
+from sqlalchemy.sql.expression import (
+ ClauseElement,
+ ColumnClause,
+ FunctionElement,
+ Select,
+ BindParameter,
+ ColumnElement,
+)
from sqlalchemy.schema import DDLElement, CreateColumn, CreateTable
from sqlalchemy.ext.compiler import compiles, deregister
@@ -14,92 +19,87 @@ from sqlalchemy.testing import fixtures, AssertsCompiledSQL
class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
- __dialect__ = 'default'
+ __dialect__ = "default"
def test_column(self):
-
class MyThingy(ColumnClause):
def __init__(self, arg=None):
- super(MyThingy, self).__init__(arg or 'MYTHINGY!')
+ super(MyThingy, self).__init__(arg or "MYTHINGY!")
@compiles(MyThingy)
def visit_thingy(thingy, compiler, **kw):
return ">>%s<<" % thingy.name
self.assert_compile(
- select([column('foo'), MyThingy()]),
- "SELECT foo, >>MYTHINGY!<<"
+ select([column("foo"), MyThingy()]), "SELECT foo, >>MYTHINGY!<<"
)
self.assert_compile(
- select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5),
- "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1"
+ select([MyThingy("x"), MyThingy("y")]).where(MyThingy() == 5),
+ "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1",
)
def test_create_column_skip(self):
@compiles(CreateColumn)
def skip_xmin(element, compiler, **kw):
- if element.element.name == 'xmin':
+ if element.element.name == "xmin":
return None
else:
return compiler.visit_create_column(element, **kw)
- t = Table('t', MetaData(), Column('a', Integer),
- Column('xmin', Integer),
- Column('c', Integer))
+ t = Table(
+ "t",
+ MetaData(),
+ Column("a", Integer),
+ Column("xmin", Integer),
+ Column("c", Integer),
+ )
self.assert_compile(
- CreateTable(t),
- "CREATE TABLE t (a INTEGER, c INTEGER)"
+ CreateTable(t), "CREATE TABLE t (a INTEGER, c INTEGER)"
)
def test_types(self):
class MyType(TypeEngine):
pass
- @compiles(MyType, 'sqlite')
+ @compiles(MyType, "sqlite")
def visit_type(type, compiler, **kw):
return "SQLITE_FOO"
- @compiles(MyType, 'postgresql')
+ @compiles(MyType, "postgresql")
def visit_type(type, compiler, **kw):
return "POSTGRES_FOO"
from sqlalchemy.dialects.sqlite import base as sqlite
from sqlalchemy.dialects.postgresql import base as postgresql
- self.assert_compile(
- MyType(),
- "SQLITE_FOO",
- dialect=sqlite.dialect()
- )
+ self.assert_compile(MyType(), "SQLITE_FOO", dialect=sqlite.dialect())
self.assert_compile(
- MyType(),
- "POSTGRES_FOO",
- dialect=postgresql.dialect()
+ MyType(), "POSTGRES_FOO", dialect=postgresql.dialect()
)
def test_stateful(self):
class MyThingy(ColumnClause):
def __init__(self):
- super(MyThingy, self).__init__('MYTHINGY!')
+ super(MyThingy, self).__init__("MYTHINGY!")
@compiles(MyThingy)
def visit_thingy(thingy, compiler, **kw):
- if not hasattr(compiler, 'counter'):
+ if not hasattr(compiler, "counter"):
compiler.counter = 0
compiler.counter += 1
return str(compiler.counter)
self.assert_compile(
- select([column('foo'), MyThingy()]).order_by(desc(MyThingy())),
- "SELECT foo, 1 ORDER BY 2 DESC"
+ select([column("foo"), MyThingy()]).order_by(desc(MyThingy())),
+ "SELECT foo, 1 ORDER BY 2 DESC",
)
self.assert_compile(
select([MyThingy(), MyThingy()]).where(MyThingy() == 5),
- "SELECT 1, 2 WHERE 3 = :MYTHINGY!_1"
+ "SELECT 1, 2 WHERE 3 = :MYTHINGY!_1",
)
def test_callout_to_compiler(self):
@@ -112,34 +112,31 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
def visit_insert_from_select(element, compiler, **kw):
return "INSERT INTO %s (%s)" % (
compiler.process(element.table, asfrom=True),
- compiler.process(element.select)
+ compiler.process(element.select),
)
- t1 = table("mytable", column('x'), column('y'), column('z'))
+ t1 = table("mytable", column("x"), column("y"), column("z"))
self.assert_compile(
- InsertFromSelect(
- t1,
- select([t1]).where(t1.c.x > 5)
- ),
+ InsertFromSelect(t1, select([t1]).where(t1.c.x > 5)),
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z "
- "FROM mytable WHERE mytable.x > :x_1)"
+ "FROM mytable WHERE mytable.x > :x_1)",
)
def test_no_default_but_has_a_visit(self):
class MyThingy(ColumnClause):
pass
- @compiles(MyThingy, 'postgresql')
+ @compiles(MyThingy, "postgresql")
def visit_thingy(thingy, compiler, **kw):
return "mythingy"
- eq_(str(MyThingy('x')), "x")
+ eq_(str(MyThingy("x")), "x")
def test_no_default_has_no_visit(self):
class MyThingy(TypeEngine):
pass
- @compiles(MyThingy, 'postgresql')
+ @compiles(MyThingy, "postgresql")
def visit_thingy(thingy, compiler, **kw):
return "mythingy"
@@ -147,14 +144,15 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
exc.CompileError,
"<class 'test.ext.test_compiler..*MyThingy'> "
"construct has no default compilation handler.",
- str, MyThingy()
+ str,
+ MyThingy(),
)
def test_no_default_message(self):
class MyThingy(ClauseElement):
pass
- @compiles(MyThingy, 'postgresql')
+ @compiles(MyThingy, "postgresql")
def visit_thingy(thingy, compiler, **kw):
return "mythingy"
@@ -162,7 +160,8 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
exc.CompileError,
"<class 'test.ext.test_compiler..*MyThingy'> "
"construct has no default compilation handler.",
- str, MyThingy()
+ str,
+ MyThingy(),
)
def test_default_subclass(self):
@@ -176,9 +175,7 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
return "array"
self.assert_compile(
- MyArray(Integer),
- "INTEGER[]",
- dialect="postgresql"
+ MyArray(Integer), "INTEGER[]", dialect="postgresql"
)
def test_annotations(self):
@@ -187,35 +184,31 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
"""
- t1 = table('t1', column('c1'), column('c2'))
+ t1 = table("t1", column("c1"), column("c2"))
dispatch = Select._compiler_dispatch
try:
+
@compiles(Select)
def compile(element, compiler, **kw):
return "OVERRIDE"
s1 = select([t1])
- self.assert_compile(
- s1, "OVERRIDE"
- )
- self.assert_compile(
- s1._annotate({}),
- "OVERRIDE"
- )
+ self.assert_compile(s1, "OVERRIDE")
+ self.assert_compile(s1._annotate({}), "OVERRIDE")
finally:
Select._compiler_dispatch = dispatch
- if hasattr(Select, '_compiler_dispatcher'):
+ if hasattr(Select, "_compiler_dispatcher"):
del Select._compiler_dispatcher
def test_dialect_specific(self):
class AddThingy(DDLElement):
- __visit_name__ = 'add_thingy'
+ __visit_name__ = "add_thingy"
class DropThingy(DDLElement):
- __visit_name__ = 'drop_thingy'
+ __visit_name__ = "drop_thingy"
- @compiles(AddThingy, 'sqlite')
+ @compiles(AddThingy, "sqlite")
def visit_add_thingy(thingy, compiler, **kw):
return "ADD SPECIAL SL THINGY"
@@ -232,21 +225,22 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(DropThingy(), "DROP THINGY")
from sqlalchemy.dialects.sqlite import base
- self.assert_compile(AddThingy(),
- "ADD SPECIAL SL THINGY",
- dialect=base.dialect())
- self.assert_compile(DropThingy(),
- "DROP THINGY",
- dialect=base.dialect())
+ self.assert_compile(
+ AddThingy(), "ADD SPECIAL SL THINGY", dialect=base.dialect()
+ )
- @compiles(DropThingy, 'sqlite')
+ self.assert_compile(
+ DropThingy(), "DROP THINGY", dialect=base.dialect()
+ )
+
+ @compiles(DropThingy, "sqlite")
def visit_drop_thingy(thingy, compiler, **kw):
return "DROP SPECIAL SL THINGY"
- self.assert_compile(DropThingy(),
- "DROP SPECIAL SL THINGY",
- dialect=base.dialect())
+ self.assert_compile(
+ DropThingy(), "DROP SPECIAL SL THINGY", dialect=base.dialect()
+ )
self.assert_compile(DropThingy(), "DROP THINGY")
@@ -260,19 +254,17 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
def visit_myfunc(element, compiler, **kw):
return "utcnow()"
- @compiles(MyUtcFunction, 'postgresql')
+ @compiles(MyUtcFunction, "postgresql")
def visit_myfunc(element, compiler, **kw):
return "timezone('utc', current_timestamp)"
self.assert_compile(
- MyUtcFunction(),
- "utcnow()",
- use_default_dialect=True
+ MyUtcFunction(), "utcnow()", use_default_dialect=True
)
self.assert_compile(
MyUtcFunction(),
"timezone('utc', current_timestamp)",
- dialect=postgresql.dialect()
+ dialect=postgresql.dialect(),
)
def test_function_calls_base(self):
@@ -280,13 +272,13 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
class greatest(FunctionElement):
type = Numeric()
- name = 'greatest'
+ name = "greatest"
@compiles(greatest)
def default_greatest(element, compiler, **kw):
return compiler.visit_function(element)
- @compiles(greatest, 'mssql')
+ @compiles(greatest, "mssql")
def case_greatest(element, compiler, **kw):
arg1, arg2 = list(element.clauses)
return "CASE WHEN %s > %s THEN %s ELSE %s END" % (
@@ -297,26 +289,26 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
)
self.assert_compile(
- greatest('a', 'b'),
- 'greatest(:greatest_1, :greatest_2)',
- use_default_dialect=True
+ greatest("a", "b"),
+ "greatest(:greatest_1, :greatest_2)",
+ use_default_dialect=True,
)
self.assert_compile(
- greatest('a', 'b'),
+ greatest("a", "b"),
"CASE WHEN :greatest_1 > :greatest_2 "
"THEN :greatest_1 ELSE :greatest_2 END",
- dialect=mssql.dialect()
+ dialect=mssql.dialect(),
)
def test_subclasses_one(self):
class Base(FunctionElement):
- name = 'base'
+ name = "base"
class Sub1(Base):
- name = 'sub1'
+ name = "sub1"
class Sub2(Base):
- name = 'sub2'
+ name = "sub2"
@compiles(Base)
def visit_base(element, compiler, **kw):
@@ -328,31 +320,31 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(
select([Sub1(), Sub2()]),
- 'SELECT FOOsub1, sub2',
- use_default_dialect=True
+ "SELECT FOOsub1, sub2",
+ use_default_dialect=True,
)
def test_subclasses_two(self):
class Base(FunctionElement):
- name = 'base'
+ name = "base"
class Sub1(Base):
- name = 'sub1'
+ name = "sub1"
@compiles(Base)
def visit_base(element, compiler, **kw):
return element.name
class Sub2(Base):
- name = 'sub2'
+ name = "sub2"
class SubSub1(Sub1):
- name = 'subsub1'
+ name = "subsub1"
self.assert_compile(
select([Sub1(), Sub2(), SubSub1()]),
- 'SELECT sub1, sub2, subsub1',
- use_default_dialect=True
+ "SELECT sub1, sub2, subsub1",
+ use_default_dialect=True,
)
@compiles(Sub1)
@@ -361,42 +353,36 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(
select([Sub1(), Sub2(), SubSub1()]),
- 'SELECT FOOsub1, sub2, FOOsubsub1',
- use_default_dialect=True
+ "SELECT FOOsub1, sub2, FOOsubsub1",
+ use_default_dialect=True,
)
class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL):
"""Test replacement of default compilation on existing constructs."""
- __dialect__ = 'default'
+
+ __dialect__ = "default"
def teardown(self):
for cls in (Select, BindParameter):
deregister(cls)
def test_select(self):
- t1 = table('t1', column('c1'), column('c2'))
+ t1 = table("t1", column("c1"), column("c2"))
- @compiles(Select, 'sqlite')
+ @compiles(Select, "sqlite")
def compile(element, compiler, **kw):
return "OVERRIDE"
s1 = select([t1])
- self.assert_compile(
- s1, "SELECT t1.c1, t1.c2 FROM t1",
- )
+ self.assert_compile(s1, "SELECT t1.c1, t1.c2 FROM t1")
from sqlalchemy.dialects.sqlite import base as sqlite
- self.assert_compile(
- s1, "OVERRIDE",
- dialect=sqlite.dialect()
- )
+
+ self.assert_compile(s1, "OVERRIDE", dialect=sqlite.dialect())
def test_binds_in_select(self):
- t = table('t',
- column('a'),
- column('b'),
- column('c'))
+ t = table("t", column("a"), column("b"), column("c"))
@compiles(BindParameter)
def gen_bind(element, compiler, **kw):
@@ -405,14 +391,11 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(
t.select().where(t.c.c == 5),
"SELECT t.a, t.b, t.c FROM t WHERE t.c = BIND(:c_1)",
- use_default_dialect=True
+ use_default_dialect=True,
)
def test_binds_in_dml(self):
- t = table('t',
- column('a'),
- column('b'),
- column('c'))
+ t = table("t", column("a"), column("b"), column("c"))
@compiles(BindParameter)
def gen_bind(element, compiler, **kw):
@@ -421,6 +404,6 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(
t.insert(),
"INSERT INTO t (a, b) VALUES (BIND(:a), BIND(:b))",
- {'a': 1, 'b': 2},
- use_default_dialect=True
+ {"a": 1, "b": 2},
+ use_default_dialect=True,
)