diff options
| -rw-r--r-- | CHANGES | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/compiler.py | 7 | ||||
| -rw-r--r-- | test/ext/test_compiler.py | 32 |
3 files changed, 47 insertions, 2 deletions
@@ -170,7 +170,15 @@ CHANGES subclass. It cannot, however, define one that is not present in the __table__, and the error message here now works. [ticket:1821] - + +- compiler extension + - The 'default' compiler is automatically copied over + when overriding the compilation of a built in + clause construct, so no KeyError is raised if the + user-defined compiler is specific to certain + backends and compilation for a different backend + is invoked. [ticket:1838] + - documentation - Added documentation for the Inspector. [ticket:1820] diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 68c434fd9..12f1e443d 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -198,9 +198,13 @@ A big part of using the compiler extension is subclassing SQLAlchemy expression def compiles(class_, *specs): def decorate(fn): existing = class_.__dict__.get('_compiler_dispatcher', None) + existing_dispatch = class_.__dict__.get('_compiler_dispatch') if not existing: existing = _dispatcher() - + + if existing_dispatch: + existing.specs['default'] = existing_dispatch + # TODO: why is the lambda needed ? setattr(class_, '_compiler_dispatch', lambda *arg, **kw: existing(*arg, **kw)) setattr(class_, '_compiler_dispatcher', existing) @@ -208,6 +212,7 @@ def compiles(class_, *specs): if specs: for s in specs: existing.specs[s] = fn + else: existing.specs['default'] = fn return fn diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index fa1e3c162..3ed84fe61 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -125,7 +125,39 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): ) finally: Select._compiler_dispatch = dispatch + if hasattr(Select, '_compiler_dispatcher'): + del Select._compiler_dispatcher + def test_default_on_existing(self): + """test that the existing compiler function remains + as 'default' when overriding the compilation of an + existing construct.""" + + + t1 = table('t1', column('c1'), column('c2')) + + dispatch = Select._compiler_dispatch + try: + + @compiles(Select, 'sqlite') + def compile(element, compiler, **kw): + return "OVERRIDE" + + s1 = select([t1]) + self.assert_compile( + s1, "SELECT t1.c1, t1.c2 FROM t1", + ) + + from sqlalchemy.dialects.sqlite import base as sqlite + self.assert_compile( + s1, "OVERRIDE", + dialect=sqlite.dialect() + ) + finally: + Select._compiler_dispatch = dispatch + if hasattr(Select, '_compiler_dispatcher'): + del Select._compiler_dispatcher + def test_dialect_specific(self): class AddThingy(DDLElement): __visit_name__ = 'add_thingy' |
