summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2019-09-30 15:32:18 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2019-09-30 15:32:18 +0000
commitff1124444e88260184ea807a7cbb16a1f6ee0ff4 (patch)
tree30b9afc4875cead27be166669ca4a0de5bd3e908 /lib/sqlalchemy
parent9f3539b1745cbb287a1338812872d27cde4ebf24 (diff)
parent6ddb62a8ba66b19afd41b967911ce5982250856e (diff)
downloadsqlalchemy-ff1124444e88260184ea807a7cbb16a1f6ee0ff4.tar.gz
Merge "Simplify _ColumnEntity, related"
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/attributes.py20
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py21
-rw-r--r--lib/sqlalchemy/orm/interfaces.py12
-rw-r--r--lib/sqlalchemy/orm/properties.py3
-rw-r--r--lib/sqlalchemy/orm/query.py317
-rw-r--r--lib/sqlalchemy/orm/relationships.py4
-rw-r--r--lib/sqlalchemy/orm/util.py9
-rw-r--r--lib/sqlalchemy/sql/coercions.py17
-rw-r--r--lib/sqlalchemy/sql/util.py16
9 files changed, 208 insertions, 211 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 9d404e00d..117dd4cea 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -168,18 +168,18 @@ class QueryableAttribute(
"""
return inspection.inspect(self._parententity)
- @property
+ @util.memoized_property
def expression(self):
- return self.comparator.__clause_element__()
-
- def __clause_element__(self):
- return self.comparator.__clause_element__()
+ return self.comparator.__clause_element__()._annotate(
+ {"orm_key": self.key}
+ )
- def _query_clause_element(self):
- """like __clause_element__(), but called specifically
- by :class:`.Query` to allow special behavior."""
+ @property
+ def _annotations(self):
+ return self.__clause_element__()._annotations
- return self.comparator._query_clause_element()
+ def __clause_element__(self):
+ return self.expression
def _bulk_update_tuples(self, value):
"""Return setter tuples for a bulk UPDATE."""
@@ -207,7 +207,7 @@ class QueryableAttribute(
)
def label(self, name):
- return self._query_clause_element().label(name)
+ return self.__clause_element__().label(name)
def operate(self, op, *other, **kwargs):
return op(self.comparator, *other, **kwargs)
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 28b3bc5db..075638fed 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -413,19 +413,26 @@ class CompositeProperty(DescriptorProperty):
__hash__ = None
- @property
+ @util.memoized_property
def clauses(self):
- return self.__clause_element__()
-
- def __clause_element__(self):
return expression.ClauseList(
group=False, *self._comparable_elements
)
- def _query_clause_element(self):
- return CompositeProperty.CompositeBundle(
- self.prop, self.__clause_element__()
+ def __clause_element__(self):
+ return self.expression
+
+ @util.memoized_property
+ def expression(self):
+ clauses = self.clauses._annotate(
+ {
+ "bundle": True,
+ "parententity": self._parententity,
+ "parentmapper": self._parententity,
+ "orm_key": self.prop.key,
+ }
)
+ return CompositeProperty.CompositeBundle(self.prop, clauses)
def _bulk_update_tuples(self, value):
if value is None:
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 5098a55ce..d6bdfb924 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -363,6 +363,7 @@ class PropComparator(operators.ColumnOperators):
__slots__ = "prop", "property", "_parententity", "_adapt_to_entity"
def __init__(self, prop, parentmapper, adapt_to_entity=None):
+ # type: (MapperProperty, Mapper, Optional(AliasedInsp))
self.prop = self.property = prop
self._parententity = adapt_to_entity or parentmapper
self._adapt_to_entity = adapt_to_entity
@@ -370,10 +371,15 @@ class PropComparator(operators.ColumnOperators):
def __clause_element__(self):
raise NotImplementedError("%r" % self)
- def _query_clause_element(self):
- return self.__clause_element__()
-
def _bulk_update_tuples(self, value):
+ # type: (ColumnOperators) -> List[tuple[ColumnOperators, Any]]
+ """Receive a SQL expression that represents a value in the SET
+ clause of an UPDATE statement.
+
+ Return a tuple that can be passed to a :class:`.Update` construct.
+
+ """
+
return [(self.__clause_element__(), value)]
def adapt_to_entity(self, adapt_to_entity):
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index e2c10e50a..f804d6eed 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -292,7 +292,7 @@ class ColumnProperty(StrategizedProperty):
def _memoized_method___clause_element__(self):
if self.adapter:
- return self.adapter(self.prop.columns[0])
+ return self.adapter(self.prop.columns[0], self.prop.key)
else:
# no adapter, so we aren't aliased
# assert self._parententity is self._parentmapper
@@ -300,6 +300,7 @@ class ColumnProperty(StrategizedProperty):
{
"parententity": self._parententity,
"parentmapper": self._parententity,
+ "orm_key": self.prop.key,
}
)
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 37bd77f63..3d08dce22 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -114,7 +114,6 @@ class Query(Generative):
_from_obj = ()
_join_entities = ()
_select_from_entity = None
- _mapper_adapter_map = {}
_filter_aliases = ()
_from_obj_alias = None
_joinpath = _joinpoint = util.immutabledict()
@@ -177,61 +176,23 @@ class Query(Generative):
self._primary_entity = None
self._has_mapper_entities = False
- # 1. don't run util.to_list() or _set_entity_selectables
- # if no entities were passed - major performance bottleneck
- # from lazy loader implementation when it seeks to use Query
- # class for an identity lookup, causes test_orm.py to fail
- # with thousands of extra function calls, see issue #4228
- # for why this use had to be added
- # 2. can't use classmethod on Query because session.query_cls
- # is an arbitrary callable in some user recipes, not
- # necessarily a class, so we don't have the class available.
- # see issue #4256
- # 3. can't do "if entities is not None" because we usually get here
- # from session.query() which takes in *entities.
- # 4. can't do "if entities" because users make use of undocumented
- # to_list() behavior here and they pass clause expressions that
- # can't be evaluated as boolean. See issue #4269.
- # 5. the empty tuple is a singleton in cPython, take advantage of this
- # so that we can skip for the empty "*entities" case without using
- # any Python overloadable operators.
- #
if entities is not ():
for ent in util.to_list(entities):
entity_wrapper(self, ent)
- self._set_entity_selectables(self._entities)
-
- def _set_entity_selectables(self, entities):
- self._mapper_adapter_map = d = self._mapper_adapter_map.copy()
-
- for ent in entities:
- for entity in ent.entities:
- if entity not in d:
- ext_info = inspect(entity)
- if (
- not ext_info.is_aliased_class
- and ext_info.mapper.with_polymorphic
- ):
- if (
- ext_info.mapper.persist_selectable
- not in self._polymorphic_adapters
- ):
- self._mapper_loads_polymorphically_with(
- ext_info.mapper,
- sql_util.ColumnAdapter(
- ext_info.selectable,
- ext_info.mapper._equivalent_columns,
- ),
- )
- aliased_adapter = None
- elif ext_info.is_aliased_class:
- aliased_adapter = ext_info._adapter
- else:
- aliased_adapter = None
-
- d[entity] = (ext_info, aliased_adapter)
- ent.setup_entity(*d[entity])
+ def _setup_query_adapters(self, entity, ext_info):
+ if not ext_info.is_aliased_class and ext_info.mapper.with_polymorphic:
+ if (
+ ext_info.mapper.persist_selectable
+ not in self._polymorphic_adapters
+ ):
+ self._mapper_loads_polymorphically_with(
+ ext_info.mapper,
+ sql_util.ColumnAdapter(
+ ext_info.selectable,
+ ext_info.mapper._equivalent_columns,
+ ),
+ )
def _mapper_loads_polymorphically_with(self, mapper, adapter):
for m2 in mapper._with_polymorphic_mappers or [mapper]:
@@ -1162,8 +1123,7 @@ class Query(Generative):
entity = aliased(entity, alias)
self._entities = list(self._entities)
- m = _MapperEntity(self, entity)
- self._set_entity_selectables([m])
+ _MapperEntity(self, entity)
@_generative
def with_session(self, session):
@@ -1455,12 +1415,9 @@ class Query(Generative):
of result columns to be returned."""
self._entities = list(self._entities)
- l = len(self._entities)
+
for c in column:
_ColumnEntity(self, c)
- # _ColumnEntity may add many entities if the
- # given arg is a FROM clause
- self._set_entity_selectables(self._entities[l:])
@util.pending_deprecation(
"0.7",
@@ -2464,9 +2421,13 @@ class Query(Generative):
)
else:
# add a new element to the self._from_obj list
-
if use_entity_index is not None:
- # why doesn't this work as .entity_zero_or_selectable?
+ # make use of _MapperEntity selectable, which is usually
+ # entity_zero.selectable, but if with_polymorphic() were used
+ # might be distinct
+ assert isinstance(
+ self._entities[use_entity_index], _MapperEntity
+ )
left_clause = self._entities[use_entity_index].selectable
else:
left_clause = left
@@ -3529,7 +3490,7 @@ class Query(Generative):
# we get just "SELECT 1" without any entities.
return sql.exists(
self.enable_eagerloads(False)
- .add_columns("1")
+ .add_columns(sql.literal_column("1"))
.with_labels()
.statement.with_only_columns([1])
)
@@ -4029,10 +3990,10 @@ class Query(Generative):
"""
- search = set(self._mapper_adapter_map.values())
+ search = set(context.single_inh_entities.values())
if (
self._select_from_entity
- and self._select_from_entity not in self._mapper_adapter_map
+ and self._select_from_entity not in context.single_inh_entities
):
insp = inspect(self._select_from_entity)
if insp.is_aliased_class:
@@ -4110,23 +4071,27 @@ class _MapperEntity(_QueryEntity):
self.entities = [entity]
self.expr = entity
- supports_single_entity = True
-
- use_id_for_hash = True
+ ext_info = self.entity_zero = inspect(entity)
- def setup_entity(self, ext_info, aliased_adapter):
self.mapper = ext_info.mapper
- self.aliased_adapter = aliased_adapter
+
+ if ext_info.is_aliased_class:
+ self._label_name = ext_info.name
+ else:
+ self._label_name = self.mapper.class_.__name__
+
self.selectable = ext_info.selectable
self.is_aliased_class = ext_info.is_aliased_class
self._with_polymorphic = ext_info.with_polymorphic_mappers
self._polymorphic_discriminator = ext_info.polymorphic_on
- self.entity_zero = ext_info
- if ext_info.is_aliased_class:
- self._label_name = self.entity_zero.name
- else:
- self._label_name = self.mapper.class_.__name__
- self.path = self.entity_zero._path_registry
+ self.path = ext_info._path_registry
+
+ if ext_info.mapper.with_polymorphic:
+ query._setup_query_adapters(entity, ext_info)
+
+ supports_single_entity = True
+
+ use_id_for_hash = True
def set_with_polymorphic(
self, query, cls_or_mappers, selectable, polymorphic_on
@@ -4185,7 +4150,7 @@ class _MapperEntity(_QueryEntity):
if query._polymorphic_adapters:
adapter = query._polymorphic_adapters.get(self.mapper, None)
else:
- adapter = self.aliased_adapter
+ adapter = self.entity_zero._adapter
if adapter:
if query._from_obj_alias:
@@ -4235,6 +4200,14 @@ class _MapperEntity(_QueryEntity):
def setup_context(self, query, context):
adapter = self._get_entity_clauses(query, context)
+ single_table_crit = self.mapper._single_table_criterion
+ if single_table_crit is not None:
+ ext_info = self.entity_zero
+ context.single_inh_entities[ext_info] = (
+ ext_info,
+ ext_info._adapter if ext_info.is_aliased_class else None,
+ )
+
# if self._adapted_selectable is None:
context.froms += (self.selectable,)
@@ -4352,7 +4325,9 @@ class Bundle(InspectionAttr):
return cloned
def __clause_element__(self):
- return expression.ClauseList(group=False, *self.exprs)
+ return expression.ClauseList(group=False, *self.exprs)._annotate(
+ {"bundle": True}
+ )
@property
def clauses(self):
@@ -4386,8 +4361,19 @@ class Bundle(InspectionAttr):
class _BundleEntity(_QueryEntity):
use_id_for_hash = False
- def __init__(self, query, bundle, setup_entities=True):
- query._entities.append(self)
+ def __init__(self, query, expr, setup_entities=True, parent_bundle=None):
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ query._entities.append(self)
+
+ if isinstance(
+ expr, (attributes.QueryableAttribute, interfaces.PropComparator)
+ ):
+ bundle = expr.__clause_element__()
+ else:
+ bundle = expr
+
self.bundle = self.expr = bundle
self.type = type(bundle)
self._label_name = bundle.name
@@ -4396,9 +4382,9 @@ class _BundleEntity(_QueryEntity):
if setup_entities:
for expr in bundle.exprs:
if isinstance(expr, Bundle):
- _BundleEntity(self, expr)
+ _BundleEntity(query, expr, parent_bundle=self)
else:
- _ColumnEntity(self, expr)
+ _ColumnEntity(query, expr, parent_bundle=self)
self.supports_single_entity = self.bundle.single_entity
@@ -4448,18 +4434,19 @@ class _BundleEntity(_QueryEntity):
else:
return None
- def adapt_to_selectable(self, query, sel):
- c = _BundleEntity(query, self.bundle, setup_entities=False)
+ def adapt_to_selectable(self, query, sel, parent_bundle=None):
+ c = _BundleEntity(
+ query,
+ self.bundle,
+ setup_entities=False,
+ parent_bundle=parent_bundle,
+ )
# c._label_name = self._label_name
# c.entity_zero = self.entity_zero
# c.entities = self.entities
for ent in self._entities:
- ent.adapt_to_selectable(c, sel)
-
- def setup_entity(self, ext_info, aliased_adapter):
- for ent in self._entities:
- ent.setup_entity(ext_info, aliased_adapter)
+ ent.adapt_to_selectable(query, sel, parent_bundle=c)
def setup_context(self, query, context):
for ent in self._entities:
@@ -4481,76 +4468,52 @@ class _BundleEntity(_QueryEntity):
class _ColumnEntity(_QueryEntity):
"""Column/expression based entity."""
- def __init__(self, query, column, namespace=None):
- self.expr = column
+ froms = frozenset()
+
+ def __init__(self, query, column, namespace=None, parent_bundle=None):
+ self.expr = expr = column
self.namespace = namespace
- search_entities = True
- check_column = False
-
- if isinstance(column, util.string_types):
- column = sql.literal_column(column)
- self._label_name = column.name
- search_entities = False
- check_column = True
- _entity = None
- elif isinstance(
- column, (attributes.QueryableAttribute, interfaces.PropComparator)
- ):
- _entity = getattr(column, "_parententity", None)
- if _entity is not None:
- search_entities = False
- self._label_name = column.key
- column = column._query_clause_element()
- check_column = True
- if isinstance(column, Bundle):
- _BundleEntity(query, column)
- return
+ _label_name = None
- if not isinstance(column, sql.ColumnElement):
- if hasattr(column, "_select_iterable"):
- # break out an object like Table into
- # individual columns
- for c in column._select_iterable:
- if c is column:
- break
- _ColumnEntity(query, c, namespace=column)
- else:
- return
+ column = coercions.expect(roles.ColumnsClauseRole, column)
- raise sa_exc.InvalidRequestError(
- "SQL expression, column, or mapped entity "
- "expected - got '%r'" % (column,)
- )
- elif not check_column:
+ annotations = column._annotations
+
+ if annotations.get("bundle", False):
+ _BundleEntity(query, expr, parent_bundle=parent_bundle)
+ return
+
+ orm_expr = False
+
+ if "parententity" in annotations:
+ _entity = annotations["parententity"]
+ self._label_name = _label_name = annotations.get("orm_key", None)
+ orm_expr = True
+
+ if hasattr(column, "_select_iterable"):
+ # break out an object like Table into
+ # individual columns
+ for c in column._select_iterable:
+ if c is column:
+ break
+ _ColumnEntity(query, c, namespace=column)
+ else:
+ return
+
+ if _label_name is None:
self._label_name = getattr(column, "key", None)
- search_entities = True
self.type = type_ = column.type
self.use_id_for_hash = not type_.hashable
- # If the Column is unnamed, give it a
- # label() so that mutable column expressions
- # can be located in the result even
- # if the expression's identity has been changed
- # due to adaption.
-
- if not column._label and not getattr(column, "is_literal", False):
- column = column.label(self._label_name)
-
- query._entities.append(self)
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ query._entities.append(self)
self.column = column
- self.froms = set()
-
- # look for ORM entities represented within the
- # given expression. Try to count only entities
- # for columns whose FROM object is in the actual list
- # of FROMs for the overall expression - this helps
- # subqueries which were built from ORM constructs from
- # leaking out their entities into the main select construct
- self.actual_froms = set(column._from_objects)
- if not search_entities:
+ if orm_expr:
self.entity_zero = _entity
if _entity:
self.entities = [_entity]
@@ -4559,21 +4522,20 @@ class _ColumnEntity(_QueryEntity):
self.entities = []
self.mapper = None
else:
- all_elements = [
- elem
- for elem in sql_util.surface_column_elements(
- column, include_scalar_selects=False
- )
- if "parententity" in elem._annotations
- ]
- self.entities = util.unique_list(
- [elem._annotations["parententity"] for elem in all_elements]
+ entity = sql_util.extract_first_column_annotation(
+ column, "parententity"
)
+ if entity:
+ self.entities = [entity]
+ else:
+ self.entities = []
+
if self.entities:
self.entity_zero = self.entities[0]
self.mapper = self.entity_zero.mapper
+
elif self.namespace is not None:
self.entity_zero = self.namespace
self.mapper = None
@@ -4581,6 +4543,9 @@ class _ColumnEntity(_QueryEntity):
self.entity_zero = None
self.mapper = None
+ if self.entities and self.entity_zero.mapper.with_polymorphic:
+ query._setup_query_adapters(self.entity_zero, self.entity_zero)
+
supports_single_entity = False
def _deep_entity_zero(self):
@@ -4603,24 +4568,21 @@ class _ColumnEntity(_QueryEntity):
def entity_zero_or_selectable(self):
if self.entity_zero is not None:
return self.entity_zero
- elif self.actual_froms:
- return list(self.actual_froms)[0]
+ elif self.column._from_objects:
+ return self.column._from_objects[0]
else:
return None
- def adapt_to_selectable(self, query, sel):
- c = _ColumnEntity(query, sel.corresponding_column(self.column))
+ def adapt_to_selectable(self, query, sel, parent_bundle=None):
+ c = _ColumnEntity(
+ query,
+ sel.corresponding_column(self.column),
+ parent_bundle=parent_bundle,
+ )
c._label_name = self._label_name
c.entity_zero = self.entity_zero
c.entities = self.entities
- def setup_entity(self, ext_info, aliased_adapter):
- if "selectable" not in self.__dict__:
- self.selectable = ext_info.selectable
-
- if self.actual_froms.intersection(ext_info.selectable._from_objects):
- self.froms.add(ext_info.selectable)
-
def corresponds_to(self, entity):
if self.entity_zero is None:
return False
@@ -4651,13 +4613,32 @@ class _ColumnEntity(_QueryEntity):
def setup_context(self, query, context):
column = query._adapt_clause(self.column, False, True)
+ ezero = self.entity_zero
+
+ if self.mapper:
+ single_table_crit = self.mapper._single_table_criterion
+ if single_table_crit is not None:
+ context.single_inh_entities[ezero] = (
+ ezero,
+ ezero._adapter if ezero.is_aliased_class else None,
+ )
if column._annotations:
# annotated columns perform more slowly in compiler and
# result due to the __eq__() method, so use deannotated
column = column._deannotate()
- context.froms += tuple(self.froms)
+ if ezero is not None:
+ # use entity_zero as the from if we have it. this is necessary
+ # for polymorpic scenarios where our FROM is based on ORM entity,
+ # not the FROM of the column. but also, don't use it if our column
+ # doesn't actually have any FROMs that line up, such as when its
+ # a scalar subquery.
+ if set(self.column._from_objects).intersection(
+ ezero.selectable._from_objects
+ ):
+ context.froms += (ezero.selectable,)
+
context.primary_columns.append(column)
context.attributes[("fetch_column", self)] = column
@@ -4697,6 +4678,7 @@ class QueryContext(object):
"partials",
"post_load_paths",
"identity_token",
+ "single_inh_entities",
)
def __init__(self, query):
@@ -4731,6 +4713,7 @@ class QueryContext(object):
self.secondary_columns = []
self.eager_order_by = []
self.eager_joins = {}
+ self.single_inh_entities = {}
self.create_eager_joins = []
self.propagate_options = set(
o for o in query._with_options if o.propagate_to_loaders
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 63ec21099..731947cba 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -2319,11 +2319,11 @@ class JoinCondition(object):
"""
self.primaryjoin = _deep_deannotate(
- self.primaryjoin, values=("parententity",)
+ self.primaryjoin, values=("parententity", "orm_key")
)
if self.secondaryjoin is not None:
self.secondaryjoin = _deep_deannotate(
- self.secondaryjoin, values=("parententity",)
+ self.secondaryjoin, values=("parententity", "orm_key")
)
def _determine_joins(self):
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 4b4fa4052..747ec7e65 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -668,10 +668,11 @@ class AliasedInsp(InspectionAttr):
state["represents_outer_join"],
)
- def _adapt_element(self, elem):
- return self._adapter.traverse(elem)._annotate(
- {"parententity": self, "parentmapper": self.mapper}
- )
+ def _adapt_element(self, elem, key=None):
+ d = {"parententity": self, "parentmapper": self.mapper}
+ if key:
+ d["orm_key"] = key
+ return self._adapter.traverse(elem)._annotate(d)
def _entity_for_mapper(self, mapper):
self_poly = self.with_polymorphic_mappers
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index a7a856bba..95aee0468 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -57,7 +57,7 @@ def expect(role, element, **kw):
else:
resolved = element
- if issubclass(resolved.__class__, impl._role_class):
+ if impl._role_class in resolved.__class__.__mro__:
if impl._post_coercion:
resolved = impl._post_coercion(resolved, **kw)
return resolved
@@ -102,13 +102,16 @@ class RoleImpl(object):
def _resolve_for_clause_element(self, element, argname=None, **kw):
original_element = element
- is_clause_element = False
+ is_clause_element = hasattr(element, "__clause_element__")
- while hasattr(element, "__clause_element__") and not isinstance(
- element, (elements.ClauseElement, schema.SchemaItem)
- ):
- element = element.__clause_element__()
- is_clause_element = True
+ if is_clause_element:
+ while not isinstance(
+ element, (elements.ClauseElement, schema.SchemaItem)
+ ):
+ try:
+ element = element.__clause_element__()
+ except AttributeError:
+ break
if not is_clause_element:
if self._use_inspection:
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index fe83b163c..3c7f904de 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -364,23 +364,19 @@ def surface_selectables_only(clause):
stack.append(elem.table)
-def surface_column_elements(clause, include_scalar_selects=True):
- """traverse and yield only outer-exposed column elements, such as would
- be addressable in the WHERE clause of a SELECT if this element were
- in the columns clause."""
+def extract_first_column_annotation(column, annotation_name):
+ filter_ = (FromGrouping, SelectBase)
- filter_ = (FromGrouping,)
- if not include_scalar_selects:
- filter_ += (SelectBase,)
-
- stack = deque([clause])
+ stack = deque([column])
while stack:
elem = stack.popleft()
- yield elem
+ if annotation_name in elem._annotations:
+ return elem._annotations[annotation_name]
for sub in elem.get_children():
if isinstance(sub, filter_):
continue
stack.append(sub)
+ return None
def selectables_overlap(left, right):