summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py33
1 files changed, 25 insertions, 8 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 91740dc16..6f4d27e1b 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -200,15 +200,28 @@ def clause_is_present(clause, search):
"""
- stack = [search]
- while stack:
- elem = stack.pop()
+ for elem in surface_selectables(search):
if clause == elem: # use == here so that Annotated's compare
return True
- elif isinstance(elem, expression.Join):
+ else:
+ return False
+
+def surface_selectables(clause):
+ stack = [clause]
+ while stack:
+ elem = stack.pop()
+ yield elem
+ if isinstance(elem, expression.Join):
stack.extend((elem.left, elem.right))
- return False
+def selectables_overlap(left, right):
+ """Return True if left/right have some overlapping selectable"""
+
+ return bool(
+ set(surface_selectables(left)).intersection(
+ surface_selectables(right)
+ )
+ )
def bind_values(clause):
"""Return an ordered list of "bound" values in the given clause.
@@ -797,8 +810,11 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
def __init__(self, selectable, equivalents=None,
include=None, exclude=None,
include_fn=None, exclude_fn=None,
- adapt_on_names=False):
+ adapt_on_names=False,
+ traverse_options=None):
self.__traverse_options__ = {'stop_on': [selectable]}
+ if traverse_options:
+ self.__traverse_options__.update(traverse_options)
self.selectable = selectable
if include:
assert not include_fn
@@ -829,10 +845,11 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
newcol = self.selectable.c.get(col.name)
return newcol
+ magic_flag = False
def replace(self, col):
- if isinstance(col, expression.FromClause) and \
+ if not self.magic_flag and isinstance(col, expression.FromClause) and \
self.selectable.is_derived_from(col):
- return self.selectable
+ return self.selectable
elif not isinstance(col, expression.ColumnElement):
return None
elif self.include_fn and not self.include_fn(col):