diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-06-04 14:30:29 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-06-04 14:30:29 -0400 |
commit | 9998e9e0131ff83a4e38e3c17a835a0854789174 (patch) | |
tree | 5bfe2cd41b4d2bfef1a0119bdfe85af45b1749d3 /lib/sqlalchemy/sql/compiler.py | |
parent | 822786dfaea7a56b16669561b4818ca1bf3a800f (diff) | |
download | sqlalchemy-9998e9e0131ff83a4e38e3c17a835a0854789174.tar.gz |
rewriting scheme now works.
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 114 |
1 files changed, 43 insertions, 71 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3e159b112..d245c781a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1078,86 +1078,58 @@ class SQLCompiler(engine.Compiled): return None def _transform_select_for_nested_joins(self, select): - adapters = [] - stop_on = [] - - # test for "unconditional" - any statement with - # no_replacement_traverse setup, i.e. query.statement, from_self(), etc. - #traverse_options = {"cloned": {}, "unconditional": True} - traverse_options = {"unconditional": True} + """Rewrite any "a JOIN (b JOIN c)" expression as + "a JOIN (select * from b JOIN c) AS anon", to support + databases that can't parse a parenthesized join correctly + (i.e. sqlite the main one). + """ cloned = {} - def thing(element, **kw): - if element in cloned: - return cloned[element] - - newelem = cloned[element] = element._clone() - - if newelem.__visit_name__ == 'join' and \ - isinstance(newelem.right, sql.FromGrouping): - selectable = sql.select([newelem.right.element], use_labels=True) - selectable = selectable.alias() - newelem.right = selectable - stop_on.append(selectable) - for c in selectable.c: - c._label = c._key_label = c.name - adapter = sql_util.ClauseAdapter(selectable, - traverse_options=traverse_options) - adapter.magic_flag = True - adapters.append(adapter) - else: - newelem._copy_internals(clone=thing, **kw) - - return newelem + column_translate = [{}] - elem = thing(select) - while adapters: - adapt = adapters.pop(-1) - adapt.__traverse_options__['stop_on'].extend(stop_on) - elem = adapt.traverse(elem) - return elem + join_name = sql.Join.__visit_name__ + select_name = sql.Select.__visit_name__ + def visit(element, **kw): + if element in column_translate[-1]: + return column_translate[-1][element] - def _transform_select_for_nested_joins(self, select): - adapters = [] - stop_on = [] - - # test for "unconditional" - any statement with - # no_replacement_traverse setup, i.e. query.statement, from_self(), etc. - #traverse_options = {"cloned": {}, "unconditional": True} - traverse_options = {"unconditional": True} + elif element in cloned: + return cloned[element] - def visit_join(elem): - if isinstance(elem.right, sql.FromGrouping): - selectable = sql.select([elem.right.element], use_labels=True) - selectable = selectable.alias() + newelem = cloned[element] = element._clone() - while adapters: - adapt = adapters.pop(-1) - selectable = adapt.traverse(selectable) - #stop_on.append(selectable) + if newelem.__visit_name__ is join_name and \ + isinstance(newelem.right, sql.FromGrouping): - # test: see test_subquery_relations: - # CyclicalInheritingEagerTestTwo.test_integrate - stop_on.append(elem.left) + newelem._reset_exported() + newelem.left = visit(newelem.left, **kw) + selectable = sql.select( + [newelem.right.element], + use_labels=True).alias() for c in selectable.c: c._label = c._key_label = c.name + translate_dict = dict( + zip(newelem.right.element.c, selectable.c) + ) + translate_dict[newelem.right.element.left] = selectable + translate_dict[newelem.right.element.right] = selectable + column_translate[-1].update(translate_dict) - elem.right = selectable - adapter = sql_util.ClauseAdapter(selectable, - traverse_options=traverse_options) - adapter.__traverse_options__['stop_on'].extend(stop_on) - adapters.append(adapter) - + newelem.right = selectable + newelem.onclause = visit(newelem.onclause, **kw) + elif newelem.__visit_name__ is select_name: + column_translate.append({}) + newelem._copy_internals(clone=visit, **kw) + del column_translate[-1] + else: + newelem._copy_internals(clone=visit, **kw) - select = visitors.cloned_traverse(select, - traverse_options, {"join": visit_join}) + return newelem - for adap in reversed(adapters): - select = adap.traverse(select) - return select + return visit(select) def _transform_result_map_for_nested_joins(self, select, transformed_select): d = dict(zip(transformed_select.inner_columns, select.inner_columns)) @@ -1172,10 +1144,12 @@ class SQLCompiler(engine.Compiled): positional_names=None, nested_join_translation=False, **kwargs): + needs_nested_translation = \ + not nested_join_translation and \ + not self.stack and \ + not self.dialect.supports_right_nested_joins - if self.dialect.supports_right_nested_joins: - nested_join_translation = True - if not nested_join_translation: + if needs_nested_translation: transformed_select = self._transform_select_for_nested_joins(select) text = self.visit_select( transformed_select, asfrom=asfrom, parens=parens, @@ -1186,8 +1160,6 @@ class SQLCompiler(engine.Compiled): nested_join_translation=True, **kwargs ) - - entry = self.stack and self.stack[-1] or {} populate_result_map = force_result_map or ( @@ -1197,7 +1169,7 @@ class SQLCompiler(engine.Compiled): ) ) - if not nested_join_translation: + if needs_nested_translation: if populate_result_map: self._transform_result_map_for_nested_joins( select, transformed_select) |