diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2019-09-30 15:32:18 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2019-09-30 15:32:18 +0000 |
| commit | ff1124444e88260184ea807a7cbb16a1f6ee0ff4 (patch) | |
| tree | 30b9afc4875cead27be166669ca4a0de5bd3e908 /lib/sqlalchemy | |
| parent | 9f3539b1745cbb287a1338812872d27cde4ebf24 (diff) | |
| parent | 6ddb62a8ba66b19afd41b967911ce5982250856e (diff) | |
| download | sqlalchemy-ff1124444e88260184ea807a7cbb16a1f6ee0ff4.tar.gz | |
Merge "Simplify _ColumnEntity, related"
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 21 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 317 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/util.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 17 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 16 |
9 files changed, 208 insertions, 211 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 9d404e00d..117dd4cea 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -168,18 +168,18 @@ class QueryableAttribute( """ return inspection.inspect(self._parententity) - @property + @util.memoized_property def expression(self): - return self.comparator.__clause_element__() - - def __clause_element__(self): - return self.comparator.__clause_element__() + return self.comparator.__clause_element__()._annotate( + {"orm_key": self.key} + ) - def _query_clause_element(self): - """like __clause_element__(), but called specifically - by :class:`.Query` to allow special behavior.""" + @property + def _annotations(self): + return self.__clause_element__()._annotations - return self.comparator._query_clause_element() + def __clause_element__(self): + return self.expression def _bulk_update_tuples(self, value): """Return setter tuples for a bulk UPDATE.""" @@ -207,7 +207,7 @@ class QueryableAttribute( ) def label(self, name): - return self._query_clause_element().label(name) + return self.__clause_element__().label(name) def operate(self, op, *other, **kwargs): return op(self.comparator, *other, **kwargs) diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 28b3bc5db..075638fed 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -413,19 +413,26 @@ class CompositeProperty(DescriptorProperty): __hash__ = None - @property + @util.memoized_property def clauses(self): - return self.__clause_element__() - - def __clause_element__(self): return expression.ClauseList( group=False, *self._comparable_elements ) - def _query_clause_element(self): - return CompositeProperty.CompositeBundle( - self.prop, self.__clause_element__() + def __clause_element__(self): + return self.expression + + @util.memoized_property + def expression(self): + clauses = self.clauses._annotate( + { + "bundle": True, + "parententity": self._parententity, + "parentmapper": self._parententity, + "orm_key": self.prop.key, + } ) + return CompositeProperty.CompositeBundle(self.prop, clauses) def _bulk_update_tuples(self, value): if value is None: diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 5098a55ce..d6bdfb924 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -363,6 +363,7 @@ class PropComparator(operators.ColumnOperators): __slots__ = "prop", "property", "_parententity", "_adapt_to_entity" def __init__(self, prop, parentmapper, adapt_to_entity=None): + # type: (MapperProperty, Mapper, Optional(AliasedInsp)) self.prop = self.property = prop self._parententity = adapt_to_entity or parentmapper self._adapt_to_entity = adapt_to_entity @@ -370,10 +371,15 @@ class PropComparator(operators.ColumnOperators): def __clause_element__(self): raise NotImplementedError("%r" % self) - def _query_clause_element(self): - return self.__clause_element__() - def _bulk_update_tuples(self, value): + # type: (ColumnOperators) -> List[tuple[ColumnOperators, Any]] + """Receive a SQL expression that represents a value in the SET + clause of an UPDATE statement. + + Return a tuple that can be passed to a :class:`.Update` construct. + + """ + return [(self.__clause_element__(), value)] def adapt_to_entity(self, adapt_to_entity): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index e2c10e50a..f804d6eed 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -292,7 +292,7 @@ class ColumnProperty(StrategizedProperty): def _memoized_method___clause_element__(self): if self.adapter: - return self.adapter(self.prop.columns[0]) + return self.adapter(self.prop.columns[0], self.prop.key) else: # no adapter, so we aren't aliased # assert self._parententity is self._parentmapper @@ -300,6 +300,7 @@ class ColumnProperty(StrategizedProperty): { "parententity": self._parententity, "parentmapper": self._parententity, + "orm_key": self.prop.key, } ) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 37bd77f63..3d08dce22 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -114,7 +114,6 @@ class Query(Generative): _from_obj = () _join_entities = () _select_from_entity = None - _mapper_adapter_map = {} _filter_aliases = () _from_obj_alias = None _joinpath = _joinpoint = util.immutabledict() @@ -177,61 +176,23 @@ class Query(Generative): self._primary_entity = None self._has_mapper_entities = False - # 1. don't run util.to_list() or _set_entity_selectables - # if no entities were passed - major performance bottleneck - # from lazy loader implementation when it seeks to use Query - # class for an identity lookup, causes test_orm.py to fail - # with thousands of extra function calls, see issue #4228 - # for why this use had to be added - # 2. can't use classmethod on Query because session.query_cls - # is an arbitrary callable in some user recipes, not - # necessarily a class, so we don't have the class available. - # see issue #4256 - # 3. can't do "if entities is not None" because we usually get here - # from session.query() which takes in *entities. - # 4. can't do "if entities" because users make use of undocumented - # to_list() behavior here and they pass clause expressions that - # can't be evaluated as boolean. See issue #4269. - # 5. the empty tuple is a singleton in cPython, take advantage of this - # so that we can skip for the empty "*entities" case without using - # any Python overloadable operators. - # if entities is not (): for ent in util.to_list(entities): entity_wrapper(self, ent) - self._set_entity_selectables(self._entities) - - def _set_entity_selectables(self, entities): - self._mapper_adapter_map = d = self._mapper_adapter_map.copy() - - for ent in entities: - for entity in ent.entities: - if entity not in d: - ext_info = inspect(entity) - if ( - not ext_info.is_aliased_class - and ext_info.mapper.with_polymorphic - ): - if ( - ext_info.mapper.persist_selectable - not in self._polymorphic_adapters - ): - self._mapper_loads_polymorphically_with( - ext_info.mapper, - sql_util.ColumnAdapter( - ext_info.selectable, - ext_info.mapper._equivalent_columns, - ), - ) - aliased_adapter = None - elif ext_info.is_aliased_class: - aliased_adapter = ext_info._adapter - else: - aliased_adapter = None - - d[entity] = (ext_info, aliased_adapter) - ent.setup_entity(*d[entity]) + def _setup_query_adapters(self, entity, ext_info): + if not ext_info.is_aliased_class and ext_info.mapper.with_polymorphic: + if ( + ext_info.mapper.persist_selectable + not in self._polymorphic_adapters + ): + self._mapper_loads_polymorphically_with( + ext_info.mapper, + sql_util.ColumnAdapter( + ext_info.selectable, + ext_info.mapper._equivalent_columns, + ), + ) def _mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers or [mapper]: @@ -1162,8 +1123,7 @@ class Query(Generative): entity = aliased(entity, alias) self._entities = list(self._entities) - m = _MapperEntity(self, entity) - self._set_entity_selectables([m]) + _MapperEntity(self, entity) @_generative def with_session(self, session): @@ -1455,12 +1415,9 @@ class Query(Generative): of result columns to be returned.""" self._entities = list(self._entities) - l = len(self._entities) + for c in column: _ColumnEntity(self, c) - # _ColumnEntity may add many entities if the - # given arg is a FROM clause - self._set_entity_selectables(self._entities[l:]) @util.pending_deprecation( "0.7", @@ -2464,9 +2421,13 @@ class Query(Generative): ) else: # add a new element to the self._from_obj list - if use_entity_index is not None: - # why doesn't this work as .entity_zero_or_selectable? + # make use of _MapperEntity selectable, which is usually + # entity_zero.selectable, but if with_polymorphic() were used + # might be distinct + assert isinstance( + self._entities[use_entity_index], _MapperEntity + ) left_clause = self._entities[use_entity_index].selectable else: left_clause = left @@ -3529,7 +3490,7 @@ class Query(Generative): # we get just "SELECT 1" without any entities. return sql.exists( self.enable_eagerloads(False) - .add_columns("1") + .add_columns(sql.literal_column("1")) .with_labels() .statement.with_only_columns([1]) ) @@ -4029,10 +3990,10 @@ class Query(Generative): """ - search = set(self._mapper_adapter_map.values()) + search = set(context.single_inh_entities.values()) if ( self._select_from_entity - and self._select_from_entity not in self._mapper_adapter_map + and self._select_from_entity not in context.single_inh_entities ): insp = inspect(self._select_from_entity) if insp.is_aliased_class: @@ -4110,23 +4071,27 @@ class _MapperEntity(_QueryEntity): self.entities = [entity] self.expr = entity - supports_single_entity = True - - use_id_for_hash = True + ext_info = self.entity_zero = inspect(entity) - def setup_entity(self, ext_info, aliased_adapter): self.mapper = ext_info.mapper - self.aliased_adapter = aliased_adapter + + if ext_info.is_aliased_class: + self._label_name = ext_info.name + else: + self._label_name = self.mapper.class_.__name__ + self.selectable = ext_info.selectable self.is_aliased_class = ext_info.is_aliased_class self._with_polymorphic = ext_info.with_polymorphic_mappers self._polymorphic_discriminator = ext_info.polymorphic_on - self.entity_zero = ext_info - if ext_info.is_aliased_class: - self._label_name = self.entity_zero.name - else: - self._label_name = self.mapper.class_.__name__ - self.path = self.entity_zero._path_registry + self.path = ext_info._path_registry + + if ext_info.mapper.with_polymorphic: + query._setup_query_adapters(entity, ext_info) + + supports_single_entity = True + + use_id_for_hash = True def set_with_polymorphic( self, query, cls_or_mappers, selectable, polymorphic_on @@ -4185,7 +4150,7 @@ class _MapperEntity(_QueryEntity): if query._polymorphic_adapters: adapter = query._polymorphic_adapters.get(self.mapper, None) else: - adapter = self.aliased_adapter + adapter = self.entity_zero._adapter if adapter: if query._from_obj_alias: @@ -4235,6 +4200,14 @@ class _MapperEntity(_QueryEntity): def setup_context(self, query, context): adapter = self._get_entity_clauses(query, context) + single_table_crit = self.mapper._single_table_criterion + if single_table_crit is not None: + ext_info = self.entity_zero + context.single_inh_entities[ext_info] = ( + ext_info, + ext_info._adapter if ext_info.is_aliased_class else None, + ) + # if self._adapted_selectable is None: context.froms += (self.selectable,) @@ -4352,7 +4325,9 @@ class Bundle(InspectionAttr): return cloned def __clause_element__(self): - return expression.ClauseList(group=False, *self.exprs) + return expression.ClauseList(group=False, *self.exprs)._annotate( + {"bundle": True} + ) @property def clauses(self): @@ -4386,8 +4361,19 @@ class Bundle(InspectionAttr): class _BundleEntity(_QueryEntity): use_id_for_hash = False - def __init__(self, query, bundle, setup_entities=True): - query._entities.append(self) + def __init__(self, query, expr, setup_entities=True, parent_bundle=None): + if parent_bundle: + parent_bundle._entities.append(self) + else: + query._entities.append(self) + + if isinstance( + expr, (attributes.QueryableAttribute, interfaces.PropComparator) + ): + bundle = expr.__clause_element__() + else: + bundle = expr + self.bundle = self.expr = bundle self.type = type(bundle) self._label_name = bundle.name @@ -4396,9 +4382,9 @@ class _BundleEntity(_QueryEntity): if setup_entities: for expr in bundle.exprs: if isinstance(expr, Bundle): - _BundleEntity(self, expr) + _BundleEntity(query, expr, parent_bundle=self) else: - _ColumnEntity(self, expr) + _ColumnEntity(query, expr, parent_bundle=self) self.supports_single_entity = self.bundle.single_entity @@ -4448,18 +4434,19 @@ class _BundleEntity(_QueryEntity): else: return None - def adapt_to_selectable(self, query, sel): - c = _BundleEntity(query, self.bundle, setup_entities=False) + def adapt_to_selectable(self, query, sel, parent_bundle=None): + c = _BundleEntity( + query, + self.bundle, + setup_entities=False, + parent_bundle=parent_bundle, + ) # c._label_name = self._label_name # c.entity_zero = self.entity_zero # c.entities = self.entities for ent in self._entities: - ent.adapt_to_selectable(c, sel) - - def setup_entity(self, ext_info, aliased_adapter): - for ent in self._entities: - ent.setup_entity(ext_info, aliased_adapter) + ent.adapt_to_selectable(query, sel, parent_bundle=c) def setup_context(self, query, context): for ent in self._entities: @@ -4481,76 +4468,52 @@ class _BundleEntity(_QueryEntity): class _ColumnEntity(_QueryEntity): """Column/expression based entity.""" - def __init__(self, query, column, namespace=None): - self.expr = column + froms = frozenset() + + def __init__(self, query, column, namespace=None, parent_bundle=None): + self.expr = expr = column self.namespace = namespace - search_entities = True - check_column = False - - if isinstance(column, util.string_types): - column = sql.literal_column(column) - self._label_name = column.name - search_entities = False - check_column = True - _entity = None - elif isinstance( - column, (attributes.QueryableAttribute, interfaces.PropComparator) - ): - _entity = getattr(column, "_parententity", None) - if _entity is not None: - search_entities = False - self._label_name = column.key - column = column._query_clause_element() - check_column = True - if isinstance(column, Bundle): - _BundleEntity(query, column) - return + _label_name = None - if not isinstance(column, sql.ColumnElement): - if hasattr(column, "_select_iterable"): - # break out an object like Table into - # individual columns - for c in column._select_iterable: - if c is column: - break - _ColumnEntity(query, c, namespace=column) - else: - return + column = coercions.expect(roles.ColumnsClauseRole, column) - raise sa_exc.InvalidRequestError( - "SQL expression, column, or mapped entity " - "expected - got '%r'" % (column,) - ) - elif not check_column: + annotations = column._annotations + + if annotations.get("bundle", False): + _BundleEntity(query, expr, parent_bundle=parent_bundle) + return + + orm_expr = False + + if "parententity" in annotations: + _entity = annotations["parententity"] + self._label_name = _label_name = annotations.get("orm_key", None) + orm_expr = True + + if hasattr(column, "_select_iterable"): + # break out an object like Table into + # individual columns + for c in column._select_iterable: + if c is column: + break + _ColumnEntity(query, c, namespace=column) + else: + return + + if _label_name is None: self._label_name = getattr(column, "key", None) - search_entities = True self.type = type_ = column.type self.use_id_for_hash = not type_.hashable - # If the Column is unnamed, give it a - # label() so that mutable column expressions - # can be located in the result even - # if the expression's identity has been changed - # due to adaption. - - if not column._label and not getattr(column, "is_literal", False): - column = column.label(self._label_name) - - query._entities.append(self) + if parent_bundle: + parent_bundle._entities.append(self) + else: + query._entities.append(self) self.column = column - self.froms = set() - - # look for ORM entities represented within the - # given expression. Try to count only entities - # for columns whose FROM object is in the actual list - # of FROMs for the overall expression - this helps - # subqueries which were built from ORM constructs from - # leaking out their entities into the main select construct - self.actual_froms = set(column._from_objects) - if not search_entities: + if orm_expr: self.entity_zero = _entity if _entity: self.entities = [_entity] @@ -4559,21 +4522,20 @@ class _ColumnEntity(_QueryEntity): self.entities = [] self.mapper = None else: - all_elements = [ - elem - for elem in sql_util.surface_column_elements( - column, include_scalar_selects=False - ) - if "parententity" in elem._annotations - ] - self.entities = util.unique_list( - [elem._annotations["parententity"] for elem in all_elements] + entity = sql_util.extract_first_column_annotation( + column, "parententity" ) + if entity: + self.entities = [entity] + else: + self.entities = [] + if self.entities: self.entity_zero = self.entities[0] self.mapper = self.entity_zero.mapper + elif self.namespace is not None: self.entity_zero = self.namespace self.mapper = None @@ -4581,6 +4543,9 @@ class _ColumnEntity(_QueryEntity): self.entity_zero = None self.mapper = None + if self.entities and self.entity_zero.mapper.with_polymorphic: + query._setup_query_adapters(self.entity_zero, self.entity_zero) + supports_single_entity = False def _deep_entity_zero(self): @@ -4603,24 +4568,21 @@ class _ColumnEntity(_QueryEntity): def entity_zero_or_selectable(self): if self.entity_zero is not None: return self.entity_zero - elif self.actual_froms: - return list(self.actual_froms)[0] + elif self.column._from_objects: + return self.column._from_objects[0] else: return None - def adapt_to_selectable(self, query, sel): - c = _ColumnEntity(query, sel.corresponding_column(self.column)) + def adapt_to_selectable(self, query, sel, parent_bundle=None): + c = _ColumnEntity( + query, + sel.corresponding_column(self.column), + parent_bundle=parent_bundle, + ) c._label_name = self._label_name c.entity_zero = self.entity_zero c.entities = self.entities - def setup_entity(self, ext_info, aliased_adapter): - if "selectable" not in self.__dict__: - self.selectable = ext_info.selectable - - if self.actual_froms.intersection(ext_info.selectable._from_objects): - self.froms.add(ext_info.selectable) - def corresponds_to(self, entity): if self.entity_zero is None: return False @@ -4651,13 +4613,32 @@ class _ColumnEntity(_QueryEntity): def setup_context(self, query, context): column = query._adapt_clause(self.column, False, True) + ezero = self.entity_zero + + if self.mapper: + single_table_crit = self.mapper._single_table_criterion + if single_table_crit is not None: + context.single_inh_entities[ezero] = ( + ezero, + ezero._adapter if ezero.is_aliased_class else None, + ) if column._annotations: # annotated columns perform more slowly in compiler and # result due to the __eq__() method, so use deannotated column = column._deannotate() - context.froms += tuple(self.froms) + if ezero is not None: + # use entity_zero as the from if we have it. this is necessary + # for polymorpic scenarios where our FROM is based on ORM entity, + # not the FROM of the column. but also, don't use it if our column + # doesn't actually have any FROMs that line up, such as when its + # a scalar subquery. + if set(self.column._from_objects).intersection( + ezero.selectable._from_objects + ): + context.froms += (ezero.selectable,) + context.primary_columns.append(column) context.attributes[("fetch_column", self)] = column @@ -4697,6 +4678,7 @@ class QueryContext(object): "partials", "post_load_paths", "identity_token", + "single_inh_entities", ) def __init__(self, query): @@ -4731,6 +4713,7 @@ class QueryContext(object): self.secondary_columns = [] self.eager_order_by = [] self.eager_joins = {} + self.single_inh_entities = {} self.create_eager_joins = [] self.propagate_options = set( o for o in query._with_options if o.propagate_to_loaders diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 63ec21099..731947cba 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -2319,11 +2319,11 @@ class JoinCondition(object): """ self.primaryjoin = _deep_deannotate( - self.primaryjoin, values=("parententity",) + self.primaryjoin, values=("parententity", "orm_key") ) if self.secondaryjoin is not None: self.secondaryjoin = _deep_deannotate( - self.secondaryjoin, values=("parententity",) + self.secondaryjoin, values=("parententity", "orm_key") ) def _determine_joins(self): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4b4fa4052..747ec7e65 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -668,10 +668,11 @@ class AliasedInsp(InspectionAttr): state["represents_outer_join"], ) - def _adapt_element(self, elem): - return self._adapter.traverse(elem)._annotate( - {"parententity": self, "parentmapper": self.mapper} - ) + def _adapt_element(self, elem, key=None): + d = {"parententity": self, "parentmapper": self.mapper} + if key: + d["orm_key"] = key + return self._adapter.traverse(elem)._annotate(d) def _entity_for_mapper(self, mapper): self_poly = self.with_polymorphic_mappers diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index a7a856bba..95aee0468 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -57,7 +57,7 @@ def expect(role, element, **kw): else: resolved = element - if issubclass(resolved.__class__, impl._role_class): + if impl._role_class in resolved.__class__.__mro__: if impl._post_coercion: resolved = impl._post_coercion(resolved, **kw) return resolved @@ -102,13 +102,16 @@ class RoleImpl(object): def _resolve_for_clause_element(self, element, argname=None, **kw): original_element = element - is_clause_element = False + is_clause_element = hasattr(element, "__clause_element__") - while hasattr(element, "__clause_element__") and not isinstance( - element, (elements.ClauseElement, schema.SchemaItem) - ): - element = element.__clause_element__() - is_clause_element = True + if is_clause_element: + while not isinstance( + element, (elements.ClauseElement, schema.SchemaItem) + ): + try: + element = element.__clause_element__() + except AttributeError: + break if not is_clause_element: if self._use_inspection: diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index fe83b163c..3c7f904de 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -364,23 +364,19 @@ def surface_selectables_only(clause): stack.append(elem.table) -def surface_column_elements(clause, include_scalar_selects=True): - """traverse and yield only outer-exposed column elements, such as would - be addressable in the WHERE clause of a SELECT if this element were - in the columns clause.""" +def extract_first_column_annotation(column, annotation_name): + filter_ = (FromGrouping, SelectBase) - filter_ = (FromGrouping,) - if not include_scalar_selects: - filter_ += (SelectBase,) - - stack = deque([clause]) + stack = deque([column]) while stack: elem = stack.popleft() - yield elem + if annotation_name in elem._annotations: + return elem._annotations[annotation_name] for sub in elem.get_children(): if isinstance(sub, filter_): continue stack.append(sub) + return None def selectables_overlap(left, right): |
