diff options
Diffstat (limited to 'lib/sqlalchemy')
35 files changed, 920 insertions, 447 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 441e77a37..24e2d13d8 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -3681,7 +3681,7 @@ class PGDialect(default.DefaultDialect): WHERE t.typtype = 'd' """ - s = sql.text(SQL_DOMAINS).columns(attname=sqltypes.Unicode) + s = sql.text(SQL_DOMAINS) c = connection.execution_options(future_result=True).execute(s) domains = {} diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index a36f4eee2..3e02a29fe 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1175,46 +1175,17 @@ class Connection(Connectable): ) compiled_cache = execution_options.get( - "compiled_cache", self.dialect._compiled_cache + "compiled_cache", self.engine._compiled_cache ) - if compiled_cache is not None: - elem_cache_key = elem._generate_cache_key() - else: - elem_cache_key = None - - if elem_cache_key: - cache_key, extracted_params = elem_cache_key - key = ( - dialect, - cache_key, - tuple(keys), - bool(schema_translate_map), - inline, - ) - compiled_sql = compiled_cache.get(key) - - if compiled_sql is None: - compiled_sql = elem.compile( - dialect=dialect, - cache_key=elem_cache_key, - column_keys=keys, - inline=inline, - schema_translate_map=schema_translate_map, - linting=self.dialect.compiler_linting - | compiler.WARN_LINTING, - ) - compiled_cache[key] = compiled_sql - else: - extracted_params = None - compiled_sql = elem.compile( - dialect=dialect, - column_keys=keys, - inline=inline, - schema_translate_map=schema_translate_map, - linting=self.dialect.compiler_linting | compiler.WARN_LINTING, - ) - + compiled_sql, extracted_params, cache_hit = elem._compile_w_cache( + dialect=dialect, + compiled_cache=compiled_cache, + column_keys=keys, + inline=inline, + schema_translate_map=schema_translate_map, + linting=self.dialect.compiler_linting | compiler.WARN_LINTING, + ) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_compiled, @@ -1225,6 +1196,7 @@ class Connection(Connectable): distilled_params, elem, extracted_params, + cache_hit=cache_hit, ) if has_events: self.dispatch.after_execute( @@ -1389,7 +1361,8 @@ class Connection(Connectable): statement, parameters, execution_options, - *args + *args, + **kw ): """Create an :class:`.ExecutionContext` and execute, returning a :class:`_engine.CursorResult`.""" @@ -1407,7 +1380,7 @@ class Connection(Connectable): conn = self._revalidate_connection() context = constructor( - dialect, self, conn, execution_options, *args + dialect, self, conn, execution_options, *args, **kw ) except (exc.PendingRollbackError, exc.ResourceClosedError): raise @@ -1455,32 +1428,21 @@ class Connection(Connectable): self.engine.logger.info(statement) - # stats = context._get_cache_stats() + stats = context._get_cache_stats() if not self.engine.hide_parameters: - # TODO: I love the stats but a ton of tests that are hardcoded. - # to certain log output are failing. self.engine.logger.info( - "%r", + "[%s] %r", + stats, sql_util._repr_params( parameters, batches=10, ismulti=context.executemany ), ) - # self.engine.logger.info( - # "[%s] %r", - # stats, - # sql_util._repr_params( - # parameters, batches=10, ismulti=context.executemany - # ), - # ) else: self.engine.logger.info( - "[SQL parameters hidden due to hide_parameters=True]" + "[%s] [SQL parameters hidden due to hide_parameters=True]" + % (stats,) ) - # self.engine.logger.info( - # "[%s] [SQL parameters hidden due to hide_parameters=True]" - # % (stats,) - # ) evt_handled = False try: @@ -2369,6 +2331,7 @@ class Engine(Connectable, log.Identified): url, logging_name=None, echo=None, + query_cache_size=500, execution_options=None, hide_parameters=False, ): @@ -2379,14 +2342,43 @@ class Engine(Connectable, log.Identified): self.logging_name = logging_name self.echo = echo self.hide_parameters = hide_parameters + if query_cache_size != 0: + self._compiled_cache = util.LRUCache( + query_cache_size, size_alert=self._lru_size_alert + ) + else: + self._compiled_cache = None log.instance_logger(self, echoflag=echo) if execution_options: self.update_execution_options(**execution_options) + def _lru_size_alert(self, cache): + if self._should_log_info: + self.logger.info( + "Compiled cache size pruning from %d items to %d. " + "Increase cache size to reduce the frequency of pruning.", + len(cache), + cache.capacity, + ) + @property def engine(self): return self + def clear_compiled_cache(self): + """Clear the compiled cache associated with the dialect. + + This applies **only** to the built-in cache that is established + via the :paramref:`.create_engine.query_cache_size` parameter. + It will not impact any dictionary caches that were passed via the + :paramref:`.Connection.execution_options.query_cache` parameter. + + .. versionadded:: 1.4 + + """ + if self._compiled_cache: + self._compiled_cache.clear() + def update_execution_options(self, **opt): r"""Update the default execution_options dictionary of this :class:`_engine.Engine`. @@ -2874,6 +2866,7 @@ class OptionEngineMixin(object): self.dialect = proxied.dialect self.logging_name = proxied.logging_name self.echo = proxied.echo + self._compiled_cache = proxied._compiled_cache self.hide_parameters = proxied.hide_parameters log.instance_logger(self, echoflag=self.echo) diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 4c912349e..9bf72eb06 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -436,7 +436,13 @@ def create_engine(url, **kwargs): .. versionadded:: 1.2.3 :param query_cache_size: size of the cache used to cache the SQL string - form of queries. Defaults to zero, which disables caching. + form of queries. Set to zero to disable caching. + + The cache is pruned of its least recently used items when its size reaches + N * 1.5. Defaults to 500, meaning the cache will always store at least + 500 SQL statements when filled, and will grow up to 750 items at which + point it is pruned back down to 500 by removing the 250 least recently + used items. Caching is accomplished on a per-statement basis by generating a cache key that represents the statement's structure, then generating @@ -446,6 +452,11 @@ def create_engine(url, **kwargs): bypass the cache. SQL logging will indicate statistics for each statement whether or not it were pull from the cache. + .. note:: some ORM functions related to unit-of-work persistence as well + as some attribute loading strategies will make use of individual + per-mapper caches outside of the main cache. + + .. seealso:: ``engine_caching`` - TODO: this will be an upcoming section describing diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index d03d79df7..abffe0d1f 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -51,6 +51,7 @@ class CursorResultMetaData(ResultMetaData): "_keys", "_tuplefilter", "_translated_indexes", + "_safe_for_cache" # don't need _unique_filters support here for now. Can be added # if a need arises. ) @@ -104,11 +105,11 @@ class CursorResultMetaData(ResultMetaData): return new_metadata def _adapt_to_context(self, context): - """When using a cached result metadata against a new context, - we need to rewrite the _keymap so that it has the specific - Column objects in the new context inside of it. this accommodates - for select() constructs that contain anonymized columns and - are cached. + """When using a cached Compiled construct that has a _result_map, + for a new statement that used the cached Compiled, we need to ensure + the keymap has the Column objects from our new statement as keys. + So here we rewrite keymap with new entries for the new columns + as matched to those of the cached statement. """ if not context.compiled._result_columns: @@ -124,14 +125,15 @@ class CursorResultMetaData(ResultMetaData): # to the result map. md = self.__class__.__new__(self.__class__) - md._keymap = self._keymap.copy() + md._keymap = dict(self._keymap) # match up new columns positionally to the result columns for existing, new in zip( context.compiled._result_columns, invoked_statement._exported_columns_iterator(), ): - md._keymap[new] = md._keymap[existing[RM_NAME]] + if existing[RM_NAME] in md._keymap: + md._keymap[new] = md._keymap[existing[RM_NAME]] md.case_sensitive = self.case_sensitive md._processors = self._processors @@ -147,6 +149,7 @@ class CursorResultMetaData(ResultMetaData): self._tuplefilter = None self._translated_indexes = None self.case_sensitive = dialect.case_sensitive + self._safe_for_cache = False if context.result_column_struct: ( @@ -341,6 +344,10 @@ class CursorResultMetaData(ResultMetaData): self._keys = [elem[0] for elem in result_columns] # pure positional 1-1 case; doesn't need to read # the names from cursor.description + + # this metadata is safe to cache because we are guaranteed + # to have the columns in the same order for new executions + self._safe_for_cache = True return [ ( idx, @@ -359,9 +366,12 @@ class CursorResultMetaData(ResultMetaData): for idx, rmap_entry in enumerate(result_columns) ] else: + # name-based or text-positional cases, where we need # to read cursor.description names + if textual_ordered: + self._safe_for_cache = True # textual positional case raw_iterator = self._merge_textual_cols_by_position( context, cursor_description, result_columns @@ -369,6 +379,9 @@ class CursorResultMetaData(ResultMetaData): elif num_ctx_cols: # compiled SQL with a mismatch of description cols # vs. compiled cols, or textual w/ unordered columns + # the order of columns can change if the query is + # against a "select *", so not safe to cache + self._safe_for_cache = False raw_iterator = self._merge_cols_by_name( context, cursor_description, @@ -376,7 +389,9 @@ class CursorResultMetaData(ResultMetaData): loose_column_name_matching, ) else: - # no compiled SQL, just a raw string + # no compiled SQL, just a raw string, order of columns + # can change for "select *" + self._safe_for_cache = False raw_iterator = self._merge_cols_by_none( context, cursor_description ) @@ -1152,7 +1167,6 @@ class BaseCursorResult(object): out_parameters = None _metadata = None - _metadata_from_cache = False _soft_closed = False closed = False @@ -1209,33 +1223,38 @@ class BaseCursorResult(object): def _init_metadata(self, context, cursor_description): if context.compiled: if context.compiled._cached_metadata: - cached_md = self.context.compiled._cached_metadata - self._metadata_from_cache = True - - # result rewrite/ adapt step. two translations can occur here. - # one is if we are invoked against a cached statement, we want - # to rewrite the ResultMetaData to reflect the column objects - # that are in our current selectable, not the cached one. the - # other is, the CompileState can return an alternative Result - # object. Finally, CompileState might want to tell us to not - # actually do the ResultMetaData adapt step if it in fact has - # changed the selected columns in any case. - compiled = context.compiled - if ( - compiled - and not compiled._rewrites_selected_columns - and compiled.statement is not context.invoked_statement - ): - cached_md = cached_md._adapt_to_context(context) - - self._metadata = metadata = cached_md - + metadata = self.context.compiled._cached_metadata else: - self._metadata = ( - metadata - ) = context.compiled._cached_metadata = self._cursor_metadata( - self, cursor_description - ) + metadata = self._cursor_metadata(self, cursor_description) + if metadata._safe_for_cache: + context.compiled._cached_metadata = metadata + + # result rewrite/ adapt step. this is to suit the case + # when we are invoked against a cached Compiled object, we want + # to rewrite the ResultMetaData to reflect the Column objects + # that are in our current SQL statement object, not the one + # that is associated with the cached Compiled object. + # the Compiled object may also tell us to not + # actually do this step; this is to support the ORM where + # it is to produce a new Result object in any case, and will + # be using the cached Column objects against this database result + # so we don't want to rewrite them. + # + # Basically this step suits the use case where the end user + # is using Core SQL expressions and is accessing columns in the + # result row using row._mapping[table.c.column]. + compiled = context.compiled + if ( + compiled + and compiled._result_columns + and context.cache_hit + and not compiled._rewrites_selected_columns + and compiled.statement is not context.invoked_statement + ): + metadata = metadata._adapt_to_context(context) + + self._metadata = metadata + else: self._metadata = metadata = self._cursor_metadata( self, cursor_description diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index c682a8ee1..4d516e97c 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -230,7 +230,6 @@ class DefaultDialect(interfaces.Dialect): supports_native_boolean=None, max_identifier_length=None, label_length=None, - query_cache_size=0, # int() is because the @deprecated_params decorator cannot accommodate # the direct reference to the "NO_LINTING" object compiler_linting=int(compiler.NO_LINTING), @@ -262,10 +261,6 @@ class DefaultDialect(interfaces.Dialect): if supports_native_boolean is not None: self.supports_native_boolean = supports_native_boolean self.case_sensitive = case_sensitive - if query_cache_size != 0: - self._compiled_cache = util.LRUCache(query_cache_size) - else: - self._compiled_cache = None self._user_defined_max_identifier_length = max_identifier_length if self._user_defined_max_identifier_length: @@ -794,6 +789,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): parameters, invoked_statement, extracted_parameters, + cache_hit=False, ): """Initialize execution context for a Compiled construct.""" @@ -804,6 +800,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.extracted_parameters = extracted_parameters self.invoked_statement = invoked_statement self.compiled = compiled + self.cache_hit = cache_hit self.execution_options = execution_options @@ -1027,13 +1024,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _get_cache_stats(self): if self.compiled is None: - return "raw SQL" + return "raw sql" now = time.time() if self.compiled.cache_key is None: - return "gen %.5fs" % (now - self.compiled._gen_time,) + return "no key %.5fs" % (now - self.compiled._gen_time,) + elif self.cache_hit: + return "cached for %.4gs" % (now - self.compiled._gen_time,) else: - return "cached %.5fs" % (now - self.compiled._gen_time,) + return "generated in %.5fs" % (now - self.compiled._gen_time,) @util.memoized_property def engine(self): diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index f95a30fda..4f40637c5 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -412,7 +412,6 @@ class Result(object): result = self.session.execute( statement, params, execution_options=execution_options ) - if result._attributes.get("is_single_entity", False): result = result.scalars() diff --git a/lib/sqlalchemy/future/selectable.py b/lib/sqlalchemy/future/selectable.py index 407ec9633..53fc7c107 100644 --- a/lib/sqlalchemy/future/selectable.py +++ b/lib/sqlalchemy/future/selectable.py @@ -11,6 +11,7 @@ class Select(_LegacySelect): _is_future = True _setup_joins = () _legacy_setup_joins = () + inherit_cache = True @classmethod def _create_select(cls, *entities): diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 262a1efc9..bf07061c6 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -85,16 +85,16 @@ class QueryableAttribute( self, class_, key, + parententity, impl=None, comparator=None, - parententity=None, of_type=None, ): self.class_ = class_ self.key = key + self._parententity = parententity self.impl = impl self.comparator = comparator - self._parententity = parententity self._of_type = of_type manager = manager_of_class(class_) @@ -197,10 +197,14 @@ class QueryableAttribute( @util.memoized_property def expression(self): return self.comparator.__clause_element__()._annotate( - {"orm_key": self.key} + {"orm_key": self.key, "entity_namespace": self._entity_namespace} ) @property + def _entity_namespace(self): + return self._parententity + + @property def _annotations(self): return self.__clause_element__()._annotations @@ -230,9 +234,9 @@ class QueryableAttribute( return QueryableAttribute( self.class_, self.key, - self.impl, - self.comparator.of_type(entity), self._parententity, + impl=self.impl, + comparator=self.comparator.of_type(entity), of_type=inspection.inspect(entity), ) @@ -301,6 +305,8 @@ class InstrumentedAttribute(QueryableAttribute): """ + inherit_cache = True + def __set__(self, instance, value): self.impl.set( instance_state(instance), instance_dict(instance), value, None @@ -320,6 +326,11 @@ class InstrumentedAttribute(QueryableAttribute): return self.impl.get(instance_state(instance), dict_) +HasEntityNamespace = util.namedtuple( + "HasEntityNamespace", ["entity_namespace"] +) + + def create_proxied_attribute(descriptor): """Create an QueryableAttribute / user descriptor hybrid. @@ -365,6 +376,15 @@ def create_proxied_attribute(descriptor): ) @property + def _entity_namespace(self): + if hasattr(self._comparator, "_parententity"): + return self._comparator._parententity + else: + # used by hybrid attributes which try to remain + # agnostic of any ORM concepts like mappers + return HasEntityNamespace(self.class_) + + @property def property(self): return self.comparator.property diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index a16db66f6..588b83571 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -63,6 +63,8 @@ class QueryContext(object): "post_load_paths", "identity_token", "yield_per", + "loaders_require_buffering", + "loaders_require_uniquing", ) class default_load_options(Options): @@ -80,21 +82,23 @@ class QueryContext(object): def __init__( self, compile_state, + statement, session, load_options, execution_options=None, bind_arguments=None, ): - self.load_options = load_options self.execution_options = execution_options or _EMPTY_DICT self.bind_arguments = bind_arguments or _EMPTY_DICT self.compile_state = compile_state - self.query = query = compile_state.select_statement + self.query = statement self.session = session + self.loaders_require_buffering = False + self.loaders_require_uniquing = False self.propagated_loader_options = { - o for o in query._with_options if o.propagate_to_loaders + o for o in statement._with_options if o.propagate_to_loaders } self.attributes = dict(compile_state.attributes) @@ -237,6 +241,7 @@ class ORMCompileState(CompileState): ) querycontext = QueryContext( compile_state, + statement, session, load_options, execution_options, @@ -278,8 +283,6 @@ class ORMFromStatementCompileState(ORMCompileState): _has_orm_entities = False multi_row_eager_loaders = False compound_eager_adapter = None - loaders_require_buffering = False - loaders_require_uniquing = False @classmethod def create_for_statement(cls, statement_container, compiler, **kw): @@ -386,8 +389,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState): _has_orm_entities = False multi_row_eager_loaders = False compound_eager_adapter = None - loaders_require_buffering = False - loaders_require_uniquing = False correlate = None _where_criteria = () @@ -416,7 +417,14 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self = cls.__new__(cls) - self.select_statement = select_statement + if select_statement._execution_options: + # execution options should not impact the compilation of a + # query, and at the moment subqueryloader is putting some things + # in here that we explicitly don't want stuck in a cache. + self.select_statement = select_statement._clone() + self.select_statement._execution_options = util.immutabledict() + else: + self.select_statement = select_statement # indicates this select() came from Query.statement self.for_statement = ( @@ -654,6 +662,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ) self._setup_with_polymorphics() + # entities will also set up polymorphic adapters for mappers + # that have with_polymorphic configured _QueryEntity.to_compile_state(self, query._raw_columns) return self @@ -1810,10 +1820,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self._where_criteria += (single_crit,) -def _column_descriptions(query_or_select_stmt): - ctx = ORMSelectCompileState._create_entities_collection( - query_or_select_stmt - ) +def _column_descriptions(query_or_select_stmt, compile_state=None): + if compile_state is None: + compile_state = ORMSelectCompileState._create_entities_collection( + query_or_select_stmt + ) + ctx = compile_state return [ { "name": ent._label_name, @@ -2097,6 +2109,7 @@ class _MapperEntity(_QueryEntity): only_load_props = refresh_state = None _instance = loading._instance_processor( + self, self.mapper, context, result, diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 027f2521b..39cf86e34 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -411,7 +411,6 @@ class CompositeProperty(DescriptorProperty): def expression(self): clauses = self.clauses._annotate( { - "bundle": True, "parententity": self._parententity, "parentmapper": self._parententity, "orm_key": self.prop.key, diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 6c0f5d3ef..9782d92b7 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -158,7 +158,7 @@ class MapperProperty( """ def create_row_processor( - self, context, path, mapper, result, adapter, populators + self, context, query_entity, path, mapper, result, adapter, populators ): """Produce row processing functions and append to the given set of populators lists. @@ -539,7 +539,7 @@ class StrategizedProperty(MapperProperty): "_wildcard_token", "_default_path_loader_key", ) - + inherit_cache = True strategy_wildcard_key = None def _memoized_attr__wildcard_token(self): @@ -600,7 +600,7 @@ class StrategizedProperty(MapperProperty): ) def create_row_processor( - self, context, path, mapper, result, adapter, populators + self, context, query_entity, path, mapper, result, adapter, populators ): loader = self._get_context_loader(context, path) if loader and loader.strategy: @@ -608,7 +608,14 @@ class StrategizedProperty(MapperProperty): else: strat = self.strategy strat.create_row_processor( - context, path, loader, mapper, result, adapter, populators + context, + query_entity, + path, + loader, + mapper, + result, + adapter, + populators, ) def do_init(self): @@ -668,7 +675,7 @@ class StrategizedProperty(MapperProperty): ) -class ORMOption(object): +class ORMOption(HasCacheKey): """Base class for option objects that are passed to ORM queries. These options may be consumed by :meth:`.Query.options`, @@ -696,7 +703,7 @@ class ORMOption(object): _is_compile_state = False -class LoaderOption(HasCacheKey, ORMOption): +class LoaderOption(ORMOption): """Describe a loader modification to an ORM statement at compilation time. .. versionadded:: 1.4 @@ -736,9 +743,6 @@ class UserDefinedOption(ORMOption): def __init__(self, payload=None): self.payload = payload - def _gen_cache_key(self, *arg, **kw): - return () - @util.deprecated_cls( "1.4", @@ -855,7 +859,15 @@ class LoaderStrategy(object): """ def create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ): """Establish row processing functions for a given QueryContext. diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 424ed5dfe..a33e1b77d 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -72,8 +72,8 @@ def instances(cursor, context): ) if context.yield_per and ( - context.compile_state.loaders_require_buffering - or context.compile_state.loaders_require_uniquing + context.loaders_require_buffering + or context.loaders_require_uniquing ): raise sa_exc.InvalidRequestError( "Can't use yield_per with eager loaders that require uniquing " @@ -545,6 +545,7 @@ def _warn_for_runid_changed(state): def _instance_processor( + query_entity, mapper, context, result, @@ -648,6 +649,7 @@ def _instance_processor( # to see if one fits prop.create_row_processor( context, + query_entity, path, mapper, result, @@ -667,7 +669,7 @@ def _instance_processor( populators = {key: list(value) for key, value in cached_populators.items()} for prop in getters["todo"]: prop.create_row_processor( - context, path, mapper, result, adapter, populators + context, query_entity, path, mapper, result, adapter, populators ) propagated_loader_options = context.propagated_loader_options @@ -925,6 +927,7 @@ def _instance_processor( _instance = _decorate_polymorphic_switch( _instance, context, + query_entity, mapper, result, path, @@ -1081,6 +1084,7 @@ def _validate_version_id(mapper, state, dict_, row, getter): def _decorate_polymorphic_switch( instance_fn, context, + query_entity, mapper, result, path, @@ -1112,6 +1116,7 @@ def _decorate_polymorphic_switch( return False return _instance_processor( + query_entity, sub_mapper, context, result, diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c4cb89c03..bec6da74d 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -720,7 +720,7 @@ class Mapper( return self _cache_key_traversal = [ - ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj) + ("mapper", visitors.ExtendedInternalTraversal.dp_plain_obj), ] @property diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 2e5941713..ac7a64c30 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -216,6 +216,8 @@ class RootRegistry(PathRegistry): """ + inherit_cache = True + path = natural_path = () has_entity = False is_aliased_class = False @@ -248,6 +250,8 @@ class PathToken(HasCacheKey, str): class TokenRegistry(PathRegistry): __slots__ = ("token", "parent", "path", "natural_path") + inherit_cache = True + def __init__(self, parent, token): token = PathToken.intern(token) @@ -280,6 +284,7 @@ class TokenRegistry(PathRegistry): class PropRegistry(PathRegistry): is_unnatural = False + inherit_cache = True def __init__(self, parent, prop): # restate this path in terms of the @@ -439,6 +444,7 @@ class AbstractEntityRegistry(PathRegistry): class SlotsEntityRegistry(AbstractEntityRegistry): # for aliased class, return lightweight, no-cycles created # version + inherit_cache = True __slots__ = ( "key", @@ -454,6 +460,8 @@ class CachingEntityRegistry(AbstractEntityRegistry, dict): # for long lived mapper, return dict based caching # version that creates reference cycles + inherit_cache = True + def __getitem__(self, entity): if isinstance(entity, (int, slice)): return self.path[entity] diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 19d43d354..8393eaf74 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -38,6 +38,7 @@ from ..sql.base import Options from ..sql.dml import DeleteDMLState from ..sql.dml import UpdateDMLState from ..sql.elements import BooleanClauseList +from ..sql.util import _entity_namespace_key def _bulk_insert( @@ -1820,8 +1821,12 @@ class BulkUDCompileState(CompileState): if isinstance(k, util.string_types): desc = sql.util._entity_namespace_key(mapper, k) values.extend(desc._bulk_update_tuples(v)) - elif isinstance(k, attributes.QueryableAttribute): - values.extend(k._bulk_update_tuples(v)) + elif "entity_namespace" in k._annotations: + k_anno = k._annotations + attr = _entity_namespace_key( + k_anno["entity_namespace"], k_anno["orm_key"] + ) + values.extend(attr._bulk_update_tuples(v)) else: values.append((k, v)) else: diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 02f0752a5..5fb3beca3 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -45,6 +45,7 @@ class ColumnProperty(StrategizedProperty): """ strategy_wildcard_key = "column" + inherit_cache = True __slots__ = ( "_orig_columns", diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 284ea9d72..cdad55320 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -61,6 +61,7 @@ from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import SelectStatementGrouping from ..sql.util import _entity_namespace_key +from ..sql.visitors import InternalTraversal from ..util import collections_abc __all__ = ["Query", "QueryContext", "aliased"] @@ -423,6 +424,7 @@ class Query( _label_style=self._label_style, compile_options=compile_options, ) + stmt.__dict__.pop("session", None) stmt._propagate_attrs = self._propagate_attrs return stmt @@ -1725,7 +1727,6 @@ class Query( """ from_entity = self._filter_by_zero() - if from_entity is None: raise sa_exc.InvalidRequestError( "Can't use filter_by when the first entity '%s' of a query " @@ -2900,7 +2901,10 @@ class Query( compile_state = self._compile_state(for_statement=False) context = QueryContext( - compile_state, self.session, self.load_options + compile_state, + compile_state.statement, + self.session, + self.load_options, ) result = loading.instances(result_proxy, context) @@ -3376,7 +3380,12 @@ class Query( def _compile_context(self, for_statement=False): compile_state = self._compile_state(for_statement=for_statement) - context = QueryContext(compile_state, self.session, self.load_options) + context = QueryContext( + compile_state, + compile_state.statement, + self.session, + self.load_options, + ) return context @@ -3397,6 +3406,11 @@ class FromStatement(SelectStatementGrouping, Executable): _for_update_arg = None + _traverse_internals = [ + ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("element", InternalTraversal.dp_clauseelement), + ] + Executable._executable_traverse_internals + def __init__(self, entities, element): self._raw_columns = [ coercions.expect( diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 683f2b978..bedc54153 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -107,6 +107,7 @@ class RelationshipProperty(StrategizedProperty): """ strategy_wildcard_key = "relationship" + inherit_cache = True _persistence_only = dict( passive_deletes=False, diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index f67c23aab..5f039aff7 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -25,6 +25,7 @@ from .base import _DEFER_FOR_STATE from .base import _RAISE_FOR_STATE from .base import _SET_DEFERRED_EXPIRED from .context import _column_descriptions +from .context import ORMCompileState from .interfaces import LoaderStrategy from .interfaces import StrategizedProperty from .session import _state_session @@ -156,7 +157,15 @@ class UninstrumentedColumnLoader(LoaderStrategy): column_collection.append(c) def create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ): pass @@ -224,7 +233,15 @@ class ColumnLoader(LoaderStrategy): ) def create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ): # look through list of columns represented here # to see which, if any, is present in the row. @@ -281,7 +298,15 @@ class ExpressionColumnLoader(ColumnLoader): memoized_populators[self.parent_property] = fetch def create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ): # look through list of columns represented here # to see which, if any, is present in the row. @@ -332,7 +357,15 @@ class DeferredColumnLoader(LoaderStrategy): self.group = self.parent_property.group def create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ): # for a DeferredColumnLoader, this method is only used during a @@ -542,7 +575,15 @@ class NoLoader(AbstractRelationshipLoader): ) def create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ): def invoke_no_load(state, dict_, row): if self.uselist: @@ -985,7 +1026,15 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): return None def create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ): key = self.key @@ -1039,12 +1088,27 @@ class PostLoader(AbstractRelationshipLoader): """A relationship loader that emits a second SELECT statement.""" def _immediateload_create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ): return self.parent_property._get_strategy( (("lazy", "immediate"),) ).create_row_processor( - context, path, loadopt, mapper, result, adapter, populators + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ) @@ -1057,21 +1121,16 @@ class ImmediateLoader(PostLoader): (("lazy", "select"),) ).init_class_attribute(mapper) - def setup_query( + def create_row_processor( self, - compile_state, - entity, + context, + query_entity, path, loadopt, + mapper, + result, adapter, - column_collection=None, - parentmapper=None, - **kwargs - ): - pass - - def create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + populators, ): def load_immediate(state, dict_, row): state.get_impl(self.key).get(state, dict_) @@ -1093,120 +1152,6 @@ class SubqueryLoader(PostLoader): (("lazy", "select"),) ).init_class_attribute(mapper) - def setup_query( - self, - compile_state, - entity, - path, - loadopt, - adapter, - column_collection=None, - parentmapper=None, - **kwargs - ): - if ( - not compile_state.compile_options._enable_eagerloads - or compile_state.compile_options._for_refresh_state - ): - return - - compile_state.loaders_require_buffering = True - - path = path[self.parent_property] - - # build up a path indicating the path from the leftmost - # entity to the thing we're subquery loading. - with_poly_entity = path.get( - compile_state.attributes, "path_with_polymorphic", None - ) - if with_poly_entity is not None: - effective_entity = with_poly_entity - else: - effective_entity = self.entity - - subq_path = compile_state.attributes.get( - ("subquery_path", None), orm_util.PathRegistry.root - ) - - subq_path = subq_path + path - - # if not via query option, check for - # a cycle - if not path.contains(compile_state.attributes, "loader"): - if self.join_depth: - if ( - ( - compile_state.current_path.length - if compile_state.current_path - else 0 - ) - + path.length - ) / 2 > self.join_depth: - return - elif subq_path.contains_mapper(self.mapper): - return - - ( - leftmost_mapper, - leftmost_attr, - leftmost_relationship, - ) = self._get_leftmost(subq_path) - - orig_query = compile_state.attributes.get( - ("orig_query", SubqueryLoader), compile_state.select_statement - ) - - # generate a new Query from the original, then - # produce a subquery from it. - left_alias = self._generate_from_original_query( - compile_state, - orig_query, - leftmost_mapper, - leftmost_attr, - leftmost_relationship, - entity.entity_zero, - ) - - # generate another Query that will join the - # left alias to the target relationships. - # basically doing a longhand - # "from_self()". (from_self() itself not quite industrial - # strength enough for all contingencies...but very close) - - q = query.Query(effective_entity) - - def set_state_options(compile_state): - compile_state.attributes.update( - { - ("orig_query", SubqueryLoader): orig_query, - ("subquery_path", None): subq_path, - } - ) - - q = q._add_context_option(set_state_options, None)._disable_caching() - - q = q._set_enable_single_crit(False) - to_join, local_attr, parent_alias = self._prep_for_joins( - left_alias, subq_path - ) - - q = q.add_columns(*local_attr) - q = self._apply_joins( - q, to_join, left_alias, parent_alias, effective_entity - ) - - q = self._setup_options(q, subq_path, orig_query, effective_entity) - q = self._setup_outermost_orderby(q) - - # add new query to attributes to be picked up - # by create_row_processor - # NOTE: be sure to consult baked.py for some hardcoded logic - # about this structure as well - assert q.session is None - path.set( - compile_state.attributes, "subqueryload_data", {"query": q}, - ) - def _get_leftmost(self, subq_path): subq_path = subq_path.path subq_mapper = orm_util._class_to_mapper(subq_path[0]) @@ -1267,27 +1212,34 @@ class SubqueryLoader(PostLoader): q, *{ ent["entity"] - for ent in _column_descriptions(orig_query) + for ent in _column_descriptions( + orig_query, compile_state=orig_compile_state + ) if ent["entity"] is not None } ) - # for column information, look to the compile state that is - # already being passed through - compile_state = orig_compile_state - # select from the identity columns of the outer (specifically, these - # are the 'local_cols' of the property). This will remove - # other columns from the query that might suggest the right entity - # which is why we do _set_select_from above. - target_cols = compile_state._adapt_col_list( + # are the 'local_cols' of the property). This will remove other + # columns from the query that might suggest the right entity which is + # why we do set select_from above. The attributes we have are + # coerced and adapted using the original query's adapter, which is + # needed only for the case of adapting a subclass column to + # that of a polymorphic selectable, e.g. we have + # Engineer.primary_language and the entity is Person. All other + # adaptations, e.g. from_self, select_entity_from(), will occur + # within the new query when it compiles, as the compile_state we are + # using here is only a partial one. If the subqueryload is from a + # with_polymorphic() or other aliased() object, left_attr will already + # be the correct attributes so no adaptation is needed. + target_cols = orig_compile_state._adapt_col_list( [ - sql.coercions.expect(sql.roles.ByOfRole, o) + sql.coercions.expect(sql.roles.ColumnsClauseRole, o) for o in leftmost_attr ], - compile_state._get_current_adapter(), + orig_compile_state._get_current_adapter(), ) - q._set_entities(target_cols) + q._raw_columns = target_cols distinct_target_key = leftmost_relationship.distinct_target_key @@ -1461,13 +1413,13 @@ class SubqueryLoader(PostLoader): "_data", ) - def __init__(self, context, subq_info): + def __init__(self, context, subq): # avoid creating a cycle by storing context # even though that's preferable self.session = context.session self.execution_options = context.execution_options self.load_options = context.load_options - self.subq = subq_info["query"] + self.subq = subq self._data = None def get(self, key, default): @@ -1499,12 +1451,148 @@ class SubqueryLoader(PostLoader): if self._data is None: self._load() + def _setup_query_from_rowproc( + self, context, path, entity, loadopt, adapter, + ): + compile_state = context.compile_state + if ( + not compile_state.compile_options._enable_eagerloads + or compile_state.compile_options._for_refresh_state + ): + return + + context.loaders_require_buffering = True + + path = path[self.parent_property] + + # build up a path indicating the path from the leftmost + # entity to the thing we're subquery loading. + with_poly_entity = path.get( + compile_state.attributes, "path_with_polymorphic", None + ) + if with_poly_entity is not None: + effective_entity = with_poly_entity + else: + effective_entity = self.entity + + subq_path = context.query._execution_options.get( + ("subquery_path", None), orm_util.PathRegistry.root + ) + + subq_path = subq_path + path + + # if not via query option, check for + # a cycle + if not path.contains(compile_state.attributes, "loader"): + if self.join_depth: + if ( + ( + compile_state.current_path.length + if compile_state.current_path + else 0 + ) + + path.length + ) / 2 > self.join_depth: + return + elif subq_path.contains_mapper(self.mapper): + return + + ( + leftmost_mapper, + leftmost_attr, + leftmost_relationship, + ) = self._get_leftmost(subq_path) + + # use the current query being invoked, not the compile state + # one. this is so that we get the current parameters. however, + # it means we can't use the existing compile state, we have to make + # a new one. other approaches include possibly using the + # compiled query but swapping the params, seems only marginally + # less time spent but more complicated + orig_query = context.query._execution_options.get( + ("orig_query", SubqueryLoader), context.query + ) + + # make a new compile_state for the query that's probably cached, but + # we're sort of undoing a bit of that caching :( + compile_state_cls = ORMCompileState._get_plugin_class_for_plugin( + orig_query, "orm" + ) + + # this would create the full blown compile state, which we don't + # need + # orig_compile_state = compile_state_cls.create_for_statement( + # orig_query, None) + + # this is the more "quick" version, however it's not clear how + # much of this we need. in particular I can't get a test to + # fail if the "set_base_alias" is missing and not sure why that is. + orig_compile_state = compile_state_cls._create_entities_collection( + orig_query + ) + + # generate a new Query from the original, then + # produce a subquery from it. + left_alias = self._generate_from_original_query( + orig_compile_state, + orig_query, + leftmost_mapper, + leftmost_attr, + leftmost_relationship, + entity, + ) + + # generate another Query that will join the + # left alias to the target relationships. + # basically doing a longhand + # "from_self()". (from_self() itself not quite industrial + # strength enough for all contingencies...but very close) + + q = query.Query(effective_entity) + + q._execution_options = q._execution_options.union( + { + ("orig_query", SubqueryLoader): orig_query, + ("subquery_path", None): subq_path, + } + ) + + q = q._set_enable_single_crit(False) + to_join, local_attr, parent_alias = self._prep_for_joins( + left_alias, subq_path + ) + + q = q.add_columns(*local_attr) + q = self._apply_joins( + q, to_join, left_alias, parent_alias, effective_entity + ) + + q = self._setup_options(q, subq_path, orig_query, effective_entity) + q = self._setup_outermost_orderby(q) + + return q + def create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ): if context.refresh_state: return self._immediateload_create_row_processor( - context, path, loadopt, mapper, result, adapter, populators + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ) if not self.parent.class_manager[self.key].impl.supports_population: @@ -1513,16 +1601,27 @@ class SubqueryLoader(PostLoader): "population - eager loading cannot be applied." % self ) - path = path[self.parent_property] + # a little dance here as the "path" is still something that only + # semi-tracks the exact series of things we are loading, still not + # telling us about with_polymorphic() and stuff like that when it's at + # the root.. the initial MapperEntity is more accurate for this case. + if len(path) == 1: + if not orm_util._entity_isa(query_entity.entity_zero, self.parent): + return + elif not orm_util._entity_isa(path[-1], self.parent): + return - subq_info = path.get(context.attributes, "subqueryload_data") + subq = self._setup_query_from_rowproc( + context, path, path[-1], loadopt, adapter, + ) - if subq_info is None: + if subq is None: return - subq = subq_info["query"] - assert subq.session is None + + path = path[self.parent_property] + local_cols = self.parent_property.local_columns # cache the loaded collections in the context @@ -1530,7 +1629,7 @@ class SubqueryLoader(PostLoader): # call upon create_row_processor again collections = path.get(context.attributes, "collections") if collections is None: - collections = self._SubqCollections(context, subq_info) + collections = self._SubqCollections(context, subq) path.set(context.attributes, "collections", collections) if adapter: @@ -1634,7 +1733,6 @@ class JoinedLoader(AbstractRelationshipLoader): if not compile_state.compile_options._enable_eagerloads: return elif self.uselist: - compile_state.loaders_require_uniquing = True compile_state.multi_row_eager_loaders = True path = path[self.parent_property] @@ -2142,7 +2240,15 @@ class JoinedLoader(AbstractRelationshipLoader): return False def create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ): if not self.parent.class_manager[self.key].impl.supports_population: raise sa_exc.InvalidRequestError( @@ -2150,6 +2256,9 @@ class JoinedLoader(AbstractRelationshipLoader): "population - eager loading cannot be applied." % self ) + if self.uselist: + context.loaders_require_uniquing = True + our_path = path[self.parent_property] eager_adapter = self._create_eager_adapter( @@ -2160,6 +2269,7 @@ class JoinedLoader(AbstractRelationshipLoader): key = self.key _instance = loading._instance_processor( + query_entity, self.mapper, context, result, @@ -2177,7 +2287,14 @@ class JoinedLoader(AbstractRelationshipLoader): self.parent_property._get_strategy( (("lazy", "select"),) ).create_row_processor( - context, path, loadopt, mapper, result, adapter, populators + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ) def _create_collection_loader(self, context, key, _instance, populators): @@ -2382,11 +2499,26 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): return util.preloaded.ext_baked.bakery(size=50) def create_row_processor( - self, context, path, loadopt, mapper, result, adapter, populators + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ): if context.refresh_state: return self._immediateload_create_row_processor( - context, path, loadopt, mapper, result, adapter, populators + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, ) if not self.parent.class_manager[self.key].impl.supports_population: @@ -2395,13 +2527,20 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): "population - eager loading cannot be applied." % self ) + # a little dance here as the "path" is still something that only + # semi-tracks the exact series of things we are loading, still not + # telling us about with_polymorphic() and stuff like that when it's at + # the root.. the initial MapperEntity is more accurate for this case. + if len(path) == 1: + if not orm_util._entity_isa(query_entity.entity_zero, self.parent): + return + elif not orm_util._entity_isa(path[-1], self.parent): + return + selectin_path = ( context.compile_state.current_path or orm_util.PathRegistry.root ) + path - if not orm_util._entity_isa(path[-1], self.parent): - return - if loading.PostLoad.path_exists( context, selectin_path, self.parent_property ): @@ -2427,7 +2566,6 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): return elif selectin_path_w_prop.contains_mapper(self.mapper): return - loading.PostLoad.callable_for_path( context, selectin_path, @@ -2543,7 +2681,39 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): ) ) - orig_query = context.query + # a test which exercises what these comments talk about is + # test_selectin_relations.py -> test_twolevel_selectin_w_polymorphic + # + # effective_entity above is given to us in terms of the cached + # statement, namely this one: + orig_query = context.compile_state.select_statement + + # the actual statement that was requested is this one: + # context_query = context.query + # + # that's not the cached one, however. So while it is of the identical + # structure, if it has entities like AliasedInsp, which we get from + # aliased() or with_polymorphic(), the AliasedInsp will likely be a + # different object identity each time, and will not match up + # hashing-wise to the corresponding AliasedInsp that's in the + # cached query, meaning it won't match on paths and loader lookups + # and loaders like this one will be skipped if it is used in options. + # + # Now we want to transfer loader options from the parent query to the + # "selectinload" query we're about to run. Which query do we transfer + # the options from? We use the cached query, because the options in + # that query will be in terms of the effective entity we were just + # handed. + # + # But now the selectinload/ baked query we are running is *also* + # cached. What if it's cached and running from some previous iteration + # of that AliasedInsp? Well in that case it will also use the previous + # iteration of the loader options. If the baked query expires and + # gets generated again, it will be handed the current effective_entity + # and the current _with_options, again in terms of whatever + # compile_state.select_statement happens to be right now, so the + # query will still be internally consistent and loader callables + # will be correctly invoked. q._add_lazyload_options( orig_query._with_options, path[self.parent_property] diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 85f4f85d1..f7a97bfe5 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1187,9 +1187,9 @@ class Bundle(ORMColumnsClauseRole, SupportsCloneAnnotations, InspectionAttr): return cloned def __clause_element__(self): - annotations = self._annotations.union( - {"bundle": self, "entity_namespace": self} - ) + # ensure existing entity_namespace remains + annotations = {"bundle": self, "entity_namespace": self} + annotations.update(self._annotations) return expression.ClauseList( _literal_as_text_role=roles.ColumnsClauseRole, group=False, @@ -1258,6 +1258,8 @@ class _ORMJoin(expression.Join): __visit_name__ = expression.Join.__visit_name__ + inherit_cache = True + def __init__( self, left, diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 78de80734..a25c1b083 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -100,6 +100,7 @@ def __go(lcls): from .elements import AnnotatedColumnElement from .elements import ClauseList # noqa from .selectable import AnnotatedFromClause # noqa + from .traversals import _preconfigure_traversals from . import base from . import coercions @@ -122,6 +123,8 @@ def __go(lcls): _prepare_annotations(FromClause, AnnotatedFromClause) _prepare_annotations(ClauseList, Annotated) + _preconfigure_traversals(ClauseElement) + _sa_util.preloaded.import_prefix("sqlalchemy.sql") from . import naming # noqa diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 08ed121d3..8a0d6ec28 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -338,6 +338,15 @@ def _new_annotation_type(cls, base_cls): anno_cls._traverse_internals = list(cls._traverse_internals) + [ ("_annotations", InternalTraversal.dp_annotations_key) ] + elif cls.__dict__.get("inherit_cache", False): + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_key) + ] + + # some classes include this even if they have traverse_internals + # e.g. BindParameter, add it if present. + if cls.__dict__.get("inherit_cache", False): + anno_cls.inherit_cache = True anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 5dd3b519a..5f2ce8f14 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -624,19 +624,14 @@ class Executable(Generative): _bind = None _with_options = () _with_context_options = () - _cache_enable = True _executable_traverse_internals = [ ("_with_options", ExtendedInternalTraversal.dp_has_cache_key_list), ("_with_context_options", ExtendedInternalTraversal.dp_plain_obj), - ("_cache_enable", ExtendedInternalTraversal.dp_plain_obj), + ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs), ] @_generative - def _disable_caching(self): - self._cache_enable = HasCacheKey() - - @_generative def options(self, *options): """Apply options to this statement. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2519438d1..61178291a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -373,6 +373,8 @@ class Compiled(object): _cached_metadata = None + _result_columns = None + schema_translate_map = None execution_options = util.immutabledict() @@ -433,7 +435,6 @@ class Compiled(object): self, dialect, statement, - bind=None, schema_translate_map=None, render_schema_translate=False, compile_kwargs=util.immutabledict(), @@ -463,7 +464,6 @@ class Compiled(object): """ self.dialect = dialect - self.bind = bind self.preparer = self.dialect.identifier_preparer if schema_translate_map: self.schema_translate_map = schema_translate_map @@ -527,24 +527,6 @@ class Compiled(object): """Return the bind params for this compiled object.""" return self.construct_params() - def execute(self, *multiparams, **params): - """Execute this compiled object.""" - - e = self.bind - if e is None: - raise exc.UnboundExecutionError( - "This Compiled object is not bound to any Engine " - "or Connection.", - code="2afi", - ) - return e._execute_compiled(self, multiparams, params) - - def scalar(self, *multiparams, **params): - """Execute this compiled object and return the result's - scalar value.""" - - return self.execute(*multiparams, **params).scalar() - class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): """Produces DDL specification for TypeEngine objects.""" @@ -687,6 +669,13 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () + _cache_key_bind_match = None + """a mapping that will relate the BindParameter object we compile + to those that are part of the extracted collection of parameters + in the cache key, if we were given a cache key. + + """ + def __init__( self, dialect, @@ -717,6 +706,9 @@ class SQLCompiler(Compiled): self.cache_key = cache_key + if cache_key: + self._cache_key_bind_match = {b: b for b in cache_key[1]} + # compile INSERT/UPDATE defaults/sequences inlined (no pre- # execute) self.inline = inline or getattr(statement, "_inline", False) @@ -875,8 +867,9 @@ class SQLCompiler(Compiled): replace_context=err, ) + ckbm = self._cache_key_bind_match resolved_extracted = { - b.key: extracted + ckbm[b]: extracted for b, extracted in zip(orig_extracted, extracted_parameters) } else: @@ -907,7 +900,7 @@ class SQLCompiler(Compiled): else: if resolved_extracted: value_param = resolved_extracted.get( - bindparam.key, bindparam + bindparam, bindparam ) else: value_param = bindparam @@ -936,9 +929,7 @@ class SQLCompiler(Compiled): ) if resolved_extracted: - value_param = resolved_extracted.get( - bindparam.key, bindparam - ) + value_param = resolved_extracted.get(bindparam, bindparam) else: value_param = bindparam @@ -2021,6 +2012,19 @@ class SQLCompiler(Compiled): ) self.binds[bindparam.key] = self.binds[name] = bindparam + + # if we are given a cache key that we're going to match against, + # relate the bindparam here to one that is most likely present + # in the "extracted params" portion of the cache key. this is used + # to set up a positional mapping that is used to determine the + # correct parameters for a subsequent use of this compiled with + # a different set of parameter values. here, we accommodate for + # parameters that may have been cloned both before and after the cache + # key was been generated. + ckbm = self._cache_key_bind_match + if ckbm: + ckbm.update({bp: bindparam for bp in bindparam._cloned_set}) + if bindparam.isoutparam: self.has_out_parameters = True diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 569030651..d3730b124 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -28,6 +28,9 @@ class _DDLCompiles(ClauseElement): return dialect.ddl_compiler(dialect, self, **kw) + def _compile_w_cache(self, *arg, **kw): + raise NotImplementedError() + class DDLElement(roles.DDLRole, Executable, _DDLCompiles): """Base class for DDL expression constructs. diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index a82641d77..50b2a935a 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -641,7 +641,7 @@ class ValuesBase(UpdateBase): if self._preserve_parameter_order: arg = [ ( - k, + coercions.expect(roles.DMLColumnRole, k), coercions.expect( roles.ExpressionElementRole, v, @@ -654,7 +654,7 @@ class ValuesBase(UpdateBase): self._ordered_values = arg else: arg = { - k: coercions.expect( + coercions.expect(roles.DMLColumnRole, k): coercions.expect( roles.ExpressionElementRole, v, type_=NullType(), @@ -772,6 +772,7 @@ class Insert(ValuesBase): ] + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals ) @ValuesBase._constructor_20_deprecations( @@ -997,6 +998,7 @@ class Update(DMLWhereBase, ValuesBase): ] + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals ) @ValuesBase._constructor_20_deprecations( @@ -1187,7 +1189,7 @@ class Update(DMLWhereBase, ValuesBase): ) arg = [ ( - k, + coercions.expect(roles.DMLColumnRole, k), coercions.expect( roles.ExpressionElementRole, v, @@ -1238,6 +1240,7 @@ class Delete(DMLWhereBase, UpdateBase): ] + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals ) @ValuesBase._constructor_20_deprecations( diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 8e1b623a7..60c816ee6 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -381,6 +381,7 @@ class ClauseElement( try: traverse_internals = self._traverse_internals except AttributeError: + # user-defined classes may not have a _traverse_internals return for attrname, obj, meth in _copy_internals.run_generated_dispatch( @@ -410,6 +411,7 @@ class ClauseElement( try: traverse_internals = self._traverse_internals except AttributeError: + # user-defined classes may not have a _traverse_internals return [] return itertools.chain.from_iterable( @@ -516,10 +518,62 @@ class ClauseElement( dialect = bind.dialect elif self.bind: dialect = self.bind.dialect - bind = self.bind else: dialect = default.StrCompileDialect() - return self._compiler(dialect, bind=bind, **kw) + + return self._compiler(dialect, **kw) + + def _compile_w_cache( + self, + dialect, + compiled_cache=None, + column_keys=None, + inline=False, + schema_translate_map=None, + **kw + ): + if compiled_cache is not None: + elem_cache_key = self._generate_cache_key() + else: + elem_cache_key = None + + cache_hit = False + + if elem_cache_key: + cache_key, extracted_params = elem_cache_key + key = ( + dialect, + cache_key, + tuple(column_keys), + bool(schema_translate_map), + inline, + ) + compiled_sql = compiled_cache.get(key) + + if compiled_sql is None: + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + inline=inline, + schema_translate_map=schema_translate_map, + **kw + ) + compiled_cache[key] = compiled_sql + else: + cache_hit = True + else: + extracted_params = None + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + inline=inline, + schema_translate_map=schema_translate_map, + **kw + ) + + return compiled_sql, extracted_params, cache_hit def _compiler(self, dialect, **kw): """Return a compiler appropriate for this ClauseElement, given a @@ -1035,6 +1089,10 @@ class BindParameter(roles.InElementRole, ColumnElement): _is_bind_parameter = True _key_is_anon = False + # bindparam implements its own _gen_cache_key() method however + # we check subclasses for this flag, else no cache key is generated + inherit_cache = True + def __init__( self, key, @@ -1396,6 +1454,13 @@ class BindParameter(roles.InElementRole, ColumnElement): return c def _gen_cache_key(self, anon_map, bindparams): + _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False) + + if not _gen_cache_ok: + if anon_map is not None: + anon_map[NO_CACHE] = True + return None + idself = id(self) if idself in anon_map: return (anon_map[idself], self.__class__) @@ -2082,6 +2147,7 @@ class ClauseList( roles.InElementRole, roles.OrderByRole, roles.ColumnsClauseRole, + roles.DMLColumnRole, ClauseElement, ): """Describe a list of clauses, separated by an operator. @@ -2174,6 +2240,7 @@ class ClauseList( class BooleanClauseList(ClauseList, ColumnElement): __visit_name__ = "clauselist" + inherit_cache = True _tuple_values = False @@ -3428,6 +3495,8 @@ class CollectionAggregate(UnaryExpression): class AsBoolean(WrapsColumnExpression, UnaryExpression): + inherit_cache = True + def __init__(self, element, operator, negate): self.element = element self.type = type_api.BOOLEANTYPE @@ -3474,6 +3543,7 @@ class BinaryExpression(ColumnElement): ("operator", InternalTraversal.dp_operator), ("negate", InternalTraversal.dp_operator), ("modifiers", InternalTraversal.dp_plain_dict), + ("type", InternalTraversal.dp_type,), # affects JSON CAST operators ] _is_implicitly_boolean = True @@ -3482,41 +3552,6 @@ class BinaryExpression(ColumnElement): """ - def _gen_cache_key(self, anon_map, bindparams): - # inlined for performance - - idself = id(self) - - if idself in anon_map: - return (anon_map[idself], self.__class__) - else: - # inline of - # id_ = anon_map[idself] - anon_map[idself] = id_ = str(anon_map.index) - anon_map.index += 1 - - if self._cache_key_traversal is NO_CACHE: - anon_map[NO_CACHE] = True - return None - - result = (id_, self.__class__) - - return result + ( - ("left", self.left._gen_cache_key(anon_map, bindparams)), - ("right", self.right._gen_cache_key(anon_map, bindparams)), - ("operator", self.operator), - ("negate", self.negate), - ( - "modifiers", - tuple( - (key, self.modifiers[key]) - for key in sorted(self.modifiers) - ) - if self.modifiers - else None, - ), - ) - def __init__( self, left, right, operator, type_=None, negate=None, modifiers=None ): @@ -3587,15 +3622,30 @@ class Slice(ColumnElement): __visit_name__ = "slice" _traverse_internals = [ - ("start", InternalTraversal.dp_plain_obj), - ("stop", InternalTraversal.dp_plain_obj), - ("step", InternalTraversal.dp_plain_obj), + ("start", InternalTraversal.dp_clauseelement), + ("stop", InternalTraversal.dp_clauseelement), + ("step", InternalTraversal.dp_clauseelement), ] - def __init__(self, start, stop, step): - self.start = start - self.stop = stop - self.step = step + def __init__(self, start, stop, step, _name=None): + self.start = coercions.expect( + roles.ExpressionElementRole, + start, + name=_name, + type_=type_api.INTEGERTYPE, + ) + self.stop = coercions.expect( + roles.ExpressionElementRole, + stop, + name=_name, + type_=type_api.INTEGERTYPE, + ) + self.step = coercions.expect( + roles.ExpressionElementRole, + step, + name=_name, + type_=type_api.INTEGERTYPE, + ) self.type = type_api.NULLTYPE def self_group(self, against=None): diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 6b1172eba..7b723f371 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -744,6 +744,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): coerce_arguments = True _register = False + inherit_cache = True def __init__(self, *args, **kwargs): parsed_args = kwargs.pop("_parsed_args", None) @@ -808,6 +809,8 @@ class next_value(GenericFunction): class AnsiFunction(GenericFunction): + inherit_cache = True + def __init__(self, *args, **kwargs): GenericFunction.__init__(self, *args, **kwargs) @@ -815,6 +818,8 @@ class AnsiFunction(GenericFunction): class ReturnTypeFromArgs(GenericFunction): """Define a function whose return type is the same as its arguments.""" + inherit_cache = True + def __init__(self, *args, **kwargs): args = [ coercions.expect( @@ -832,30 +837,34 @@ class ReturnTypeFromArgs(GenericFunction): class coalesce(ReturnTypeFromArgs): _has_args = True + inherit_cache = True class max(ReturnTypeFromArgs): # noqa - pass + inherit_cache = True class min(ReturnTypeFromArgs): # noqa - pass + inherit_cache = True class sum(ReturnTypeFromArgs): # noqa - pass + inherit_cache = True class now(GenericFunction): # noqa type = sqltypes.DateTime + inherit_cache = True class concat(GenericFunction): type = sqltypes.String + inherit_cache = True class char_length(GenericFunction): type = sqltypes.Integer + inherit_cache = True def __init__(self, arg, **kwargs): GenericFunction.__init__(self, arg, **kwargs) @@ -863,6 +872,7 @@ class char_length(GenericFunction): class random(GenericFunction): _has_args = True + inherit_cache = True class count(GenericFunction): @@ -887,6 +897,7 @@ class count(GenericFunction): """ type = sqltypes.Integer + inherit_cache = True def __init__(self, expression=None, **kwargs): if expression is None: @@ -896,38 +907,47 @@ class count(GenericFunction): class current_date(AnsiFunction): type = sqltypes.Date + inherit_cache = True class current_time(AnsiFunction): type = sqltypes.Time + inherit_cache = True class current_timestamp(AnsiFunction): type = sqltypes.DateTime + inherit_cache = True class current_user(AnsiFunction): type = sqltypes.String + inherit_cache = True class localtime(AnsiFunction): type = sqltypes.DateTime + inherit_cache = True class localtimestamp(AnsiFunction): type = sqltypes.DateTime + inherit_cache = True class session_user(AnsiFunction): type = sqltypes.String + inherit_cache = True class sysdate(AnsiFunction): type = sqltypes.DateTime + inherit_cache = True class user(AnsiFunction): type = sqltypes.String + inherit_cache = True class array_agg(GenericFunction): @@ -951,6 +971,7 @@ class array_agg(GenericFunction): """ type = sqltypes.ARRAY + inherit_cache = True def __init__(self, *args, **kwargs): args = [ @@ -978,6 +999,7 @@ class OrderedSetAgg(GenericFunction): :meth:`.FunctionElement.within_group` method.""" array_for_multi_clause = False + inherit_cache = True def within_group_type(self, within_group): func_clauses = self.clause_expr.element @@ -1000,6 +1022,8 @@ class mode(OrderedSetAgg): """ + inherit_cache = True + class percentile_cont(OrderedSetAgg): """implement the ``percentile_cont`` ordered-set aggregate function. @@ -1016,6 +1040,7 @@ class percentile_cont(OrderedSetAgg): """ array_for_multi_clause = True + inherit_cache = True class percentile_disc(OrderedSetAgg): @@ -1033,6 +1058,7 @@ class percentile_disc(OrderedSetAgg): """ array_for_multi_clause = True + inherit_cache = True class rank(GenericFunction): @@ -1048,6 +1074,7 @@ class rank(GenericFunction): """ type = sqltypes.Integer() + inherit_cache = True class dense_rank(GenericFunction): @@ -1063,6 +1090,7 @@ class dense_rank(GenericFunction): """ type = sqltypes.Integer() + inherit_cache = True class percent_rank(GenericFunction): @@ -1078,6 +1106,7 @@ class percent_rank(GenericFunction): """ type = sqltypes.Numeric() + inherit_cache = True class cume_dist(GenericFunction): @@ -1093,6 +1122,7 @@ class cume_dist(GenericFunction): """ type = sqltypes.Numeric() + inherit_cache = True class cube(GenericFunction): @@ -1109,6 +1139,7 @@ class cube(GenericFunction): """ _has_args = True + inherit_cache = True class rollup(GenericFunction): @@ -1125,6 +1156,7 @@ class rollup(GenericFunction): """ _has_args = True + inherit_cache = True class grouping_sets(GenericFunction): @@ -1158,3 +1190,4 @@ class grouping_sets(GenericFunction): """ _has_args = True + inherit_cache = True diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index ee411174c..29ca81d26 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1013,6 +1013,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): __visit_name__ = "column" + inherit_cache = True + def __init__(self, *args, **kwargs): r""" Construct a new ``Column`` object. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index a95fc561a..54f293967 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -61,6 +61,8 @@ if util.TYPE_CHECKING: class _OffsetLimitParam(BindParameter): + inherit_cache = True + @property def _limit_offset_value(self): return self.effective_value @@ -1426,6 +1428,8 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows): __visit_name__ = "alias" + inherit_cache = True + @classmethod def _factory(cls, selectable, name=None, flat=False): """Return an :class:`_expression.Alias` object. @@ -1500,6 +1504,8 @@ class Lateral(AliasedReturnsRows): __visit_name__ = "lateral" _is_lateral = True + inherit_cache = True + @classmethod def _factory(cls, selectable, name=None): """Return a :class:`_expression.Lateral` object. @@ -1626,7 +1632,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows): AliasedReturnsRows._traverse_internals + [ ("_cte_alias", InternalTraversal.dp_clauseelement), - ("_restates", InternalTraversal.dp_clauseelement_unordered_set), + ("_restates", InternalTraversal.dp_clauseelement_list), ("recursive", InternalTraversal.dp_boolean), ] + HasPrefixes._has_prefixes_traverse_internals @@ -1651,7 +1657,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows): name=None, recursive=False, _cte_alias=None, - _restates=frozenset(), + _restates=(), _prefixes=None, _suffixes=None, ): @@ -1692,7 +1698,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows): self.element.union(other), name=self.name, recursive=self.recursive, - _restates=self._restates.union([self]), + _restates=self._restates + (self,), _prefixes=self._prefixes, _suffixes=self._suffixes, ) @@ -1702,7 +1708,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows): self.element.union_all(other), name=self.name, recursive=self.recursive, - _restates=self._restates.union([self]), + _restates=self._restates + (self,), _prefixes=self._prefixes, _suffixes=self._suffixes, ) @@ -1918,6 +1924,8 @@ class Subquery(AliasedReturnsRows): _is_subquery = True + inherit_cache = True + @classmethod def _factory(cls, selectable, name=None): """Return a :class:`.Subquery` object. @@ -3783,15 +3791,15 @@ class Select( ("_group_by_clauses", InternalTraversal.dp_clauseelement_list,), ("_setup_joins", InternalTraversal.dp_setup_join_tuple,), ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,), - ("_correlate", InternalTraversal.dp_clauseelement_unordered_set), - ( - "_correlate_except", - InternalTraversal.dp_clauseelement_unordered_set, - ), + ("_correlate", InternalTraversal.dp_clauseelement_list), + ("_correlate_except", InternalTraversal.dp_clauseelement_list,), + ("_limit_clause", InternalTraversal.dp_clauseelement), + ("_offset_clause", InternalTraversal.dp_clauseelement), ("_for_update_arg", InternalTraversal.dp_clauseelement), ("_distinct", InternalTraversal.dp_boolean), ("_distinct_on", InternalTraversal.dp_clauseelement_list), ("_label_style", InternalTraversal.dp_plain_obj), + ("_is_future", InternalTraversal.dp_boolean), ] + HasPrefixes._has_prefixes_traverse_internals + HasSuffixes._has_suffixes_traverse_internals @@ -4522,7 +4530,7 @@ class Select( if fromclauses and fromclauses[0] is None: self._correlate = () else: - self._correlate = set(self._correlate).union( + self._correlate = self._correlate + tuple( coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) @@ -4560,7 +4568,7 @@ class Select( if fromclauses and fromclauses[0] is None: self._correlate_except = () else: - self._correlate_except = set(self._correlate_except or ()).union( + self._correlate_except = (self._correlate_except or ()) + tuple( coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) @@ -4866,6 +4874,7 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): _from_objects = [] _is_from_container = True _is_implicitly_boolean = False + inherit_cache = True def __init__(self, element): self.element = element @@ -4899,6 +4908,7 @@ class Exists(UnaryExpression): """ _from_objects = [] + inherit_cache = True def __init__(self, *args, **kwargs): """Construct a new :class:`_expression.Exists` against an existing diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 732b775f6..9cd9d5058 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2616,26 +2616,10 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): return_type = self.type if self.type.zero_indexes: index = slice(index.start + 1, index.stop + 1, index.step) - index = Slice( - coercions.expect( - roles.ExpressionElementRole, - index.start, - name=self.expr.key, - type_=type_api.INTEGERTYPE, - ), - coercions.expect( - roles.ExpressionElementRole, - index.stop, - name=self.expr.key, - type_=type_api.INTEGERTYPE, - ), - coercions.expect( - roles.ExpressionElementRole, - index.step, - name=self.expr.key, - type_=type_api.INTEGERTYPE, - ), + slice_ = Slice( + index.start, index.stop, index.step, _name=self.expr.key ) + return operators.getitem, slice_, return_type else: if self.type.zero_indexes: index += 1 @@ -2647,7 +2631,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): self.type.__class__, **adapt_kw ) - return operators.getitem, index, return_type + return operators.getitem, index, return_type def contains(self, *arg, **kw): raise NotImplementedError( diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 68281f33d..ed0bfa27a 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -19,6 +19,7 @@ NO_CACHE = util.symbol("no_cache") CACHE_IN_PLACE = util.symbol("cache_in_place") CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key") STATIC_CACHE_KEY = util.symbol("static_cache_key") +PROPAGATE_ATTRS = util.symbol("propagate_attrs") ANON_NAME = util.symbol("anon_name") @@ -31,10 +32,74 @@ def compare(obj1, obj2, **kw): return strategy.compare(obj1, obj2, **kw) +def _preconfigure_traversals(target_hierarchy): + + stack = [target_hierarchy] + while stack: + cls = stack.pop() + stack.extend(cls.__subclasses__()) + + if hasattr(cls, "_traverse_internals"): + cls._generate_cache_attrs() + _copy_internals.generate_dispatch( + cls, + cls._traverse_internals, + "_generated_copy_internals_traversal", + ) + _get_children.generate_dispatch( + cls, + cls._traverse_internals, + "_generated_get_children_traversal", + ) + + class HasCacheKey(object): _cache_key_traversal = NO_CACHE __slots__ = () + @classmethod + def _generate_cache_attrs(cls): + """generate cache key dispatcher for a new class. + + This sets the _generated_cache_key_traversal attribute once called + so should only be called once per class. + + """ + inherit = cls.__dict__.get("inherit_cache", False) + + if inherit: + _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) + if _cache_key_traversal is None: + try: + _cache_key_traversal = cls._traverse_internals + except AttributeError: + cls._generated_cache_key_traversal = NO_CACHE + return NO_CACHE + + # TODO: wouldn't we instead get this from our superclass? + # also, our superclass may not have this yet, but in any case, + # we'd generate for the superclass that has it. this is a little + # more complicated, so for the moment this is a little less + # efficient on startup but simpler. + return _cache_key_traversal_visitor.generate_dispatch( + cls, _cache_key_traversal, "_generated_cache_key_traversal" + ) + else: + _cache_key_traversal = cls.__dict__.get( + "_cache_key_traversal", None + ) + if _cache_key_traversal is None: + _cache_key_traversal = cls.__dict__.get( + "_traverse_internals", None + ) + if _cache_key_traversal is None: + cls._generated_cache_key_traversal = NO_CACHE + return NO_CACHE + + return _cache_key_traversal_visitor.generate_dispatch( + cls, _cache_key_traversal, "_generated_cache_key_traversal" + ) + @util.preload_module("sqlalchemy.sql.elements") def _gen_cache_key(self, anon_map, bindparams): """return an optional cache key. @@ -72,14 +137,18 @@ class HasCacheKey(object): else: id_ = None - _cache_key_traversal = self._cache_key_traversal - if _cache_key_traversal is None: - try: - _cache_key_traversal = self._traverse_internals - except AttributeError: - _cache_key_traversal = NO_CACHE + try: + dispatcher = self.__class__.__dict__[ + "_generated_cache_key_traversal" + ] + except KeyError: + # most of the dispatchers are generated up front + # in sqlalchemy/sql/__init__.py -> + # traversals.py-> _preconfigure_traversals(). + # this block will generate any remaining dispatchers. + dispatcher = self.__class__._generate_cache_attrs() - if _cache_key_traversal is NO_CACHE: + if dispatcher is NO_CACHE: if anon_map is not None: anon_map[NO_CACHE] = True return None @@ -87,19 +156,13 @@ class HasCacheKey(object): result = (id_, self.__class__) # inline of _cache_key_traversal_visitor.run_generated_dispatch() - try: - dispatcher = self.__class__.__dict__[ - "_generated_cache_key_traversal" - ] - except KeyError: - dispatcher = _cache_key_traversal_visitor.generate_dispatch( - self, _cache_key_traversal, "_generated_cache_key_traversal" - ) for attrname, obj, meth in dispatcher( self, _cache_key_traversal_visitor ): if obj is not None: + # TODO: see if C code can help here as Python lacks an + # efficient switch construct if meth is CACHE_IN_PLACE: # cache in place is always going to be a Python # tuple, dict, list, etc. so we can do a boolean check @@ -116,6 +179,15 @@ class HasCacheKey(object): attrname, obj._gen_cache_key(anon_map, bindparams), ) + elif meth is PROPAGATE_ATTRS: + if obj: + result += ( + attrname, + obj["compile_state_plugin"], + obj["plugin_subject"]._gen_cache_key( + anon_map, bindparams + ), + ) elif meth is InternalTraversal.dp_annotations_key: # obj is here is the _annotations dict. however, # we want to use the memoized cache key version of it. @@ -332,6 +404,8 @@ class _CacheKey(ExtendedInternalTraversal): visit_type = STATIC_CACHE_KEY visit_anon_name = ANON_NAME + visit_propagate_attrs = PROPAGATE_ATTRS + def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) @@ -445,10 +519,16 @@ class _CacheKey(ExtendedInternalTraversal): def visit_setup_join_tuple( self, attrname, obj, parent, anon_map, bindparams ): + is_legacy = "legacy" in attrname + return tuple( ( - target._gen_cache_key(anon_map, bindparams), - onclause._gen_cache_key(anon_map, bindparams) + target + if is_legacy and isinstance(target, str) + else target._gen_cache_key(anon_map, bindparams), + onclause + if is_legacy and isinstance(onclause, str) + else onclause._gen_cache_key(anon_map, bindparams) if onclause is not None else None, from_._gen_cache_key(anon_map, bindparams) @@ -711,6 +791,11 @@ class _CopyInternals(InternalTraversal): for sequence in element ] + def visit_propagate_attrs( + self, attrname, parent, element, clone=_clone, **kw + ): + return element + _copy_internals = _CopyInternals() @@ -782,6 +867,9 @@ class _GetChildren(InternalTraversal): def visit_dml_multi_values(self, element, **kw): return () + def visit_propagate_attrs(self, element, **kw): + return () + _get_children = _GetChildren() @@ -916,6 +1004,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): ): return COMPARE_FAILED + def visit_propagate_attrs( + self, attrname, left_parent, left, right_parent, right, **kw + ): + return self.compare_inner( + left.get("plugin_subject", None), right.get("plugin_subject", None) + ) + def visit_has_cache_key_list( self, attrname, left_parent, left, right_parent, right, **kw ): diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index ccda21e11..fe3634bad 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -555,7 +555,12 @@ class TypeEngine(Traversible): def _static_cache_key(self): names = util.get_cls_kwargs(self.__class__) return (self.__class__,) + tuple( - (k, self.__dict__[k]) + ( + k, + self.__dict__[k]._static_cache_key + if isinstance(self.__dict__[k], TypeEngine) + else self.__dict__[k], + ) for k in names if k in self.__dict__ and not k.startswith("_") ) diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 5de68f504..904702003 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -217,18 +217,23 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): try: dispatcher = target.__class__.__dict__[generate_dispatcher_name] except KeyError: + # most of the dispatchers are generated up front + # in sqlalchemy/sql/__init__.py -> + # traversals.py-> _preconfigure_traversals(). + # this block will generate any remaining dispatchers. dispatcher = self.generate_dispatch( - target, internal_dispatch, generate_dispatcher_name + target.__class__, internal_dispatch, generate_dispatcher_name ) return dispatcher(target, self) def generate_dispatch( - self, target, internal_dispatch, generate_dispatcher_name + self, target_cls, internal_dispatch, generate_dispatcher_name ): dispatcher = _generate_dispatcher( self, internal_dispatch, generate_dispatcher_name ) - setattr(target.__class__, generate_dispatcher_name, dispatcher) + # assert isinstance(target_cls, type) + setattr(target_cls, generate_dispatcher_name, dispatcher) return dispatcher dp_has_cache_key = symbol("HC") @@ -263,10 +268,6 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ - dp_clauseelement_unordered_set = symbol("CU") - """Visit an unordered set of :class:`_expression.ClauseElement` - objects. """ - dp_fromclause_ordered_set = symbol("CO") """Visit an ordered set of :class:`_expression.FromClause` objects. """ @@ -414,6 +415,10 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ + dp_propagate_attrs = symbol("PA") + """Visit the propagate attrs dict. this hardcodes to the particular + elements we care about right now.""" + class ExtendedInternalTraversal(InternalTraversal): """defines additional symbols that are useful in caching applications. diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 7988b4ec9..48cbb4694 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -97,13 +97,13 @@ class CompiledSQL(SQLMatchRule): else: map_ = None - if isinstance(context.compiled.statement, _DDLCompiles): + if isinstance(execute_observed.clauseelement, _DDLCompiles): - compiled = context.compiled.statement.compile( + compiled = execute_observed.clauseelement.compile( dialect=compare_dialect, schema_translate_map=map_ ) else: - compiled = context.compiled.statement.compile( + compiled = execute_observed.clauseelement.compile( dialect=compare_dialect, column_keys=context.compiled.column_keys, inline=context.compiled.inline, |
