summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-09-27 17:32:10 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2019-09-30 10:10:58 -0400
commit6ddb62a8ba66b19afd41b967911ce5982250856e (patch)
treeb98eba226618f2dff301366ac7e384f052feb69e /lib/sqlalchemy
parent2c34d2503a17316cae3282192405b9b9d60df6fe (diff)
downloadsqlalchemy-6ddb62a8ba66b19afd41b967911ce5982250856e.tar.gz
Simplify _ColumnEntity, related
In the interests of making Query much more lightweight up front, rework the calculations done at the top when the entities are constructed to be much less inolved. Use the new coercion system for _ColumnEntity and stop accepting plain strings, this will need to emit a deprecation warning in 1.3.x. Use annotations and other techniques to reduce the decisionmaking and complexity of Query. For the use case of subquery(), .statement, etc. we would like to do minimal work in order to get the columns clause. Change-Id: I7e459bbd3bb10ec71235f75ef4f3b0a969bec590
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/attributes.py20
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py21
-rw-r--r--lib/sqlalchemy/orm/interfaces.py12
-rw-r--r--lib/sqlalchemy/orm/properties.py3
-rw-r--r--lib/sqlalchemy/orm/query.py317
-rw-r--r--lib/sqlalchemy/orm/relationships.py4
-rw-r--r--lib/sqlalchemy/orm/util.py9
-rw-r--r--lib/sqlalchemy/sql/coercions.py17
-rw-r--r--lib/sqlalchemy/sql/util.py16
9 files changed, 208 insertions, 211 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 9d404e00d..117dd4cea 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -168,18 +168,18 @@ class QueryableAttribute(
"""
return inspection.inspect(self._parententity)
- @property
+ @util.memoized_property
def expression(self):
- return self.comparator.__clause_element__()
-
- def __clause_element__(self):
- return self.comparator.__clause_element__()
+ return self.comparator.__clause_element__()._annotate(
+ {"orm_key": self.key}
+ )
- def _query_clause_element(self):
- """like __clause_element__(), but called specifically
- by :class:`.Query` to allow special behavior."""
+ @property
+ def _annotations(self):
+ return self.__clause_element__()._annotations
- return self.comparator._query_clause_element()
+ def __clause_element__(self):
+ return self.expression
def _bulk_update_tuples(self, value):
"""Return setter tuples for a bulk UPDATE."""
@@ -207,7 +207,7 @@ class QueryableAttribute(
)
def label(self, name):
- return self._query_clause_element().label(name)
+ return self.__clause_element__().label(name)
def operate(self, op, *other, **kwargs):
return op(self.comparator, *other, **kwargs)
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 28b3bc5db..075638fed 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -413,19 +413,26 @@ class CompositeProperty(DescriptorProperty):
__hash__ = None
- @property
+ @util.memoized_property
def clauses(self):
- return self.__clause_element__()
-
- def __clause_element__(self):
return expression.ClauseList(
group=False, *self._comparable_elements
)
- def _query_clause_element(self):
- return CompositeProperty.CompositeBundle(
- self.prop, self.__clause_element__()
+ def __clause_element__(self):
+ return self.expression
+
+ @util.memoized_property
+ def expression(self):
+ clauses = self.clauses._annotate(
+ {
+ "bundle": True,
+ "parententity": self._parententity,
+ "parentmapper": self._parententity,
+ "orm_key": self.prop.key,
+ }
)
+ return CompositeProperty.CompositeBundle(self.prop, clauses)
def _bulk_update_tuples(self, value):
if value is None:
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 5098a55ce..d6bdfb924 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -363,6 +363,7 @@ class PropComparator(operators.ColumnOperators):
__slots__ = "prop", "property", "_parententity", "_adapt_to_entity"
def __init__(self, prop, parentmapper, adapt_to_entity=None):
+ # type: (MapperProperty, Mapper, Optional(AliasedInsp))
self.prop = self.property = prop
self._parententity = adapt_to_entity or parentmapper
self._adapt_to_entity = adapt_to_entity
@@ -370,10 +371,15 @@ class PropComparator(operators.ColumnOperators):
def __clause_element__(self):
raise NotImplementedError("%r" % self)
- def _query_clause_element(self):
- return self.__clause_element__()
-
def _bulk_update_tuples(self, value):
+ # type: (ColumnOperators) -> List[tuple[ColumnOperators, Any]]
+ """Receive a SQL expression that represents a value in the SET
+ clause of an UPDATE statement.
+
+ Return a tuple that can be passed to a :class:`.Update` construct.
+
+ """
+
return [(self.__clause_element__(), value)]
def adapt_to_entity(self, adapt_to_entity):
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index e2c10e50a..f804d6eed 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -292,7 +292,7 @@ class ColumnProperty(StrategizedProperty):
def _memoized_method___clause_element__(self):
if self.adapter:
- return self.adapter(self.prop.columns[0])
+ return self.adapter(self.prop.columns[0], self.prop.key)
else:
# no adapter, so we aren't aliased
# assert self._parententity is self._parentmapper
@@ -300,6 +300,7 @@ class ColumnProperty(StrategizedProperty):
{
"parententity": self._parententity,
"parentmapper": self._parententity,
+ "orm_key": self.prop.key,
}
)
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 37bd77f63..3d08dce22 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -114,7 +114,6 @@ class Query(Generative):
_from_obj = ()
_join_entities = ()
_select_from_entity = None
- _mapper_adapter_map = {}
_filter_aliases = ()
_from_obj_alias = None
_joinpath = _joinpoint = util.immutabledict()
@@ -177,61 +176,23 @@ class Query(Generative):
self._primary_entity = None
self._has_mapper_entities = False
- # 1. don't run util.to_list() or _set_entity_selectables
- # if no entities were passed - major performance bottleneck
- # from lazy loader implementation when it seeks to use Query
- # class for an identity lookup, causes test_orm.py to fail
- # with thousands of extra function calls, see issue #4228
- # for why this use had to be added
- # 2. can't use classmethod on Query because session.query_cls
- # is an arbitrary callable in some user recipes, not
- # necessarily a class, so we don't have the class available.
- # see issue #4256
- # 3. can't do "if entities is not None" because we usually get here
- # from session.query() which takes in *entities.
- # 4. can't do "if entities" because users make use of undocumented
- # to_list() behavior here and they pass clause expressions that
- # can't be evaluated as boolean. See issue #4269.
- # 5. the empty tuple is a singleton in cPython, take advantage of this
- # so that we can skip for the empty "*entities" case without using
- # any Python overloadable operators.
- #
if entities is not ():
for ent in util.to_list(entities):
entity_wrapper(self, ent)
- self._set_entity_selectables(self._entities)
-
- def _set_entity_selectables(self, entities):
- self._mapper_adapter_map = d = self._mapper_adapter_map.copy()
-
- for ent in entities:
- for entity in ent.entities:
- if entity not in d:
- ext_info = inspect(entity)
- if (
- not ext_info.is_aliased_class
- and ext_info.mapper.with_polymorphic
- ):
- if (
- ext_info.mapper.persist_selectable
- not in self._polymorphic_adapters
- ):
- self._mapper_loads_polymorphically_with(
- ext_info.mapper,
- sql_util.ColumnAdapter(
- ext_info.selectable,
- ext_info.mapper._equivalent_columns,
- ),
- )
- aliased_adapter = None
- elif ext_info.is_aliased_class:
- aliased_adapter = ext_info._adapter
- else:
- aliased_adapter = None
-
- d[entity] = (ext_info, aliased_adapter)
- ent.setup_entity(*d[entity])
+ def _setup_query_adapters(self, entity, ext_info):
+ if not ext_info.is_aliased_class and ext_info.mapper.with_polymorphic:
+ if (
+ ext_info.mapper.persist_selectable
+ not in self._polymorphic_adapters
+ ):
+ self._mapper_loads_polymorphically_with(
+ ext_info.mapper,
+ sql_util.ColumnAdapter(
+ ext_info.selectable,
+ ext_info.mapper._equivalent_columns,
+ ),
+ )
def _mapper_loads_polymorphically_with(self, mapper, adapter):
for m2 in mapper._with_polymorphic_mappers or [mapper]:
@@ -1162,8 +1123,7 @@ class Query(Generative):
entity = aliased(entity, alias)
self._entities = list(self._entities)
- m = _MapperEntity(self, entity)
- self._set_entity_selectables([m])
+ _MapperEntity(self, entity)
@_generative
def with_session(self, session):
@@ -1455,12 +1415,9 @@ class Query(Generative):
of result columns to be returned."""
self._entities = list(self._entities)
- l = len(self._entities)
+
for c in column:
_ColumnEntity(self, c)
- # _ColumnEntity may add many entities if the
- # given arg is a FROM clause
- self._set_entity_selectables(self._entities[l:])
@util.pending_deprecation(
"0.7",
@@ -2464,9 +2421,13 @@ class Query(Generative):
)
else:
# add a new element to the self._from_obj list
-
if use_entity_index is not None:
- # why doesn't this work as .entity_zero_or_selectable?
+ # make use of _MapperEntity selectable, which is usually
+ # entity_zero.selectable, but if with_polymorphic() were used
+ # might be distinct
+ assert isinstance(
+ self._entities[use_entity_index], _MapperEntity
+ )
left_clause = self._entities[use_entity_index].selectable
else:
left_clause = left
@@ -3529,7 +3490,7 @@ class Query(Generative):
# we get just "SELECT 1" without any entities.
return sql.exists(
self.enable_eagerloads(False)
- .add_columns("1")
+ .add_columns(sql.literal_column("1"))
.with_labels()
.statement.with_only_columns([1])
)
@@ -4029,10 +3990,10 @@ class Query(Generative):
"""
- search = set(self._mapper_adapter_map.values())
+ search = set(context.single_inh_entities.values())
if (
self._select_from_entity
- and self._select_from_entity not in self._mapper_adapter_map
+ and self._select_from_entity not in context.single_inh_entities
):
insp = inspect(self._select_from_entity)
if insp.is_aliased_class:
@@ -4110,23 +4071,27 @@ class _MapperEntity(_QueryEntity):
self.entities = [entity]
self.expr = entity
- supports_single_entity = True
-
- use_id_for_hash = True
+ ext_info = self.entity_zero = inspect(entity)
- def setup_entity(self, ext_info, aliased_adapter):
self.mapper = ext_info.mapper
- self.aliased_adapter = aliased_adapter
+
+ if ext_info.is_aliased_class:
+ self._label_name = ext_info.name
+ else:
+ self._label_name = self.mapper.class_.__name__
+
self.selectable = ext_info.selectable
self.is_aliased_class = ext_info.is_aliased_class
self._with_polymorphic = ext_info.with_polymorphic_mappers
self._polymorphic_discriminator = ext_info.polymorphic_on
- self.entity_zero = ext_info
- if ext_info.is_aliased_class:
- self._label_name = self.entity_zero.name
- else:
- self._label_name = self.mapper.class_.__name__
- self.path = self.entity_zero._path_registry
+ self.path = ext_info._path_registry
+
+ if ext_info.mapper.with_polymorphic:
+ query._setup_query_adapters(entity, ext_info)
+
+ supports_single_entity = True
+
+ use_id_for_hash = True
def set_with_polymorphic(
self, query, cls_or_mappers, selectable, polymorphic_on
@@ -4185,7 +4150,7 @@ class _MapperEntity(_QueryEntity):
if query._polymorphic_adapters:
adapter = query._polymorphic_adapters.get(self.mapper, None)
else:
- adapter = self.aliased_adapter
+ adapter = self.entity_zero._adapter
if adapter:
if query._from_obj_alias:
@@ -4235,6 +4200,14 @@ class _MapperEntity(_QueryEntity):
def setup_context(self, query, context):
adapter = self._get_entity_clauses(query, context)
+ single_table_crit = self.mapper._single_table_criterion
+ if single_table_crit is not None:
+ ext_info = self.entity_zero
+ context.single_inh_entities[ext_info] = (
+ ext_info,
+ ext_info._adapter if ext_info.is_aliased_class else None,
+ )
+
# if self._adapted_selectable is None:
context.froms += (self.selectable,)
@@ -4352,7 +4325,9 @@ class Bundle(InspectionAttr):
return cloned
def __clause_element__(self):
- return expression.ClauseList(group=False, *self.exprs)
+ return expression.ClauseList(group=False, *self.exprs)._annotate(
+ {"bundle": True}
+ )
@property
def clauses(self):
@@ -4386,8 +4361,19 @@ class Bundle(InspectionAttr):
class _BundleEntity(_QueryEntity):
use_id_for_hash = False
- def __init__(self, query, bundle, setup_entities=True):
- query._entities.append(self)
+ def __init__(self, query, expr, setup_entities=True, parent_bundle=None):
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ query._entities.append(self)
+
+ if isinstance(
+ expr, (attributes.QueryableAttribute, interfaces.PropComparator)
+ ):
+ bundle = expr.__clause_element__()
+ else:
+ bundle = expr
+
self.bundle = self.expr = bundle
self.type = type(bundle)
self._label_name = bundle.name
@@ -4396,9 +4382,9 @@ class _BundleEntity(_QueryEntity):
if setup_entities:
for expr in bundle.exprs:
if isinstance(expr, Bundle):
- _BundleEntity(self, expr)
+ _BundleEntity(query, expr, parent_bundle=self)
else:
- _ColumnEntity(self, expr)
+ _ColumnEntity(query, expr, parent_bundle=self)
self.supports_single_entity = self.bundle.single_entity
@@ -4448,18 +4434,19 @@ class _BundleEntity(_QueryEntity):
else:
return None
- def adapt_to_selectable(self, query, sel):
- c = _BundleEntity(query, self.bundle, setup_entities=False)
+ def adapt_to_selectable(self, query, sel, parent_bundle=None):
+ c = _BundleEntity(
+ query,
+ self.bundle,
+ setup_entities=False,
+ parent_bundle=parent_bundle,
+ )
# c._label_name = self._label_name
# c.entity_zero = self.entity_zero
# c.entities = self.entities
for ent in self._entities:
- ent.adapt_to_selectable(c, sel)
-
- def setup_entity(self, ext_info, aliased_adapter):
- for ent in self._entities:
- ent.setup_entity(ext_info, aliased_adapter)
+ ent.adapt_to_selectable(query, sel, parent_bundle=c)
def setup_context(self, query, context):
for ent in self._entities:
@@ -4481,76 +4468,52 @@ class _BundleEntity(_QueryEntity):
class _ColumnEntity(_QueryEntity):
"""Column/expression based entity."""
- def __init__(self, query, column, namespace=None):
- self.expr = column
+ froms = frozenset()
+
+ def __init__(self, query, column, namespace=None, parent_bundle=None):
+ self.expr = expr = column
self.namespace = namespace
- search_entities = True
- check_column = False
-
- if isinstance(column, util.string_types):
- column = sql.literal_column(column)
- self._label_name = column.name
- search_entities = False
- check_column = True
- _entity = None
- elif isinstance(
- column, (attributes.QueryableAttribute, interfaces.PropComparator)
- ):
- _entity = getattr(column, "_parententity", None)
- if _entity is not None:
- search_entities = False
- self._label_name = column.key
- column = column._query_clause_element()
- check_column = True
- if isinstance(column, Bundle):
- _BundleEntity(query, column)
- return
+ _label_name = None
- if not isinstance(column, sql.ColumnElement):
- if hasattr(column, "_select_iterable"):
- # break out an object like Table into
- # individual columns
- for c in column._select_iterable:
- if c is column:
- break
- _ColumnEntity(query, c, namespace=column)
- else:
- return
+ column = coercions.expect(roles.ColumnsClauseRole, column)
- raise sa_exc.InvalidRequestError(
- "SQL expression, column, or mapped entity "
- "expected - got '%r'" % (column,)
- )
- elif not check_column:
+ annotations = column._annotations
+
+ if annotations.get("bundle", False):
+ _BundleEntity(query, expr, parent_bundle=parent_bundle)
+ return
+
+ orm_expr = False
+
+ if "parententity" in annotations:
+ _entity = annotations["parententity"]
+ self._label_name = _label_name = annotations.get("orm_key", None)
+ orm_expr = True
+
+ if hasattr(column, "_select_iterable"):
+ # break out an object like Table into
+ # individual columns
+ for c in column._select_iterable:
+ if c is column:
+ break
+ _ColumnEntity(query, c, namespace=column)
+ else:
+ return
+
+ if _label_name is None:
self._label_name = getattr(column, "key", None)
- search_entities = True
self.type = type_ = column.type
self.use_id_for_hash = not type_.hashable
- # If the Column is unnamed, give it a
- # label() so that mutable column expressions
- # can be located in the result even
- # if the expression's identity has been changed
- # due to adaption.
-
- if not column._label and not getattr(column, "is_literal", False):
- column = column.label(self._label_name)
-
- query._entities.append(self)
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ query._entities.append(self)
self.column = column
- self.froms = set()
-
- # look for ORM entities represented within the
- # given expression. Try to count only entities
- # for columns whose FROM object is in the actual list
- # of FROMs for the overall expression - this helps
- # subqueries which were built from ORM constructs from
- # leaking out their entities into the main select construct
- self.actual_froms = set(column._from_objects)
- if not search_entities:
+ if orm_expr:
self.entity_zero = _entity
if _entity:
self.entities = [_entity]
@@ -4559,21 +4522,20 @@ class _ColumnEntity(_QueryEntity):
self.entities = []
self.mapper = None
else:
- all_elements = [
- elem
- for elem in sql_util.surface_column_elements(
- column, include_scalar_selects=False
- )
- if "parententity" in elem._annotations
- ]
- self.entities = util.unique_list(
- [elem._annotations["parententity"] for elem in all_elements]
+ entity = sql_util.extract_first_column_annotation(
+ column, "parententity"
)
+ if entity:
+ self.entities = [entity]
+ else:
+ self.entities = []
+
if self.entities:
self.entity_zero = self.entities[0]
self.mapper = self.entity_zero.mapper
+
elif self.namespace is not None:
self.entity_zero = self.namespace
self.mapper = None
@@ -4581,6 +4543,9 @@ class _ColumnEntity(_QueryEntity):
self.entity_zero = None
self.mapper = None
+ if self.entities and self.entity_zero.mapper.with_polymorphic:
+ query._setup_query_adapters(self.entity_zero, self.entity_zero)
+
supports_single_entity = False
def _deep_entity_zero(self):
@@ -4603,24 +4568,21 @@ class _ColumnEntity(_QueryEntity):
def entity_zero_or_selectable(self):
if self.entity_zero is not None:
return self.entity_zero
- elif self.actual_froms:
- return list(self.actual_froms)[0]
+ elif self.column._from_objects:
+ return self.column._from_objects[0]
else:
return None
- def adapt_to_selectable(self, query, sel):
- c = _ColumnEntity(query, sel.corresponding_column(self.column))
+ def adapt_to_selectable(self, query, sel, parent_bundle=None):
+ c = _ColumnEntity(
+ query,
+ sel.corresponding_column(self.column),
+ parent_bundle=parent_bundle,
+ )
c._label_name = self._label_name
c.entity_zero = self.entity_zero
c.entities = self.entities
- def setup_entity(self, ext_info, aliased_adapter):
- if "selectable" not in self.__dict__:
- self.selectable = ext_info.selectable
-
- if self.actual_froms.intersection(ext_info.selectable._from_objects):
- self.froms.add(ext_info.selectable)
-
def corresponds_to(self, entity):
if self.entity_zero is None:
return False
@@ -4651,13 +4613,32 @@ class _ColumnEntity(_QueryEntity):
def setup_context(self, query, context):
column = query._adapt_clause(self.column, False, True)
+ ezero = self.entity_zero
+
+ if self.mapper:
+ single_table_crit = self.mapper._single_table_criterion
+ if single_table_crit is not None:
+ context.single_inh_entities[ezero] = (
+ ezero,
+ ezero._adapter if ezero.is_aliased_class else None,
+ )
if column._annotations:
# annotated columns perform more slowly in compiler and
# result due to the __eq__() method, so use deannotated
column = column._deannotate()
- context.froms += tuple(self.froms)
+ if ezero is not None:
+ # use entity_zero as the from if we have it. this is necessary
+ # for polymorpic scenarios where our FROM is based on ORM entity,
+ # not the FROM of the column. but also, don't use it if our column
+ # doesn't actually have any FROMs that line up, such as when its
+ # a scalar subquery.
+ if set(self.column._from_objects).intersection(
+ ezero.selectable._from_objects
+ ):
+ context.froms += (ezero.selectable,)
+
context.primary_columns.append(column)
context.attributes[("fetch_column", self)] = column
@@ -4697,6 +4678,7 @@ class QueryContext(object):
"partials",
"post_load_paths",
"identity_token",
+ "single_inh_entities",
)
def __init__(self, query):
@@ -4731,6 +4713,7 @@ class QueryContext(object):
self.secondary_columns = []
self.eager_order_by = []
self.eager_joins = {}
+ self.single_inh_entities = {}
self.create_eager_joins = []
self.propagate_options = set(
o for o in query._with_options if o.propagate_to_loaders
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 63ec21099..731947cba 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -2319,11 +2319,11 @@ class JoinCondition(object):
"""
self.primaryjoin = _deep_deannotate(
- self.primaryjoin, values=("parententity",)
+ self.primaryjoin, values=("parententity", "orm_key")
)
if self.secondaryjoin is not None:
self.secondaryjoin = _deep_deannotate(
- self.secondaryjoin, values=("parententity",)
+ self.secondaryjoin, values=("parententity", "orm_key")
)
def _determine_joins(self):
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 4b4fa4052..747ec7e65 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -668,10 +668,11 @@ class AliasedInsp(InspectionAttr):
state["represents_outer_join"],
)
- def _adapt_element(self, elem):
- return self._adapter.traverse(elem)._annotate(
- {"parententity": self, "parentmapper": self.mapper}
- )
+ def _adapt_element(self, elem, key=None):
+ d = {"parententity": self, "parentmapper": self.mapper}
+ if key:
+ d["orm_key"] = key
+ return self._adapter.traverse(elem)._annotate(d)
def _entity_for_mapper(self, mapper):
self_poly = self.with_polymorphic_mappers
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index a7a856bba..95aee0468 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -57,7 +57,7 @@ def expect(role, element, **kw):
else:
resolved = element
- if issubclass(resolved.__class__, impl._role_class):
+ if impl._role_class in resolved.__class__.__mro__:
if impl._post_coercion:
resolved = impl._post_coercion(resolved, **kw)
return resolved
@@ -102,13 +102,16 @@ class RoleImpl(object):
def _resolve_for_clause_element(self, element, argname=None, **kw):
original_element = element
- is_clause_element = False
+ is_clause_element = hasattr(element, "__clause_element__")
- while hasattr(element, "__clause_element__") and not isinstance(
- element, (elements.ClauseElement, schema.SchemaItem)
- ):
- element = element.__clause_element__()
- is_clause_element = True
+ if is_clause_element:
+ while not isinstance(
+ element, (elements.ClauseElement, schema.SchemaItem)
+ ):
+ try:
+ element = element.__clause_element__()
+ except AttributeError:
+ break
if not is_clause_element:
if self._use_inspection:
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index fe83b163c..3c7f904de 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -364,23 +364,19 @@ def surface_selectables_only(clause):
stack.append(elem.table)
-def surface_column_elements(clause, include_scalar_selects=True):
- """traverse and yield only outer-exposed column elements, such as would
- be addressable in the WHERE clause of a SELECT if this element were
- in the columns clause."""
+def extract_first_column_annotation(column, annotation_name):
+ filter_ = (FromGrouping, SelectBase)
- filter_ = (FromGrouping,)
- if not include_scalar_selects:
- filter_ += (SelectBase,)
-
- stack = deque([clause])
+ stack = deque([column])
while stack:
elem = stack.popleft()
- yield elem
+ if annotation_name in elem._annotations:
+ return elem._annotations[annotation_name]
for sub in elem.get_children():
if isinstance(sub, filter_):
continue
stack.append(sub)
+ return None
def selectables_overlap(left, right):