summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/coercions.py17
-rw-r--r--lib/sqlalchemy/sql/util.py16
2 files changed, 16 insertions, 17 deletions
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index a7a856bba..95aee0468 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -57,7 +57,7 @@ def expect(role, element, **kw):
else:
resolved = element
- if issubclass(resolved.__class__, impl._role_class):
+ if impl._role_class in resolved.__class__.__mro__:
if impl._post_coercion:
resolved = impl._post_coercion(resolved, **kw)
return resolved
@@ -102,13 +102,16 @@ class RoleImpl(object):
def _resolve_for_clause_element(self, element, argname=None, **kw):
original_element = element
- is_clause_element = False
+ is_clause_element = hasattr(element, "__clause_element__")
- while hasattr(element, "__clause_element__") and not isinstance(
- element, (elements.ClauseElement, schema.SchemaItem)
- ):
- element = element.__clause_element__()
- is_clause_element = True
+ if is_clause_element:
+ while not isinstance(
+ element, (elements.ClauseElement, schema.SchemaItem)
+ ):
+ try:
+ element = element.__clause_element__()
+ except AttributeError:
+ break
if not is_clause_element:
if self._use_inspection:
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index fe83b163c..3c7f904de 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -364,23 +364,19 @@ def surface_selectables_only(clause):
stack.append(elem.table)
-def surface_column_elements(clause, include_scalar_selects=True):
- """traverse and yield only outer-exposed column elements, such as would
- be addressable in the WHERE clause of a SELECT if this element were
- in the columns clause."""
+def extract_first_column_annotation(column, annotation_name):
+ filter_ = (FromGrouping, SelectBase)
- filter_ = (FromGrouping,)
- if not include_scalar_selects:
- filter_ += (SelectBase,)
-
- stack = deque([clause])
+ stack = deque([column])
while stack:
elem = stack.popleft()
- yield elem
+ if annotation_name in elem._annotations:
+ return elem._annotations[annotation_name]
for sub in elem.get_children():
if isinstance(sub, filter_):
continue
stack.append(sub)
+ return None
def selectables_overlap(left, right):