summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-10-24 15:37:06 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2012-10-24 15:37:06 -0400
commit07d7c4905d65b7f28c1ffcbd33f81ee52c9fd847 (patch)
treea9c1213e4db2f905d65ae5cef2eed101560fe9be
parente656bf4f47cf3c06975c6207ea6e54131b292bf7 (diff)
downloadsqlalchemy-07d7c4905d65b7f28c1ffcbd33f81ee52c9fd847.tar.gz
Fixed bug where keyword arguments passed to
:meth:`.Compiler.process` wouldn't get propagated to the column expressions present in the columns clause of a SELECT statement. In particular this would come up when used by custom compilation schemes that relied upon special flags. [ticket:2593]
-rw-r--r--doc/build/changelog/changelog_08.rst11
-rw-r--r--lib/sqlalchemy/engine/interfaces.py8
-rw-r--r--lib/sqlalchemy/sql/compiler.py14
-rw-r--r--test/sql/test_compiler.py49
4 files changed, 76 insertions, 6 deletions
diff --git a/doc/build/changelog/changelog_08.rst b/doc/build/changelog/changelog_08.rst
index c741cdda9..2efcee98e 100644
--- a/doc/build/changelog/changelog_08.rst
+++ b/doc/build/changelog/changelog_08.rst
@@ -9,6 +9,17 @@
:released:
.. change::
+ :tags: sql, bug
+ :tickets: 2593
+
+ Fixed bug where keyword arguments passed to
+ :meth:`.Compiler.process` wouldn't get propagated
+ to the column expressions present in the columns
+ clause of a SELECT statement. In particular this would
+ come up when used by custom compilation schemes that
+ relied upon special flags.
+
+ .. change::
:tags: sql, feature
Added a new method :meth:`.Engine.execution_options`
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index e9e0da436..c60120166 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -673,7 +673,8 @@ class Compiled(object):
defaults.
"""
- def __init__(self, dialect, statement, bind=None):
+ def __init__(self, dialect, statement, bind=None,
+ compile_kwargs=util.immutabledict()):
"""Construct a new ``Compiled`` object.
:param dialect: ``Dialect`` to compile against.
@@ -682,6 +683,9 @@ class Compiled(object):
:param bind: Optional Engine or Connection to compile this
statement against.
+
+ :param compile_kwargs: additional kwargs that will be
+ passed to the initial call to :meth:`.Compiled.process`.
"""
self.dialect = dialect
@@ -689,7 +693,7 @@ class Compiled(object):
if statement is not None:
self.statement = statement
self.can_execute = statement.supports_execution
- self.string = self.process(self.statement)
+ self.string = self.process(self.statement, **compile_kwargs)
@util.deprecated("0.7", ":class:`.Compiled` objects now compile "
"within the constructor.")
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 6da51c31c..0847335c2 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1044,9 +1044,12 @@ class SQLCompiler(engine.Compiled):
else:
result_expr = col_expr
+ column_clause_args.update(
+ within_columns_clause=within_columns_clause,
+ add_to_result_map=add_to_result_map
+ )
return result_expr._compiler_dispatch(
- self, within_columns_clause=within_columns_clause,
- add_to_result_map=add_to_result_map,
+ self,
**column_clause_args
)
@@ -1098,7 +1101,12 @@ class SQLCompiler(engine.Compiled):
self.stack.append({'from': correlate_froms,
'iswrapper': iswrapper})
- column_clause_args = {'positional_names': positional_names}
+ column_clause_args = kwargs.copy()
+ column_clause_args.update({
+ 'positional_names': positional_names,
+ 'within_label_clause': False,
+ 'within_columns_clause': False
+ })
# the actual list of columns to print in the SELECT column list.
inner_columns = [
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index bb819472a..50b425a01 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -18,7 +18,7 @@ from sqlalchemy import Integer, String, MetaData, Table, Column, select, \
insert, literal, and_, null, type_coerce, alias, or_, literal_column,\
Float, TIMESTAMP, Numeric, Date, Text, collate, union, except_,\
intersect, union_all, Boolean, distinct, join, outerjoin, asc, desc,\
- over, subquery
+ over, subquery, case
import decimal
from sqlalchemy import exc, sql, util, types, schema
from sqlalchemy.sql import table, column, label
@@ -2437,6 +2437,53 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
)
+class KwargPropagationTest(fixtures.TestBase):
+
+ @classmethod
+ def setup_class(cls):
+ from sqlalchemy.sql.expression import ColumnClause, TableClause
+ class CatchCol(ColumnClause):
+ pass
+
+ class CatchTable(TableClause):
+ pass
+
+ cls.column = CatchCol("x")
+ cls.table = CatchTable("y")
+ cls.criterion = cls.column == CatchCol('y')
+
+ @compiles(CatchCol)
+ def compile_col(element, compiler, **kw):
+ assert "canary" in kw
+ return compiler.visit_column(element)
+
+ @compiles(CatchTable)
+ def compile_table(element, compiler, **kw):
+ assert "canary" in kw
+ return compiler.visit_table(element)
+
+ def _do_test(self, element):
+ d = default.DefaultDialect()
+ d.statement_compiler(d, element,
+ compile_kwargs={"canary": True})
+
+ def test_binary(self):
+ self._do_test(self.column == 5)
+
+ def test_select(self):
+ s = select([self.column]).select_from(self.table).\
+ where(self.column == self.criterion).\
+ order_by(self.column)
+ self._do_test(s)
+
+ def test_case(self):
+ c = case([(self.criterion, self.column)], else_=self.column)
+ self._do_test(c)
+
+ def test_cast(self):
+ c = cast(self.column, Integer)
+ self._do_test(c)
+
class CRUDTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = 'default'