diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 114 | 
1 files changed, 43 insertions, 71 deletions
| diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3e159b112..d245c781a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1078,86 +1078,58 @@ class SQLCompiler(engine.Compiled):          return None      def _transform_select_for_nested_joins(self, select): -        adapters = [] -        stop_on = [] - -        # test for "unconditional" - any statement with -        # no_replacement_traverse setup, i.e. query.statement, from_self(), etc. -        #traverse_options = {"cloned": {}, "unconditional": True} -        traverse_options = {"unconditional": True} +        """Rewrite any "a JOIN (b JOIN c)" expression as +        "a JOIN (select * from b JOIN c) AS anon", to support +        databases that can't parse a parenthesized join correctly +        (i.e. sqlite the main one). +        """          cloned = {} -        def thing(element, **kw): -            if element in cloned: -                return cloned[element] - -            newelem = cloned[element] = element._clone() - -            if newelem.__visit_name__ == 'join' and \ -                isinstance(newelem.right, sql.FromGrouping): -                selectable = sql.select([newelem.right.element], use_labels=True) -                selectable = selectable.alias() -                newelem.right = selectable -                stop_on.append(selectable) -                for c in selectable.c: -                    c._label = c._key_label = c.name -                adapter = sql_util.ClauseAdapter(selectable, -                                        traverse_options=traverse_options) -                adapter.magic_flag = True -                adapters.append(adapter) -            else: -                newelem._copy_internals(clone=thing, **kw) - -            return newelem +        column_translate = [{}] -        elem = thing(select) -        while adapters: -            adapt = adapters.pop(-1) -            adapt.__traverse_options__['stop_on'].extend(stop_on) -            elem = adapt.traverse(elem) -        return elem +        join_name = sql.Join.__visit_name__ +        select_name = sql.Select.__visit_name__ +        def visit(element, **kw): +            if element in column_translate[-1]: +                return column_translate[-1][element] -    def _transform_select_for_nested_joins(self, select): -        adapters = [] -        stop_on = [] - -        # test for "unconditional" - any statement with -        # no_replacement_traverse setup, i.e. query.statement, from_self(), etc. -        #traverse_options = {"cloned": {}, "unconditional": True} -        traverse_options = {"unconditional": True} +            elif element in cloned: +                return cloned[element] -        def visit_join(elem): -            if isinstance(elem.right, sql.FromGrouping): -                selectable = sql.select([elem.right.element], use_labels=True) -                selectable = selectable.alias() +            newelem = cloned[element] = element._clone() -                while adapters: -                    adapt = adapters.pop(-1) -                    selectable = adapt.traverse(selectable) -                #stop_on.append(selectable) +            if newelem.__visit_name__ is join_name and \ +                isinstance(newelem.right, sql.FromGrouping): -                # test: see test_subquery_relations: -                # CyclicalInheritingEagerTestTwo.test_integrate -                stop_on.append(elem.left) +                newelem._reset_exported() +                newelem.left = visit(newelem.left, **kw) +                selectable = sql.select( +                                    [newelem.right.element], +                                    use_labels=True).alias()                  for c in selectable.c:                      c._label = c._key_label = c.name +                translate_dict = dict( +                        zip(newelem.right.element.c, selectable.c) +                    ) +                translate_dict[newelem.right.element.left] = selectable +                translate_dict[newelem.right.element.right] = selectable +                column_translate[-1].update(translate_dict) -                elem.right = selectable -                adapter = sql_util.ClauseAdapter(selectable, -                                        traverse_options=traverse_options) -                adapter.__traverse_options__['stop_on'].extend(stop_on) -                adapters.append(adapter) - +                newelem.right = selectable +                newelem.onclause = visit(newelem.onclause, **kw) +            elif newelem.__visit_name__ is select_name: +                column_translate.append({}) +                newelem._copy_internals(clone=visit, **kw) +                del column_translate[-1] +            else: +                newelem._copy_internals(clone=visit, **kw) -        select = visitors.cloned_traverse(select, -                                    traverse_options, {"join": visit_join}) +            return newelem -        for adap in reversed(adapters): -            select = adap.traverse(select) -        return select +        return visit(select)      def _transform_result_map_for_nested_joins(self, select, transformed_select):          d = dict(zip(transformed_select.inner_columns, select.inner_columns)) @@ -1172,10 +1144,12 @@ class SQLCompiler(engine.Compiled):                              positional_names=None,                              nested_join_translation=False, **kwargs): +        needs_nested_translation = \ +                            not nested_join_translation and \ +                            not self.stack and \ +                            not self.dialect.supports_right_nested_joins -        if self.dialect.supports_right_nested_joins: -            nested_join_translation = True -        if not nested_join_translation: +        if needs_nested_translation:              transformed_select = self._transform_select_for_nested_joins(select)              text = self.visit_select(                              transformed_select, asfrom=asfrom, parens=parens, @@ -1186,8 +1160,6 @@ class SQLCompiler(engine.Compiled):                              nested_join_translation=True, **kwargs                          ) - -          entry = self.stack and self.stack[-1] or {}          populate_result_map = force_result_map or ( @@ -1197,7 +1169,7 @@ class SQLCompiler(engine.Compiled):                                          )                                      ) -        if not nested_join_translation: +        if needs_nested_translation:              if populate_result_map:                  self._transform_result_map_for_nested_joins(                                                  select, transformed_select) | 
