diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2020-06-01 20:03:28 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-06-01 20:03:28 +0000 |
| commit | 03c4166fa852d20b1bc1f4693e70a13725a3ba2e (patch) | |
| tree | 4d2fe6fef7e7999ee6a0ae4017907cca9c4a294c /lib/sqlalchemy | |
| parent | 1287853bf59a3e50b0a518d4410ff60305058131 (diff) | |
| parent | 4ecd352a9fbb9dbac7b428fe0f098f665c1f0cb1 (diff) | |
| download | sqlalchemy-03c4166fa852d20b1bc1f4693e70a13725a3ba2e.tar.gz | |
Merge "Improve rendering of core statements w/ ORM elements"
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/engine/cursor.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/baked.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/compiler.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/future/selectable.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/context.py | 518 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/loading.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 77 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 97 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 22 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/util.py | 37 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 32 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 23 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 32 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/compat.py | 13 |
19 files changed, 572 insertions, 364 deletions
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index c32427644..1d832e4af 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -117,8 +117,6 @@ class CursorResultMetaData(ResultMetaData): compiled_statement = context.compiled.statement invoked_statement = context.invoked_statement - # same statement was invoked as the one we cached against, - # return self if compiled_statement is invoked_statement: return self diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 7ac556dcc..f95a30fda 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -228,7 +228,7 @@ class BakedQuery(object): # in 1.4, this is where before_compile() event is # invoked - statement = query._statement_20(orm_results=True) + statement = query._statement_20() # if the query is not safe to cache, we still do everything as though # we did cache it, since the receiver of _bake() assumes subqueryload diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 32975a949..7736a1290 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -401,7 +401,6 @@ Example usage:: from .. import exc from .. import util from ..sql import sqltypes -from ..sql import visitors def compiles(class_, *specs): @@ -456,12 +455,12 @@ def compiles(class_, *specs): def deregister(class_): """Remove all custom compilers associated with a given - :class:`_expression.ClauseElement` type.""" + :class:`_expression.ClauseElement` type. + + """ if hasattr(class_, "_compiler_dispatcher"): - # regenerate default _compiler_dispatch - visitors._generate_compiler_dispatch(class_) - # remove custom directive + class_._compiler_dispatch = class_._original_compiler_dispatch del class_._compiler_dispatcher diff --git a/lib/sqlalchemy/future/selectable.py b/lib/sqlalchemy/future/selectable.py index 74cc13501..407ec9633 100644 --- a/lib/sqlalchemy/future/selectable.py +++ b/lib/sqlalchemy/future/selectable.py @@ -71,6 +71,10 @@ class Select(_LegacySelect): return self.where(*criteria) + def _exported_columns_iterator(self): + meth = SelectState.get_plugin_class(self).exported_columns_iterator + return meth(self) + def _filter_by_zero(self): if self._setup_joins: meth = SelectState.get_plugin_class( diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 5589f0e0c..ba30d203b 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -5,7 +5,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php - from . import attributes from . import interfaces from . import loading @@ -27,6 +26,7 @@ from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql.base import _select_iterables from ..sql.base import CacheableOptions from ..sql.base import CompileState from ..sql.base import Options @@ -90,7 +90,7 @@ class QueryContext(object): self.execution_options = execution_options or _EMPTY_DICT self.bind_arguments = bind_arguments or _EMPTY_DICT self.compile_state = compile_state - self.query = query = compile_state.query + self.query = query = compile_state.select_statement self.session = session self.propagated_loader_options = { @@ -119,10 +119,14 @@ class QueryContext(object): class ORMCompileState(CompileState): + # note this is a dictionary, but the + # default_compile_options._with_polymorphic_adapt_map is a tuple + _with_polymorphic_adapt_map = _EMPTY_DICT + class default_compile_options(CacheableOptions): _cache_key_traversal = [ ("_use_legacy_query_style", InternalTraversal.dp_boolean), - ("_orm_results", InternalTraversal.dp_boolean), + ("_for_statement", InternalTraversal.dp_boolean), ("_bake_ok", InternalTraversal.dp_boolean), ( "_with_polymorphic_adapt_map", @@ -137,8 +141,18 @@ class ORMCompileState(CompileState): ("_for_refresh_state", InternalTraversal.dp_boolean), ] + # set to True by default from Query._statement_20(), to indicate + # the rendered query should look like a legacy ORM query. right + # now this basically indicates we should use tablename_columnname + # style labels. Generally indicates the statement originated + # from a Query object. _use_legacy_query_style = False - _orm_results = True + + # set *only* when we are coming from the Query.statement + # accessor, or a Query-level equivalent such as + # query.subquery(). this supersedes "toplevel". + _for_statement = False + _bake_ok = True _with_polymorphic_adapt_map = () _current_path = _path_registry @@ -149,42 +163,24 @@ class ORMCompileState(CompileState): _set_base_alias = False _for_refresh_state = False - @classmethod - def merge(cls, other): - return cls + other._state_dict() - current_path = _path_registry def __init__(self, *arg, **kw): raise NotImplementedError() - def dispose(self): - self.attributes.clear() - @classmethod def create_for_statement(cls, statement_container, compiler, **kw): - raise NotImplementedError() + """Create a context for a statement given a :class:`.Compiler`. - @classmethod - def _create_for_legacy_query(cls, query, toplevel, for_statement=False): - stmt = query._statement_20(orm_results=not for_statement) - - # this chooses between ORMFromStatementCompileState and - # ORMSelectCompileState. We could also base this on - # query._statement is not None as we have the ORM Query here - # however this is the more general path. - compile_state_cls = CompileState._get_plugin_class_for_plugin( - stmt, "orm" - ) + This method is always invoked in the context of SQLCompiler.process(). - return compile_state_cls._create_for_statement_or_query( - stmt, toplevel, for_statement=for_statement - ) + For a Select object, this would be invoked from + SQLCompiler.visit_select(). For the special FromStatement object used + by Query to indicate "Query.from_statement()", this is called by + FromStatement._compiler_dispatch() that would be called by + SQLCompiler.process(). - @classmethod - def _create_for_statement_or_query( - cls, statement_container, for_statement=False, - ): + """ raise NotImplementedError() @classmethod @@ -266,21 +262,20 @@ class ORMCompileState(CompileState): and ext_info.mapper.persist_selectable not in self._polymorphic_adapters ): - self._mapper_loads_polymorphically_with( - ext_info.mapper, - sql_util.ColumnAdapter( - selectable, ext_info.mapper._equivalent_columns - ), - ) + for mp in ext_info.mapper.iterate_to_root(): + self._mapper_loads_polymorphically_with( + mp, + sql_util.ColumnAdapter(selectable, mp._equivalent_columns), + ) def _mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers or [mapper]: self._polymorphic_adapters[m2] = adapter - for m in m2.iterate_to_root(): + for m in m2.iterate_to_root(): # TODO: redundant ? self._polymorphic_adapters[m.local_table] = adapter -@sql.base.CompileState.plugin_for("orm", "grouping") +@sql.base.CompileState.plugin_for("orm", "orm_from_statement") class ORMFromStatementCompileState(ORMCompileState): _aliased_generations = util.immutabledict() _from_obj_alias = None @@ -294,31 +289,23 @@ class ORMFromStatementCompileState(ORMCompileState): @classmethod def create_for_statement(cls, statement_container, compiler, **kw): - compiler._rewrites_selected_columns = True - toplevel = not compiler.stack - return cls._create_for_statement_or_query( - statement_container, toplevel - ) - @classmethod - def _create_for_statement_or_query( - cls, statement_container, toplevel, for_statement=False, - ): - # from .query import FromStatement - - # assert isinstance(statement_container, FromStatement) + if compiler is not None: + compiler._rewrites_selected_columns = True + toplevel = not compiler.stack + else: + toplevel = True self = cls.__new__(cls) self._primary_entity = None - self.use_orm_style = ( + self.use_legacy_query_style = ( statement_container.compile_options._use_legacy_query_style ) - self.statement_container = self.query = statement_container - self.requested_statement = statement_container.element + self.statement_container = self.select_statement = statement_container + self.requested_statement = statement = statement_container.element self._entities = [] - self._with_polymorphic_adapt_map = {} self._polymorphic_adapters = {} self._no_yield_pers = set() @@ -349,12 +336,6 @@ class ORMFromStatementCompileState(ORMCompileState): self.create_eager_joins = [] self._fallback_from_clauses = [] - self._setup_for_statement() - - return self - - def _setup_for_statement(self): - statement = self.requested_statement if ( isinstance(statement, expression.SelectBase) and not statement._is_textual @@ -392,6 +373,8 @@ class ORMFromStatementCompileState(ORMCompileState): # for entity in self._entities: # entity.setup_compile_state(self) + return self + def _adapt_col_list(self, cols, current_adapter): return cols @@ -401,7 +384,8 @@ class ORMFromStatementCompileState(ORMCompileState): @sql.base.CompileState.plugin_for("orm", "select") class ORMSelectCompileState(ORMCompileState, SelectState): - _joinpath = _joinpoint = util.immutabledict() + _joinpath = _joinpoint = _EMPTY_DICT + _from_obj_alias = None _has_mapper_entities = False @@ -417,77 +401,71 @@ class ORMSelectCompileState(ORMCompileState, SelectState): @classmethod def create_for_statement(cls, statement, compiler, **kw): + """compiler hook, we arrive here from compiler.visit_select() only.""" + if not statement._is_future: return SelectState(statement, compiler, **kw) - toplevel = not compiler.stack + if compiler is not None: + toplevel = not compiler.stack + compiler._rewrites_selected_columns = True + else: + toplevel = True - compiler._rewrites_selected_columns = True + select_statement = statement - orm_state = cls._create_for_statement_or_query( - statement, for_statement=True, toplevel=toplevel - ) - SelectState.__init__(orm_state, orm_state.statement, compiler, **kw) - return orm_state - - @classmethod - def _create_for_statement_or_query( - cls, query, toplevel, for_statement=False, _entities_only=False - ): - assert isinstance(query, future.Select) - - query.compile_options = cls.default_compile_options.merge( - query.compile_options + # if we are a select() that was never a legacy Query, we won't + # have ORM level compile options. + statement.compile_options = cls.default_compile_options.safe_merge( + statement.compile_options ) self = cls.__new__(cls) - self._primary_entity = None - - self.query = query - self.use_orm_style = query.compile_options._use_legacy_query_style + self.select_statement = select_statement - self.select_statement = select_statement = query + # indicates this select() came from Query.statement + self.for_statement = ( + for_statement + ) = select_statement.compile_options._for_statement - if not hasattr(select_statement.compile_options, "_orm_results"): - select_statement.compile_options = cls.default_compile_options - select_statement.compile_options += {"_orm_results": for_statement} - else: - for_statement = not select_statement.compile_options._orm_results + if not for_statement and not toplevel: + # for subqueries, turn off eagerloads. + # if "for_statement" mode is set, Query.subquery() + # would have set this flag to False already if that's what's + # desired + select_statement.compile_options += { + "_enable_eagerloads": False, + } - self.query = query + # generally if we are from Query or directly from a select() + self.use_legacy_query_style = ( + select_statement.compile_options._use_legacy_query_style + ) self._entities = [] - + self._primary_entity = None self._aliased_generations = {} self._polymorphic_adapters = {} self._no_yield_pers = set() # legacy: only for query.with_polymorphic() - self._with_polymorphic_adapt_map = wpam = dict( - select_statement.compile_options._with_polymorphic_adapt_map - ) - if wpam: + if select_statement.compile_options._with_polymorphic_adapt_map: + self._with_polymorphic_adapt_map = dict( + select_statement.compile_options._with_polymorphic_adapt_map + ) self._setup_with_polymorphics() _QueryEntity.to_compile_state(self, select_statement._raw_columns) - if _entities_only: - return self - - self.compile_options = query.compile_options - - # TODO: the name of this flag "for_statement" has to change, - # as it is difficult to distinguish from the "query._statement" use - # case which is something totally different - self.for_statement = for_statement + self.compile_options = select_statement.compile_options # determine label style. we can make different decisions here. # at the moment, trying to see if we can always use DISAMBIGUATE_ONLY # rather than LABEL_STYLE_NONE, and if we can use disambiguate style # for new style ORM selects too. if self.select_statement._label_style is LABEL_STYLE_NONE: - if self.use_orm_style and not for_statement: + if self.use_legacy_query_style and not self.for_statement: self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL else: self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY @@ -522,129 +500,16 @@ class ORMSelectCompileState(ORMCompileState, SelectState): info.selectable for info in select_statement._from_obj ] + # this is a fairly arbitrary break into a second method, + # so it might be nicer to break up create_for_statement() + # and _setup_for_generate into three or four logical sections self._setup_for_generate() - return self - - @classmethod - def _create_entities_collection(cls, query): - """Creates a partial ORMSelectCompileState that includes - the full collection of _MapperEntity and other _QueryEntity objects. - - Supports a few remaining use cases that are pre-compilation - but still need to gather some of the column / adaption information. - - """ - self = cls.__new__(cls) - - self._entities = [] - self._primary_entity = None - self._aliased_generations = {} - self._polymorphic_adapters = {} - - # legacy: only for query.with_polymorphic() - self._with_polymorphic_adapt_map = wpam = dict( - query.compile_options._with_polymorphic_adapt_map - ) - if wpam: - self._setup_with_polymorphics() + if compiler is not None: + SelectState.__init__(self, self.statement, compiler, **kw) - _QueryEntity.to_compile_state(self, query._raw_columns) return self - @classmethod - def determine_last_joined_entity(cls, statement): - setup_joins = statement._setup_joins - - if not setup_joins: - return None - - (target, onclause, from_, flags) = setup_joins[-1] - - if isinstance(target, interfaces.PropComparator): - return target.entity - else: - return target - - def _setup_with_polymorphics(self): - # legacy: only for query.with_polymorphic() - for ext_info, wp in self._with_polymorphic_adapt_map.items(): - self._mapper_loads_polymorphically_with(ext_info, wp._adapter) - - def _set_select_from_alias(self): - - query = self.select_statement # query - - assert self.compile_options._set_base_alias - assert len(query._from_obj) == 1 - - adapter = self._get_select_from_alias_from_obj(query._from_obj[0]) - if adapter: - self.compile_options += {"_enable_single_crit": False} - self._from_obj_alias = adapter - - def _get_select_from_alias_from_obj(self, from_obj): - info = from_obj - - if "parententity" in info._annotations: - info = info._annotations["parententity"] - - if hasattr(info, "mapper"): - if not info.is_aliased_class: - raise sa_exc.ArgumentError( - "A selectable (FromClause) instance is " - "expected when the base alias is being set." - ) - else: - return info._adapter - - elif isinstance(info.selectable, sql.selectable.AliasedReturnsRows): - equivs = self._all_equivs() - return sql_util.ColumnAdapter(info, equivs) - else: - return None - - def _mapper_zero(self): - """return the Mapper associated with the first QueryEntity.""" - return self._entities[0].mapper - - def _entity_zero(self): - """Return the 'entity' (mapper or AliasedClass) associated - with the first QueryEntity, or alternatively the 'select from' - entity if specified.""" - - for ent in self.from_clauses: - if "parententity" in ent._annotations: - return ent._annotations["parententity"] - for qent in self._entities: - if qent.entity_zero: - return qent.entity_zero - - return None - - def _only_full_mapper_zero(self, methname): - if self._entities != [self._primary_entity]: - raise sa_exc.InvalidRequestError( - "%s() can only be used against " - "a single mapped class." % methname - ) - return self._primary_entity.entity_zero - - def _only_entity_zero(self, rationale=None): - if len(self._entities) > 1: - raise sa_exc.InvalidRequestError( - rationale - or "This operation requires a Query " - "against a single mapper." - ) - return self._entity_zero() - - def _all_equivs(self): - equivs = {} - for ent in self._mapper_entities: - equivs.update(ent.mapper._equivalent_columns) - return equivs - def _setup_for_generate(self): query = self.select_statement @@ -772,6 +637,140 @@ class ORMSelectCompileState(ORMCompileState, SelectState): {"deepentity": ezero} ) + @classmethod + def _create_entities_collection(cls, query): + """Creates a partial ORMSelectCompileState that includes + the full collection of _MapperEntity and other _QueryEntity objects. + + Supports a few remaining use cases that are pre-compilation + but still need to gather some of the column / adaption information. + + """ + self = cls.__new__(cls) + + self._entities = [] + self._primary_entity = None + self._aliased_generations = {} + self._polymorphic_adapters = {} + + # legacy: only for query.with_polymorphic() + if query.compile_options._with_polymorphic_adapt_map: + self._with_polymorphic_adapt_map = dict( + query.compile_options._with_polymorphic_adapt_map + ) + self._setup_with_polymorphics() + + _QueryEntity.to_compile_state(self, query._raw_columns) + return self + + @classmethod + def determine_last_joined_entity(cls, statement): + setup_joins = statement._setup_joins + + if not setup_joins: + return None + + (target, onclause, from_, flags) = setup_joins[-1] + + if isinstance(target, interfaces.PropComparator): + return target.entity + else: + return target + + @classmethod + def exported_columns_iterator(cls, statement): + for element in statement._raw_columns: + if ( + element.is_selectable + and "entity_namespace" in element._annotations + ): + for elem in _select_iterables( + element._annotations["entity_namespace"].columns + ): + yield elem + else: + for elem in _select_iterables([element]): + yield elem + + def _setup_with_polymorphics(self): + # legacy: only for query.with_polymorphic() + for ext_info, wp in self._with_polymorphic_adapt_map.items(): + self._mapper_loads_polymorphically_with(ext_info, wp._adapter) + + def _set_select_from_alias(self): + + query = self.select_statement # query + + assert self.compile_options._set_base_alias + assert len(query._from_obj) == 1 + + adapter = self._get_select_from_alias_from_obj(query._from_obj[0]) + if adapter: + self.compile_options += {"_enable_single_crit": False} + self._from_obj_alias = adapter + + def _get_select_from_alias_from_obj(self, from_obj): + info = from_obj + + if "parententity" in info._annotations: + info = info._annotations["parententity"] + + if hasattr(info, "mapper"): + if not info.is_aliased_class: + raise sa_exc.ArgumentError( + "A selectable (FromClause) instance is " + "expected when the base alias is being set." + ) + else: + return info._adapter + + elif isinstance(info.selectable, sql.selectable.AliasedReturnsRows): + equivs = self._all_equivs() + return sql_util.ColumnAdapter(info, equivs) + else: + return None + + def _mapper_zero(self): + """return the Mapper associated with the first QueryEntity.""" + return self._entities[0].mapper + + def _entity_zero(self): + """Return the 'entity' (mapper or AliasedClass) associated + with the first QueryEntity, or alternatively the 'select from' + entity if specified.""" + + for ent in self.from_clauses: + if "parententity" in ent._annotations: + return ent._annotations["parententity"] + for qent in self._entities: + if qent.entity_zero: + return qent.entity_zero + + return None + + def _only_full_mapper_zero(self, methname): + if self._entities != [self._primary_entity]: + raise sa_exc.InvalidRequestError( + "%s() can only be used against " + "a single mapped class." % methname + ) + return self._primary_entity.entity_zero + + def _only_entity_zero(self, rationale=None): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError( + rationale + or "This operation requires a Query " + "against a single mapper." + ) + return self._entity_zero() + + def _all_equivs(self): + equivs = {} + for ent in self._mapper_entities: + equivs.update(ent.mapper._equivalent_columns) + return equivs + def _compound_eager_statement(self): # for eager joins present and LIMIT/OFFSET/DISTINCT, # wrap the query inside a select, @@ -920,6 +919,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement = Select.__new__(Select) statement._raw_columns = raw_columns statement._from_obj = from_obj + statement._label_style = label_style if where_criteria: @@ -1653,31 +1653,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState): "target." ) - aliased_entity = ( - right_mapper - and not right_is_aliased - and ( - # TODO: there is a reliance here on aliasing occurring - # when we join to a polymorphic mapper that doesn't actually - # need aliasing. When this condition is present, we should - # be able to say mapper_loads_polymorphically_with() - # and render the straight polymorphic selectable. this - # does not appear to be possible at the moment as the - # adapter no longer takes place on the rest of the query - # and it's not clear where that's failing to happen. - ( - right_mapper.with_polymorphic - and isinstance( - right_mapper._with_polymorphic_selectable, - expression.AliasedReturnsRows, - ) - ) - or overlap - # test for overlap: - # orm/inheritance/relationships.py - # SelfReferentialM2MTest - ) - ) + # test for overlap: + # orm/inheritance/relationships.py + # SelfReferentialM2MTest + aliased_entity = right_mapper and not right_is_aliased and overlap if not need_adapter and (create_aliases or aliased_entity): # there are a few places in the ORM that automatic aliasing @@ -1707,7 +1686,30 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self._aliased_generations[aliased_generation] = ( adapter, ) + self._aliased_generations.get(aliased_generation, ()) - + elif ( + not r_info.is_clause_element + and not right_is_aliased + and right_mapper.with_polymorphic + and isinstance( + right_mapper._with_polymorphic_selectable, + expression.AliasedReturnsRows, + ) + ): + # for the case where the target mapper has a with_polymorphic + # set up, ensure an adapter is set up for criteria that works + # against this mapper. Previously, this logic used to + # use the "create_aliases or aliased_entity" case to generate + # an aliased() object, but this creates an alias that isn't + # strictly necessary. + # see test/orm/test_core_compilation.py + # ::RelNaturalAliasedJoinsTest::test_straight + # and similar + self._mapper_loads_polymorphically_with( + right_mapper, + sql_util.ColumnAdapter( + right_mapper.selectable, right_mapper._equivalent_columns, + ), + ) # if the onclause is a ClauseElement, adapt it with any # adapters that are in place right now if isinstance(onclause, expression.ClauseElement): @@ -1755,8 +1757,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState): "offset_clause": self.select_statement._offset_clause, "distinct": self.distinct, "distinct_on": self.distinct_on, - "prefixes": self.query._prefixes, - "suffixes": self.query._suffixes, + "prefixes": self.select_statement._prefixes, + "suffixes": self.select_statement._suffixes, "group_by": self.group_by or None, } @@ -2036,7 +2038,14 @@ class _MapperEntity(_QueryEntity): self._with_polymorphic_mappers = ext_info.with_polymorphic_mappers self._polymorphic_discriminator = ext_info.polymorphic_on - if mapper.with_polymorphic or mapper._requires_row_aliasing: + if ( + mapper.with_polymorphic + # controversy - only if inheriting mapper is also + # polymorphic? + # or (mapper.inherits and mapper.inherits.with_polymorphic) + or mapper.inherits + or mapper._requires_row_aliasing + ): compile_state._create_with_polymorphic_adapter( ext_info, self.selectable ) @@ -2361,7 +2370,7 @@ class _ORMColumnEntity(_ColumnEntity): _entity._post_inspect self.entity_zero = self.entity_zero_or_selectable = ezero = _entity - self.mapper = _entity.mapper + self.mapper = mapper = _entity.mapper if parent_bundle: parent_bundle._entities.append(self) @@ -2373,7 +2382,11 @@ class _ORMColumnEntity(_ColumnEntity): self._extra_entities = (self.expr, self.column) - if self.mapper.with_polymorphic: + if ( + mapper.with_polymorphic + or mapper.inherits + or mapper._requires_row_aliasing + ): compile_state._create_with_polymorphic_adapter( ezero, ezero.selectable ) @@ -2414,6 +2427,7 @@ class _ORMColumnEntity(_ColumnEntity): column = current_adapter(self.column, False) else: column = self.column + ezero = self.entity_zero single_table_crit = self.mapper._single_table_criterion diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 88d01eb0f..424ed5dfe 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -345,7 +345,7 @@ def load_on_pk_identity( if load_options is None: load_options = QueryContext.default_load_options - compile_options = ORMCompileState.default_compile_options.merge( + compile_options = ORMCompileState.default_compile_options.safe_merge( q.compile_options ) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 7bfe70c36..4166e6d2a 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1743,6 +1743,11 @@ class Mapper( or prop.columns[0] is self.polymorphic_on ) + if isinstance(col, expression.Label): + # new in 1.4, get column property against expressions + # to be addressable in subqueries + col.key = col._key_label = key + self.columns.add(col, key) for col in prop.columns + prop._orig_columns: for col in col.proxy_set: @@ -2282,6 +2287,29 @@ class Mapper( ) ) + def _columns_plus_keys(self, polymorphic_mappers=()): + if polymorphic_mappers: + poly_properties = self._iterate_polymorphic_properties( + polymorphic_mappers + ) + else: + poly_properties = self._polymorphic_properties + + return [ + (prop.key, prop.columns[0]) + for prop in poly_properties + if isinstance(prop, properties.ColumnProperty) + ] + + @HasMemoized.memoized_attribute + def _polymorphic_adapter(self): + if self.with_polymorphic: + return sql_util.ColumnAdapter( + self.selectable, equivalents=self._equivalent_columns + ) + else: + return None + def _iterate_polymorphic_properties(self, mappers=None): """Return an iterator of MapperProperty objects which will render into a SELECT.""" diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 4cf501e3f..02f0752a5 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -264,6 +264,7 @@ class ColumnProperty(StrategizedProperty): def do_init(self): super(ColumnProperty, self).do_init() + if len(self.columns) > 1 and set(self.parent.primary_key).issuperset( self.columns ): @@ -339,28 +340,51 @@ class ColumnProperty(StrategizedProperty): __slots__ = "__clause_element__", "info", "expressions" + def _orm_annotate_column(self, column): + """annotate and possibly adapt a column to be returned + as the mapped-attribute exposed version of the column. + + The column in this context needs to act as much like the + column in an ORM mapped context as possible, so includes + annotations to give hints to various ORM functions as to + the source entity of this column. It also adapts it + to the mapper's with_polymorphic selectable if one is + present. + + """ + + pe = self._parententity + annotations = { + "entity_namespace": pe, + "parententity": pe, + "parentmapper": pe, + "orm_key": self.prop.key, + } + + col = column + + # for a mapper with polymorphic_on and an adapter, return + # the column against the polymorphic selectable. + # see also orm.util._orm_downgrade_polymorphic_columns + # for the reverse operation. + if self._parentmapper._polymorphic_adapter: + mapper_local_col = col + col = self._parentmapper._polymorphic_adapter.traverse(col) + + # this is a clue to the ORM Query etc. that this column + # was adapted to the mapper's polymorphic_adapter. the + # ORM uses this hint to know which column its adapting. + annotations["adapt_column"] = mapper_local_col + + return col._annotate(annotations)._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": pe} + ) + def _memoized_method___clause_element__(self): if self.adapter: return self.adapter(self.prop.columns[0], self.prop.key) else: - pe = self._parententity - # no adapter, so we aren't aliased - # assert self._parententity is self._parentmapper - return ( - self.prop.columns[0] - ._annotate( - { - "entity_namespace": pe, - "parententity": pe, - "parentmapper": pe, - "orm_key": self.prop.key, - "compile_state_plugin": "orm", - } - ) - ._set_propagate_attrs( - {"compile_state_plugin": "orm", "plugin_subject": pe} - ) - ) + return self._orm_annotate_column(self.prop.columns[0]) def _memoized_attr_info(self): """The .info dictionary for this attribute.""" @@ -384,23 +408,8 @@ class ColumnProperty(StrategizedProperty): for col in self.prop.columns ] else: - # no adapter, so we aren't aliased - # assert self._parententity is self._parentmapper return [ - col._annotate( - { - "parententity": self._parententity, - "parentmapper": self._parententity, - "orm_key": self.prop.key, - "compile_state_plugin": "orm", - } - )._set_propagate_attrs( - { - "compile_state_plugin": "orm", - "plugin_subject": self._parententity, - } - ) - for col in self.prop.columns + self._orm_annotate_column(col) for col in self.prop.columns ] def _fallback_getattr(self, key): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 97a81e30f..5137f9b1d 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -360,7 +360,7 @@ class Query( ): # if we don't have legacy top level aliasing features in use # then convert to a future select() directly - stmt = self._statement_20() + stmt = self._statement_20(for_statement=True) else: stmt = self._compile_state(for_statement=True).statement @@ -371,7 +371,24 @@ class Query( return stmt - def _statement_20(self, orm_results=False): + def _final_statement(self, legacy_query_style=True): + """Return the 'final' SELECT statement for this :class:`.Query`. + + This is the Core-only select() that will be rendered by a complete + compilation of this query, and is what .statement used to return + in 1.3. + + This method creates a complete compile state so is fairly expensive. + + """ + + q = self._clone() + + return q._compile_state( + use_legacy_query_style=legacy_query_style + ).statement + + def _statement_20(self, for_statement=False, use_legacy_query_style=True): # TODO: this event needs to be deprecated, as it currently applies # only to ORM query and occurs at this spot that is now more # or less an artificial spot @@ -384,7 +401,10 @@ class Query( self.compile_options += {"_bake_ok": False} compile_options = self.compile_options - compile_options += {"_use_legacy_query_style": True} + compile_options += { + "_for_statement": for_statement, + "_use_legacy_query_style": use_legacy_query_style, + } if self._statement is not None: stmt = FromStatement(self._raw_columns, self._statement) @@ -404,13 +424,16 @@ class Query( compile_options=compile_options, ) - if not orm_results: - stmt.compile_options += {"_orm_results": False} - stmt._propagate_attrs = self._propagate_attrs return stmt - def subquery(self, name=None, with_labels=False, reduce_columns=False): + def subquery( + self, + name=None, + with_labels=False, + reduce_columns=False, + _legacy_core_statement=False, + ): """return the full SELECT statement represented by this :class:`_query.Query`, embedded within an :class:`_expression.Alias`. @@ -436,7 +459,11 @@ class Query( q = self.enable_eagerloads(False) if with_labels: q = q.with_labels() - q = q.statement + + if _legacy_core_statement: + q = q._compile_state(for_statement=True).statement + else: + q = q.statement if reduce_columns: q = q.reduce_columns() @@ -943,7 +970,7 @@ class Query( # tablename_colname style is used which at the moment is asserted # in a lot of unit tests :) - statement = self._statement_20(orm_results=True).apply_labels() + statement = self._statement_20().apply_labels() return db_load_fn( self.session, statement, @@ -1328,13 +1355,13 @@ class Query( self.with_labels() .enable_eagerloads(False) .correlate(None) - .subquery() + .subquery(_legacy_core_statement=True) ._anonymous_fromclause() ) parententity = self._raw_columns[0]._annotations.get("parententity") if parententity: - ac = aliased(parententity, alias=fromclause) + ac = aliased(parententity.mapper, alias=fromclause) q = self._from_selectable(ac) else: q = self._from_selectable(fromclause) @@ -2782,7 +2809,7 @@ class Query( def _iter(self): # new style execution. params = self.load_options._params - statement = self._statement_20(orm_results=True) + statement = self._statement_20() result = self.session.execute( statement, params, @@ -2808,7 +2835,7 @@ class Query( ) def __str__(self): - statement = self._statement_20(orm_results=True) + statement = self._statement_20() try: bind = ( @@ -2879,9 +2906,8 @@ class Query( "for linking ORM results to arbitrary select constructs.", version="1.4", ) - compile_state = ORMCompileState._create_for_legacy_query( - self, toplevel=True - ) + compile_state = self._compile_state(for_statement=False) + context = QueryContext( compile_state, self.session, self.load_options ) @@ -3294,10 +3320,35 @@ class Query( return update_op.rowcount def _compile_state(self, for_statement=False, **kw): - return ORMCompileState._create_for_legacy_query( - self, toplevel=True, for_statement=for_statement, **kw + """Create an out-of-compiler ORMCompileState object. + + The ORMCompileState object is normally created directly as a result + of the SQLCompiler.process() method being handed a Select() + or FromStatement() object that uses the "orm" plugin. This method + provides a means of creating this ORMCompileState object directly + without using the compiler. + + This method is used only for deprecated cases, which include + the .from_self() method for a Query that has multiple levels + of .from_self() in use, as well as the instances() method. It is + also used within the test suite to generate ORMCompileState objects + for test purposes. + + """ + + stmt = self._statement_20(for_statement=for_statement, **kw) + assert for_statement == stmt.compile_options._for_statement + + # this chooses between ORMFromStatementCompileState and + # ORMSelectCompileState. We could also base this on + # query._statement is not None as we have the ORM Query here + # however this is the more general path. + compile_state_cls = ORMCompileState._get_plugin_class_for_plugin( + stmt, "orm" ) + return compile_state_cls.create_for_statement(stmt, None) + def _compile_context(self, for_statement=False): compile_state = self._compile_state(for_statement=for_statement) context = QueryContext(compile_state, self.session, self.load_options) @@ -3311,6 +3362,8 @@ class FromStatement(SelectStatementGrouping, Executable): """ + __visit_name__ = "orm_from_statement" + compile_options = ORMFromStatementCompileState.default_compile_options _compile_state_factory = ORMFromStatementCompileState.create_for_statement @@ -3329,6 +3382,14 @@ class FromStatement(SelectStatementGrouping, Executable): super(FromStatement, self).__init__(element) def _compiler_dispatch(self, compiler, **kw): + + """provide a fixed _compiler_dispatch method. + + This is roughly similar to using the sqlalchemy.ext.compiler + ``@compiles`` extension. + + """ + compile_state = self._compile_state_factory(self, compiler, **kw) toplevel = not compiler.stack diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index e82cd174f..683f2b978 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1170,9 +1170,9 @@ class RelationshipProperty(StrategizedProperty): def __clause_element__(self): adapt_from = self._source_selectable() if self._of_type: - of_type_mapper = inspect(self._of_type).mapper + of_type_entity = inspect(self._of_type) else: - of_type_mapper = None + of_type_entity = None ( pj, @@ -1184,7 +1184,7 @@ class RelationshipProperty(StrategizedProperty): ) = self.property._create_joins( source_selectable=adapt_from, source_polymorphic=True, - of_type_mapper=of_type_mapper, + of_type_entity=of_type_entity, alias_secondary=True, ) if sj is not None: @@ -1311,7 +1311,6 @@ class RelationshipProperty(StrategizedProperty): secondary, target_adapter, ) = self.property._create_joins( - dest_polymorphic=True, dest_selectable=to_selectable, source_selectable=source_selectable, ) @@ -2424,9 +2423,8 @@ class RelationshipProperty(StrategizedProperty): self, source_polymorphic=False, source_selectable=None, - dest_polymorphic=False, dest_selectable=None, - of_type_mapper=None, + of_type_entity=None, alias_secondary=False, ): @@ -2439,9 +2437,17 @@ class RelationshipProperty(StrategizedProperty): if source_polymorphic and self.parent.with_polymorphic: source_selectable = self.parent._with_polymorphic_selectable + if of_type_entity: + dest_mapper = of_type_entity.mapper + if dest_selectable is None: + dest_selectable = of_type_entity.selectable + aliased = True + else: + dest_mapper = self.mapper + if dest_selectable is None: dest_selectable = self.entity.selectable - if dest_polymorphic and self.mapper.with_polymorphic: + if self.mapper.with_polymorphic: aliased = True if self._is_self_referential and source_selectable is None: @@ -2453,8 +2459,6 @@ class RelationshipProperty(StrategizedProperty): ): aliased = True - dest_mapper = of_type_mapper or self.mapper - single_crit = dest_mapper._single_table_criterion aliased = aliased or ( source_selectable is not None diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 626018997..2b8c384c9 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1143,7 +1143,7 @@ class SubqueryLoader(PostLoader): ) = self._get_leftmost(subq_path) orig_query = compile_state.attributes.get( - ("orig_query", SubqueryLoader), compile_state.query + ("orig_query", SubqueryLoader), compile_state.select_statement ) # generate a new Query from the original, then diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ce37d962e..85f4f85d1 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -466,7 +466,7 @@ class AliasedClass(object): def __init__( self, - cls, + mapped_class_or_ac, alias=None, name=None, flat=False, @@ -478,7 +478,9 @@ class AliasedClass(object): use_mapper_path=False, represents_outer_join=False, ): - mapper = _class_to_mapper(cls) + insp = inspection.inspect(mapped_class_or_ac) + mapper = insp.mapper + if alias is None: alias = mapper._with_polymorphic_selectable._anonymous_fromclause( name=name, flat=flat @@ -486,7 +488,7 @@ class AliasedClass(object): self._aliased_insp = AliasedInsp( self, - mapper, + insp, alias, name, with_polymorphic_mappers @@ -617,7 +619,7 @@ class AliasedInsp( def __init__( self, entity, - mapper, + inspected, selectable, name, with_polymorphic_mappers, @@ -627,6 +629,10 @@ class AliasedInsp( adapt_on_names, represents_outer_join, ): + + mapped_class_or_ac = inspected.entity + mapper = inspected.mapper + self._weak_entity = weakref.ref(entity) self.mapper = mapper self.selectable = ( @@ -665,9 +671,12 @@ class AliasedInsp( adapt_on_names=adapt_on_names, anonymize_labels=True, ) + if inspected.is_aliased_class: + self._adapter = inspected._adapter.wrap(self._adapter) self._adapt_on_names = adapt_on_names - self._target = mapper.class_ + self._target = mapped_class_or_ac + # self._target = mapper.class_ # mapped_class_or_ac @property def entity(self): @@ -795,6 +804,21 @@ class AliasedInsp( def _memoized_values(self): return {} + @util.memoized_property + def columns(self): + if self._is_with_polymorphic: + cols_plus_keys = self.mapper._columns_plus_keys( + [ent.mapper for ent in self._with_polymorphic_entities] + ) + else: + cols_plus_keys = self.mapper._columns_plus_keys() + + cols_plus_keys = [ + (key, self._adapt_element(col)) for key, col in cols_plus_keys + ] + + return ColumnCollection(cols_plus_keys) + def _memo(self, key, callable_, *args, **kw): if key in self._memoized_values: return self._memoized_values[key] @@ -1290,8 +1314,7 @@ class _ORMJoin(expression.Join): source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, - dest_polymorphic=True, - of_type_mapper=right_info.mapper, + of_type_entity=right_info, alias_secondary=True, ) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6415d4b37..f14319089 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -522,7 +522,12 @@ class _MetaOptions(type): def __init__(cls, classname, bases, dict_): cls._cache_attrs = tuple( - sorted(d for d in dict_ if not d.startswith("__")) + sorted( + d + for d in dict_ + if not d.startswith("__") + and d not in ("_cache_key_traversal",) + ) ) type.__init__(cls, classname, bases, dict_) @@ -561,6 +566,31 @@ class Options(util.with_metaclass(_MetaOptions)): def _state_dict(cls): return cls._state_dict_const + @classmethod + def safe_merge(cls, other): + d = other._state_dict() + + # only support a merge with another object of our class + # and which does not have attrs that we dont. otherwise + # we risk having state that might not be part of our cache + # key strategy + + if ( + cls is not other.__class__ + and other._cache_attrs + and set(other._cache_attrs).difference(cls._cache_attrs) + ): + raise TypeError( + "other element %r is not empty, is not of type %s, " + "and contains attributes not covered here %r" + % ( + other, + cls, + set(other._cache_attrs).difference(cls._cache_attrs), + ) + ) + return cls + d + class CacheableOptions(Options, HasCacheKey): @hybridmethod diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 287e53724..fa2888a23 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -878,6 +878,7 @@ class ColumnElement( key = self._proxy_key else: key = name + co = ColumnClause( coercions.expect(roles.TruncatedLabelRole, name) if name_is_truncatable @@ -885,6 +886,7 @@ class ColumnElement( type_=getattr(self, "type", None), _selectable=selectable, ) + co._propagate_attrs = selectable._propagate_attrs co._proxies = [self] if selectable._is_clone_of is not None: @@ -1284,6 +1286,7 @@ class BindParameter(roles.InElementRole, ColumnElement): """ + if required is NO_ARG: required = value is NO_ARG and callable_ is None if value is NO_ARG: @@ -1302,6 +1305,7 @@ class BindParameter(roles.InElementRole, ColumnElement): id(self), re.sub(r"[%\(\) \$]+", "_", key).strip("_") if key is not None + and not isinstance(key, _anonymous_label) else "param", ) ) @@ -4182,16 +4186,27 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): return self.element._from_objects def _make_proxy(self, selectable, name=None, **kw): + name = self.name if not name else name + key, e = self.element._make_proxy( selectable, - name=name if name else self.name, + name=name, disallow_is_literal=True, + name_is_truncatable=isinstance(name, _truncated_label), ) + # TODO: want to remove this assertion at some point. all + # _make_proxy() implementations will give us back the key that + # is our "name" in the first place. based on this we can + # safely return our "self.key" as the key here, to support a new + # case where the key and name are separate. + assert key == self.name + e._propagate_attrs = selectable._propagate_attrs e._proxies.append(self) if self._type is not None: e.type = self._type - return key, e + + return self.key, e class ColumnClause( @@ -4240,7 +4255,7 @@ class ColumnClause( __visit_name__ = "column" _traverse_internals = [ - ("name", InternalTraversal.dp_string), + ("name", InternalTraversal.dp_anon_name), ("type", InternalTraversal.dp_type), ("table", InternalTraversal.dp_clauseelement), ("is_literal", InternalTraversal.dp_boolean), @@ -4410,10 +4425,8 @@ class ColumnClause( def _gen_label(self, name, dedupe_on_key=True): t = self.table - if self.is_literal: return None - elif t is not None and t.named_with_column: if getattr(t, "schema", None): label = t.schema.replace(".", "_") + "_" + t.name + "_" + name diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 170e016a5..d6845e05f 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3451,8 +3451,8 @@ class SelectState(util.MemoizedSlots, CompileState): self.columns_plus_names = statement._generate_columns_plus_names(True) def _get_froms(self, statement): - froms = [] seen = set() + froms = [] for item in itertools.chain( itertools.chain.from_iterable( @@ -3474,6 +3474,16 @@ class SelectState(util.MemoizedSlots, CompileState): froms.append(item) seen.update(item._cloned_set) + toremove = set( + itertools.chain.from_iterable( + [_expand_cloned(f._hide_froms) for f in froms] + ) + ) + if toremove: + # filter out to FROM clauses not in the list, + # using a list to maintain ordering + froms = [f for f in froms if f not in toremove] + return froms def _get_display_froms( @@ -3490,16 +3500,6 @@ class SelectState(util.MemoizedSlots, CompileState): froms = self.froms - toremove = set( - itertools.chain.from_iterable( - [_expand_cloned(f._hide_froms) for f in froms] - ) - ) - if toremove: - # filter out to FROM clauses not in the list, - # using a list to maintain ordering - froms = [f for f in froms if f not in toremove] - if self.statement._correlate: to_correlate = self.statement._correlate if to_correlate: @@ -3557,7 +3557,7 @@ class SelectState(util.MemoizedSlots, CompileState): def _memoized_attr__label_resolve_dict(self): with_cols = dict( (c._resolve_label or c._label or c.key, c) - for c in _select_iterables(self.statement._raw_columns) + for c in self.statement._exported_columns_iterator() if c._allow_label_resolve ) only_froms = dict( @@ -3578,6 +3578,10 @@ class SelectState(util.MemoizedSlots, CompileState): else: return None + @classmethod + def exported_columns_iterator(cls, statement): + return _select_iterables(statement._raw_columns) + def _setup_joins(self, args): for (right, onclause, left, flags) in args: isouter = flags["isouter"] @@ -4599,7 +4603,7 @@ class Select( pa = None collection = [] - for c in _select_iterables(self._raw_columns): + for c in self._exported_columns_iterator(): # we use key_label since this name is intended for targeting # within the ColumnCollection only, it's not related to SQL # rendering which always uses column name for SQL label names @@ -4630,7 +4634,7 @@ class Select( return self def _generate_columns_plus_names(self, anon_for_dupe_key): - cols = _select_iterables(self._raw_columns) + cols = self._exported_columns_iterator() # when use_labels is on: # in all cases == if we see the same label name, use _label_anon_label diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index a38088a27..388097e45 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -18,6 +18,7 @@ NO_CACHE = util.symbol("no_cache") CACHE_IN_PLACE = util.symbol("cache_in_place") CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key") STATIC_CACHE_KEY = util.symbol("static_cache_key") +ANON_NAME = util.symbol("anon_name") def compare(obj1, obj2, **kw): @@ -33,6 +34,7 @@ class HasCacheKey(object): _cache_key_traversal = NO_CACHE __slots__ = () + @util.preload_module("sqlalchemy.sql.elements") def _gen_cache_key(self, anon_map, bindparams): """return an optional cache key. @@ -54,6 +56,8 @@ class HasCacheKey(object): """ + elements = util.preloaded.sql_elements + idself = id(self) if anon_map is not None: @@ -102,6 +106,10 @@ class HasCacheKey(object): result += (attrname, obj) elif meth is STATIC_CACHE_KEY: result += (attrname, obj._static_cache_key) + elif meth is ANON_NAME: + if elements._anonymous_label in obj.__class__.__mro__: + obj = obj.apply_map(anon_map) + result += (attrname, obj) elif meth is CALL_GEN_CACHE_KEY: result += ( attrname, @@ -321,6 +329,7 @@ class _CacheKey(ExtendedInternalTraversal): ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE visit_statement_hint_list = CACHE_IN_PLACE visit_type = STATIC_CACHE_KEY + visit_anon_name = ANON_NAME def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) @@ -387,15 +396,6 @@ class _CacheKey(ExtendedInternalTraversal): attrname, obj, parent, anon_map, bindparams ) - def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams): - from . import elements - - name = obj - if isinstance(name, elements._anonymous_label): - name = name.apply_map(anon_map) - - return (attrname, name) - def visit_fromclause_ordered_set( self, attrname, obj, parent, anon_map, bindparams ): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 377aa4fe0..e8726000b 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -822,9 +822,14 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): # is another join or selectable that contains a table which our # selectable derives from, that we want to process return None + elif not isinstance(col, ColumnElement): return None - elif self.include_fn and not self.include_fn(col): + + if "adapt_column" in col._annotations: + col = col._annotations["adapt_column"] + + if self.include_fn and not self.include_fn(col): return None elif self.exclude_fn and self.exclude_fn(col): return None diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 683f545dd..5de68f504 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -50,6 +50,13 @@ def _generate_compiler_dispatch(cls): """ visit_name = cls.__visit_name__ + if "_compiler_dispatch" in cls.__dict__: + # class has a fixed _compiler_dispatch() method. + # copy it to "original" so that we can get it back if + # sqlalchemy.ext.compiles overrides it. + cls._original_compiler_dispatch = cls._compiler_dispatch + return + if not isinstance(visit_name, util.compat.string_types): raise exc.InvalidRequestError( "__visit_name__ on class %s must be a string at the class level" @@ -76,7 +83,9 @@ def _generate_compiler_dispatch(cls): + self.__visit_name__ on the visitor, and call it with the same kw params. """ - cls._compiler_dispatch = _compiler_dispatch + cls._compiler_dispatch = ( + cls._original_compiler_dispatch + ) = _compiler_dispatch class TraversibleType(type): diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 79de3c978..247dbc13c 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -399,7 +399,7 @@ def reraise(tp, value, tb=None, cause=None): raise_(value, with_traceback=tb, from_=cause) -def with_metaclass(meta, *bases): +def with_metaclass(meta, *bases, **kw): """Create a base class with a metaclass. Drops the middle class upon creation. @@ -414,8 +414,15 @@ def with_metaclass(meta, *bases): def __new__(cls, name, this_bases, d): if this_bases is None: - return type.__new__(cls, name, (), d) - return meta(name, bases, d) + cls = type.__new__(cls, name, (), d) + else: + cls = meta(name, bases, d) + + if hasattr(cls, "__init_subclass__") and hasattr( + cls.__init_subclass__, "__func__" + ): + cls.__init_subclass__.__func__(cls, **kw) + return cls return metaclass("temporary_class", None, {}) |
