diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-06-02 18:05:47 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-06-02 18:05:47 -0400 |
commit | af3fa1f69c077789b8a1c5078d1bb94a8d5e2240 (patch) | |
tree | 2cf4009db2cb4c3793d400147e8308dfa9e08b4a | |
parent | 32716eae773e6f6b7f37baf705342c1ed89df461 (diff) | |
download | sqlalchemy-af3fa1f69c077789b8a1c5078d1bb94a8d5e2240.tar.gz |
implement join rewriting inside of visit_select(). Currently this is global or not based on fixing nested_join_translation as True or not.
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 74 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 7 |
3 files changed, 75 insertions, 13 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 41ef20a7a..030d6dce9 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1077,23 +1077,64 @@ class SQLCompiler(engine.Compiled): def get_crud_hint_text(self, table, text): return None + def _transform_select_for_nested_joins(self, select): + adapters = [] + + traverse_options = {"cloned": {}} + + def visit_join(elem): + if isinstance(elem.right, sql.FromGrouping): + selectable = sql.select([elem.right.element], use_labels=True) + selectable = selectable.alias() + + while adapters: + adapt = adapters.pop(-1) + selectable = adapt.traverse(selectable) + + for c in selectable.c: + c._label = c._key_label = c.name + + elem.right = selectable + adapters.append( + sql_util.ClauseAdapter(selectable, + traverse_options=traverse_options) + ) + + select = visitors.cloned_traverse(select, + traverse_options, {"join": visit_join}) + + for adap in reversed(adapters): + select = adap.traverse(select) + return select + + def _transform_result_map_for_nested_joins(self, select, transformed_select): + d = dict(zip(transformed_select.inner_columns, select.inner_columns)) + for key, (name, objs, typ) in list(self.result_map.items()): + objs = tuple([d.get(col, col) for col in objs]) + self.result_map[key] = (name, objs, typ) + def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, fromhints=None, compound_index=0, force_result_map=False, - positional_names=None, **kwargs): - entry = self.stack and self.stack[-1] or {} + positional_names=None, + nested_join_translation=False, **kwargs): + + #nested_join_translation = True + if not nested_join_translation: + transformed_select = self._transform_select_for_nested_joins(select) + text = self.visit_select( + transformed_select, asfrom=asfrom, parens=parens, + iswrapper=iswrapper, fromhints=fromhints, + compound_index=compound_index, + force_result_map=force_result_map, + positional_names=positional_names, + nested_join_translation=True, **kwargs + ) - existingfroms = entry.get('from', None) - froms = select._get_display_froms(existingfroms, asfrom=asfrom) - correlate_froms = set(sql._from_objects(*froms)) - - # TODO: might want to propagate existing froms for - # select(select(select)) where innermost select should correlate - # to outermost if existingfroms: correlate_froms = - # correlate_froms.union(existingfroms) + entry = self.stack and self.stack[-1] or {} populate_result_map = force_result_map or ( compound_index == 0 and ( @@ -1102,6 +1143,19 @@ class SQLCompiler(engine.Compiled): ) ) + if not nested_join_translation: + if populate_result_map: + self._transform_result_map_for_nested_joins( + select, transformed_select) + return text + + existingfroms = entry.get('from', None) + + froms = select._get_display_froms(existingfroms, asfrom=asfrom) + + correlate_froms = set(sql._from_objects(*froms)) + + self.stack.append({'from': correlate_froms, 'iswrapper': iswrapper}) diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 91740dc16..ffa07d3df 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -797,8 +797,11 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): def __init__(self, selectable, equivalents=None, include=None, exclude=None, include_fn=None, exclude_fn=None, - adapt_on_names=False): + adapt_on_names=False, + traverse_options=None): self.__traverse_options__ = {'stop_on': [selectable]} + if traverse_options: + self.__traverse_options__.update(traverse_options) self.selectable = selectable if include: assert not include_fn @@ -832,7 +835,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): def replace(self, col): if isinstance(col, expression.FromClause) and \ self.selectable.is_derived_from(col): - return self.selectable + return self.selectable elif not isinstance(col, expression.ColumnElement): return None elif self.include_fn and not self.include_fn(col): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 62f46ab64..31ac686e3 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -30,6 +30,7 @@ import operator __all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate', 'iterate_depthfirst', 'traverse_using', 'traverse', + 'traverse_depthfirst', 'cloned_traverse', 'replacement_traverse'] @@ -255,7 +256,11 @@ def cloned_traverse(obj, opts, visitors): """clone the given expression structure, allowing modifications by visitors.""" - cloned = util.column_dict() + + if "cloned" in opts: + cloned = opts['cloned'] + else: + cloned = util.column_dict() stop_on = util.column_set(opts.get('stop_on', [])) def clone(elem): |