summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2013-06-04 21:38:56 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2013-06-04 21:38:56 -0400
commit26ec0507be72d2e1a5abde8b7307012864a88a6b (patch)
treeb83e68a0ee000059ff176d29f68706fb0a6da830 /lib/sqlalchemy/sql/compiler.py
parentada19275299f0105f4aaed5bbe0d373ea33feea6 (diff)
parent69e9574fefd5fbb4673c99ad476a00b03fe22318 (diff)
downloadsqlalchemy-26ec0507be72d2e1a5abde8b7307012864a88a6b.tar.gz
Merge branch 'ticket_2587'
Conflicts: test/profiles.txt test/sql/test_selectable.py
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py129
1 files changed, 117 insertions, 12 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 73b094053..dd2a6e08c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1113,23 +1113,115 @@ class SQLCompiler(engine.Compiled):
def get_crud_hint_text(self, table, text):
return None
+ def _transform_select_for_nested_joins(self, select):
+ """Rewrite any "a JOIN (b JOIN c)" expression as
+ "a JOIN (select * from b JOIN c) AS anon", to support
+ databases that can't parse a parenthesized join correctly
+ (i.e. sqlite the main one).
+
+ """
+ cloned = {}
+ column_translate = [{}]
+
+ # TODO: should we be using isinstance() for this,
+ # as this whole system won't work for custom Join/Select
+ # subclasses where compilation routines
+ # call down to compiler.visit_join(), compiler.visit_select()
+ join_name = sql.Join.__visit_name__
+ select_name = sql.Select.__visit_name__
+
+ def visit(element, **kw):
+ if element in column_translate[-1]:
+ return column_translate[-1][element]
+
+ elif element in cloned:
+ return cloned[element]
+
+ newelem = cloned[element] = element._clone()
+
+ if newelem.__visit_name__ is join_name and \
+ isinstance(newelem.right, sql.FromGrouping):
+
+ newelem._reset_exported()
+ newelem.left = visit(newelem.left, **kw)
+
+ right = visit(newelem.right, **kw)
+
+ selectable = sql.select(
+ [right.element],
+ use_labels=True).alias()
+
+ for c in selectable.c:
+ c._label = c._key_label = c.name
+ translate_dict = dict(
+ zip(right.element.c, selectable.c)
+ )
+ translate_dict[right.element.left] = selectable
+ translate_dict[right.element.right] = selectable
+
+ # propagate translations that we've gained
+ # from nested visit(newelem.right) outwards
+ # to the enclosing select here. this happens
+ # only when we have more than one level of right
+ # join nesting, i.e. "a JOIN (b JOIN (c JOIN d))"
+ for k, v in list(column_translate[-1].items()):
+ if v in translate_dict:
+ # remarkably, no current ORM tests (May 2013)
+ # hit this condition, only test_join_rewriting
+ # does.
+ column_translate[-1][k] = translate_dict[v]
+
+ column_translate[-1].update(translate_dict)
+
+ newelem.right = selectable
+ newelem.onclause = visit(newelem.onclause, **kw)
+ elif newelem.__visit_name__ is select_name:
+ column_translate.append({})
+ newelem._copy_internals(clone=visit, **kw)
+ del column_translate[-1]
+ else:
+ newelem._copy_internals(clone=visit, **kw)
+
+ return newelem
+
+ return visit(select)
+
+ def _transform_result_map_for_nested_joins(self, select, transformed_select):
+ inner_col = dict((c._key_label, c) for
+ c in transformed_select.inner_columns)
+ d = dict(
+ (inner_col[c._key_label], c)
+ for c in select.inner_columns
+ )
+ for key, (name, objs, typ) in list(self.result_map.items()):
+ objs = tuple([d.get(col, col) for col in objs])
+ self.result_map[key] = (name, objs, typ)
+
def visit_select(self, select, asfrom=False, parens=True,
iswrapper=False, fromhints=None,
compound_index=0,
force_result_map=False,
- positional_names=None, **kwargs):
- entry = self.stack and self.stack[-1] or {}
-
- existingfroms = entry.get('from', None)
-
- froms = select._get_display_froms(existingfroms, asfrom=asfrom)
-
- correlate_froms = set(sql._from_objects(*froms))
+ positional_names=None,
+ nested_join_translation=False, **kwargs):
+
+ needs_nested_translation = \
+ select.use_labels and \
+ not nested_join_translation and \
+ not self.stack and \
+ not self.dialect.supports_right_nested_joins
+
+ if needs_nested_translation:
+ transformed_select = self._transform_select_for_nested_joins(select)
+ text = self.visit_select(
+ transformed_select, asfrom=asfrom, parens=parens,
+ iswrapper=iswrapper, fromhints=fromhints,
+ compound_index=compound_index,
+ force_result_map=force_result_map,
+ positional_names=positional_names,
+ nested_join_translation=True, **kwargs
+ )
- # TODO: might want to propagate existing froms for
- # select(select(select)) where innermost select should correlate
- # to outermost if existingfroms: correlate_froms =
- # correlate_froms.union(existingfroms)
+ entry = self.stack and self.stack[-1] or {}
populate_result_map = force_result_map or (
compound_index == 0 and (
@@ -1138,6 +1230,19 @@ class SQLCompiler(engine.Compiled):
)
)
+ if needs_nested_translation:
+ if populate_result_map:
+ self._transform_result_map_for_nested_joins(
+ select, transformed_select)
+ return text
+
+ existingfroms = entry.get('from', None)
+
+ froms = select._get_display_froms(existingfroms, asfrom=asfrom)
+
+ correlate_froms = set(sql._from_objects(*froms))
+
+
self.stack.append({'from': correlate_froms,
'iswrapper': iswrapper})