diff options
| -rw-r--r-- | doc/build/changelog/unreleased_20/7837.rst | 40 | ||||
| -rw-r--r-- | doc/build/glossary.rst | 16 | ||||
| -rw-r--r-- | doc/build/orm/queryguide/api.rst | 132 | ||||
| -rw-r--r-- | examples/sharding/separate_databases.py | 10 | ||||
| -rw-r--r-- | examples/sharding/separate_schema_translates.py | 13 | ||||
| -rw-r--r-- | examples/sharding/separate_tables.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 138 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/bulk_persistence.py | 118 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/context.py | 27 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/loading.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 86 | ||||
| -rw-r--r-- | test/ext/test_deprecations.py | 31 | ||||
| -rw-r--r-- | test/ext/test_horizontal_shard.py | 107 | ||||
| -rw-r--r-- | test/orm/test_events.py | 39 | ||||
| -rw-r--r-- | test/orm/test_query.py | 79 | ||||
| -rw-r--r-- | test/orm/test_session.py | 33 |
17 files changed, 677 insertions, 206 deletions
diff --git a/doc/build/changelog/unreleased_20/7837.rst b/doc/build/changelog/unreleased_20/7837.rst new file mode 100644 index 000000000..1abb3e157 --- /dev/null +++ b/doc/build/changelog/unreleased_20/7837.rst @@ -0,0 +1,40 @@ +.. change:: + :tags: usecase, orm + :tickets: 7837 + + Adjustments to the :class:`_orm.Session` in terms of extensibility, + as well as updates to the :class:`.ShardedSession` extension: + + * :meth:`_orm.Session.get` now accepts + :paramref:`_orm.Session.get.bind_arguments`, which in particular may be + useful when using the horizontal sharding extension. + + * :meth:`_orm.Session.get_bind` accepts arbitrary kw arguments, which + assists in developing code that uses a :class:`_orm.Session` class which + overrides this method with additional arguments. + + * Added a new ORM execution option ``identity_token`` which may be used + to directly affect the "identity token" that will be associated with + newly loaded ORM objects. This token is how sharding approaches + (namely the :class:`.ShardedSession`, but can be used in other cases + as well) separate object identities across different "shards". + + .. seealso:: + + :ref:`queryguide_identity_token` + + * The :meth:`_orm.SessionEvents.do_orm_execute` event hook may now be used + to affect all ORM-related options, including ``autoflush``, + ``populate_existing``, and ``yield_per``; these options are re-consumed + subsequent to event hooks being invoked before they are acted upon. + Previously, options like ``autoflush`` would have been already evaluated + at this point. The new ``identity_token`` option is also supported in + this mode and is now used by the horizontal sharding extension. + + + * The :class:`.ShardedSession` class replaces the + :paramref:`.ShardedSession.id_chooser` hook with a new hook + :paramref:`.ShardedSession.identity_chooser`, which no longer relies upon + the legacy :class:`_orm.Query` object. + :paramref:`.ShardedSession.id_chooser` is still accepted in place of + :paramref:`.ShardedSession.identity_chooser` with a deprecation warning. diff --git a/doc/build/glossary.rst b/doc/build/glossary.rst index d0bc4f814..70eb05e64 100644 --- a/doc/build/glossary.rst +++ b/doc/build/glossary.rst @@ -488,6 +488,19 @@ Glossary primary key identity within the database, as well as their unique identity within a :class:`_orm.Session` :term:`identity map`. + In SQLAlchemy, you can view the identity key for an ORM object + using the :func:`_sa.inspect` API to return the :class:`_orm.InstanceState` + tracking object, then looking at the :attr:`_orm.InstanceState.key` + attribute:: + + >>> from sqlalchemy import inspect + >>> inspect(some_object).key + (<class '__main__.MyTable'>, (1,), None) + + .. seealso:: + + :term:`identity map` + identity map A mapping between Python objects and their database identities. The identity map is a collection that's associated with an @@ -505,6 +518,9 @@ Glossary `Identity Map (via Martin Fowler) <https://martinfowler.com/eaaCatalog/identityMap.html>`_ + :ref:`session_get` - how to look up an object in the identity map + by primary key + lazy initialization A tactic of delaying some initialization action, such as creating objects, populating data, or establishing connectivity to other services, until diff --git a/doc/build/orm/queryguide/api.rst b/doc/build/orm/queryguide/api.rst index 136b4b39b..35259a3b3 100644 --- a/doc/build/orm/queryguide/api.rst +++ b/doc/build/orm/queryguide/api.rst @@ -280,6 +280,138 @@ will have the same result as that of the ``yield_per`` execution option. :ref:`engine_stream_results` +.. _queryguide_identity_token: + +Identity Token +^^^^^^^^^^^^^^ + +.. doctest-disable: + +.. deepalchemy:: This option is an advanced-use feature mostly intended + to be used with the :ref:`horizontal_sharding_toplevel` extension. For + typical cases of loading objects with identical primary keys from different + "shards" or partitions, consider using individual :class:`_orm.Session` + objects per shard first. + + +The "identity token" is an arbitrary value that can be associated within +the :term:`identity key` of newly loaded objects. This element exists +first and foremost to support extensions which perform per-row "sharding", +where objects may be loaded from any number of replicas of a particular +database table that nonetheless have overlapping primary key values. +The primary consumer of "identity token" is the +:ref:`horizontal_sharding_toplevel` extension, which supplies a general +framework for persisting objects among multiple "shards" of a particular +database table. + +The ``identity_token`` execution option may be used on a per-query basis +to directly affect this token. Using it directly, one can populate a +:class:`_orm.Session` with multiple instances of an object that have the +same primary key and source table, but different "identities". + +One such example is to populate a :class:`_orm.Session` with objects that +come from same-named tables in different schemas, using the +:ref:`schema_translating` feature which can affect the choice of schema +within the scope of queries. Given a mapping as: + +.. sourcecode:: python + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + + class Base(DeclarativeBase): + pass + + + class MyTable(Base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + +The default "schema" name for the class above is ``None``, meaning, no +schema qualification will be written into SQL statements. However, +if we make use of :paramref:`_engine.Connection.execution_options.schema_translate_map`, +mapping ``None`` to an alternate schema, we can place instances of +``MyTable`` into two different schemas: + +.. sourcecode:: python + + engine = create_engine( + "postgresql+psycopg://scott:tiger@localhost/test", + ) + + with Session( + engine.execution_options(schema_translate_map={None: "test_schema"}) + ) as sess: + sess.add(MyTable(name="this is schema one")) + sess.commit() + + with Session( + engine.execution_options(schema_translate_map={None: "test_schema_2"}) + ) as sess: + sess.add(MyTable(name="this is schema two")) + sess.commit() + +The above two blocks create a :class:`_orm.Session` object linked to a different +schema translate map each time, and an instance of ``MyTable`` is persisted +into both ``test_schema.my_table`` as well as ``test_schema_2.my_table``. + +The :class:`_orm.Session` objects above are independent. If we wanted to +persist both objects in one transaction, we would need to use the +:ref:`horizontal_sharding_toplevel` extension to do this. + +However, we can illustrate querying for these objects in one session as follows: + +.. sourcecode:: python + + with Session(engine) as sess: + obj1 = sess.scalar( + select(MyTable) + .where(MyTable.id == 1) + .execution_options( + schema_translate_map={None: "test_schema"}, + identity_token="test_schema", + ) + ) + obj2 = sess.scalar( + select(MyTable) + .where(MyTable.id == 1) + .execution_options( + schema_translate_map={None: "test_schema_2"}, + identity_token="test_schema_2", + ) + ) + +Both ``obj1`` and ``obj2`` are distinct from each other. However, they both +refer to primary key id 1 for the ``MyTable`` class, yet are distinct. +This is how the ``identity_token`` comes into play, which we can see in the +inspection of each object, where we look at :attr:`_orm.InstanceState.key` +to view the two distinct identity tokens:: + + >>> from sqlalchemy import inspect + >>> inspect(obj1).key + (<class '__main__.MyTable'>, (1,), 'test_schema') + >>> inspect(obj2).key + (<class '__main__.MyTable'>, (1,), 'test_schema_2') + + +The above logic takes place automatically when using the +:ref:`horizontal_sharding_toplevel` extension. + +.. versionadded:: 2.0.0b5 - added the ``identity_token`` ORM level execution + option. + +.. seealso:: + + :ref:`examples_sharding` - in the :ref:`examples_toplevel` section. + See the script ``separate_schema_translates.py`` for a demonstration of + the above use case using the full sharding API. + + +.. doctest-enable: .. _queryguide_inspection: diff --git a/examples/sharding/separate_databases.py b/examples/sharding/separate_databases.py index a45182f42..fe92fd3ba 100644 --- a/examples/sharding/separate_databases.py +++ b/examples/sharding/separate_databases.py @@ -135,8 +135,8 @@ def shard_chooser(mapper, instance, clause=None): return shard_chooser(mapper, instance.location) -def id_chooser(query, ident): - """id chooser. +def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw): + """identity chooser. given a primary key, returns a list of shards to search. here, we don't have any particular information from a @@ -145,11 +145,11 @@ def id_chooser(query, ident): distributed among DBs. """ - if query.lazy_loaded_from: + if lazy_loaded_from: # if we are in a lazy load, we can look at the parent object # and limit our search to that same shard, assuming that's how we've # set things up. - return [query.lazy_loaded_from.identity_token] + return [lazy_loaded_from.identity_token] else: return ["north_america", "asia", "europe", "south_america"] @@ -237,7 +237,7 @@ def _get_select_comparisons(statement): # further configure create_session to use these functions Session.configure( shard_chooser=shard_chooser, - id_chooser=id_chooser, + identity_chooser=identity_chooser, execute_chooser=execute_chooser, ) diff --git a/examples/sharding/separate_schema_translates.py b/examples/sharding/separate_schema_translates.py index 2d4c2a046..f7bdc6250 100644 --- a/examples/sharding/separate_schema_translates.py +++ b/examples/sharding/separate_schema_translates.py @@ -130,21 +130,20 @@ def shard_chooser(mapper, instance, clause=None): return shard_chooser(mapper, instance.location) -def id_chooser(query, ident): - """id chooser. +def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw): + """identity chooser. - given a primary key identity and a legacy :class:`_orm.Query`, - return which shard we should look at. + given a primary key identity, return which shard we should look at. in this case, we only want to support this for lazy-loaded items; any primary query should have shard id set up front. """ - if query.lazy_loaded_from: + if lazy_loaded_from: # if we are in a lazy load, we can look at the parent object # and limit our search to that same shard, assuming that's how we've # set things up. - return [query.lazy_loaded_from.identity_token] + return [lazy_loaded_from.identity_token] else: raise NotImplementedError() @@ -169,7 +168,7 @@ def execute_chooser(context): # configure shard chooser Session.configure( shard_chooser=shard_chooser, - id_chooser=id_chooser, + identity_chooser=identity_chooser, execute_chooser=execute_chooser, ) diff --git a/examples/sharding/separate_tables.py b/examples/sharding/separate_tables.py index 8f39471e8..97c6a07f6 100644 --- a/examples/sharding/separate_tables.py +++ b/examples/sharding/separate_tables.py @@ -149,8 +149,8 @@ def shard_chooser(mapper, instance, clause=None): return shard_chooser(mapper, instance.location) -def id_chooser(query, ident): - """id chooser. +def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw): + """identity chooser. given a primary key, returns a list of shards to search. here, we don't have any particular information from a @@ -159,11 +159,11 @@ def id_chooser(query, ident): distributed among DBs. """ - if query.lazy_loaded_from: + if lazy_loaded_from: # if we are in a lazy load, we can look at the parent object # and limit our search to that same shard, assuming that's how we've # set things up. - return [query.lazy_loaded_from.identity_token] + return [lazy_loaded_from.identity_token] else: return ["north_america", "asia", "europe", "south_america"] @@ -251,7 +251,7 @@ def _get_select_comparisons(statement): # further configure create_session to use these functions Session.configure( shard_chooser=shard_chooser, - id_chooser=id_chooser, + identity_chooser=identity_chooser, execute_chooser=execute_chooser, ) diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 69767ad6c..fd53c6046 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -13,11 +13,14 @@ distribute queries and persistence operations across multiple databases. For a usage example, see the :ref:`examples_sharding` example included in the source distribution. -.. legacy:: The horizontal sharding API is not fully updated for the - SQLAlchemy 2.0 API, and still relies in part on the - legacy :class:`.Query` architecture, in particular as part of the - signature for the :paramref:`.ShardedSession.id_chooser` parameter. - This may change in a future release. +.. deepalchemy:: The horizontal sharding extension is an advanced feature, + involving a complex statement -> database interaction as well as + use of semi-public APIs for non-trivial cases. Simpler approaches to + refering to multiple database "shards", most commonly using a distinct + :class:`_orm.Session` per "shard", should always be considered first + before using this more complex and less-production-tested system. + + """ from __future__ import annotations @@ -38,8 +41,11 @@ from .. import exc from .. import inspect from .. import util from ..orm import PassiveFlag +from ..orm._typing import OrmExecuteOptionsParameter from ..orm.mapper import Mapper from ..orm.query import Query +from ..orm.session import _BindArguments +from ..orm.session import _PKIdentityArgument from ..orm.session import Session from ..util.typing import Protocol @@ -80,6 +86,20 @@ class ShardChooser(Protocol): ... +class IdentityChooser(Protocol): + def __call__( + self, + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: + ... + + class ShardedQuery(Query[_T]): """Query class used with :class:`.ShardedSession`. @@ -94,8 +114,7 @@ class ShardedQuery(Query[_T]): super().__init__(*args, **kwargs) assert isinstance(self.session, ShardedSession) - self.id_chooser = self.session.id_chooser - self.query_chooser = self.session.query_chooser + self.identity_chooser = self.session.identity_chooser self.execute_chooser = self.session.execute_chooser self._shard_id = None @@ -119,19 +138,22 @@ class ShardedQuery(Query[_T]): class ShardedSession(Session): shard_chooser: ShardChooser - id_chooser: Callable[[Query[Any], Iterable[Any]], Iterable[Any]] + identity_chooser: IdentityChooser execute_chooser: Callable[[ORMExecuteState], Iterable[Any]] def __init__( self, shard_chooser: ShardChooser, - id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]], + identity_chooser: Optional[IdentityChooser] = None, execute_chooser: Optional[ Callable[[ORMExecuteState], Iterable[Any]] ] = None, shards: Optional[Dict[str, Any]] = None, query_cls: Type[Query[_T]] = ShardedQuery, *, + id_chooser: Optional[ + Callable[[Query[_T], Iterable[_T]], Iterable[Any]] + ] = None, query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None, **kwargs: Any, ) -> None: @@ -171,12 +193,41 @@ class ShardedSession(Session): self, "do_orm_execute", execute_and_instances, retval=True ) self.shard_chooser = shard_chooser - self.id_chooser = id_chooser + + if id_chooser: + _id_chooser = id_chooser + util.warn_deprecated( + "The ``id_chooser`` parameter is deprecated; " + "please use ``identity_chooser``.", + "2.0", + ) + + def _legacy_identity_chooser( + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: + q = self.query(mapper) + if lazy_loaded_from: + q = q._set_lazyload_from(lazy_loaded_from) + return _id_chooser(q, primary_key) + + self.identity_chooser = _legacy_identity_chooser + elif identity_chooser: + self.identity_chooser = identity_chooser + else: + raise exc.ArgumentError( + "identity_chooser or id_chooser is required" + ) if query_chooser: _query_chooser = query_chooser util.warn_deprecated( - "The ``query_choser`` parameter is deprecated; " + "The ``query_chooser`` parameter is deprecated; " "please use ``execute_chooser``.", "1.4", ) @@ -199,7 +250,6 @@ class ShardedSession(Session): "execute_chooser or query_chooser is required" ) self.execute_chooser = execute_chooser - self.query_chooser = query_chooser self.__shards: Dict[_ShardKey, _SessionBind] = {} if shards is not None: for k in shards: @@ -212,6 +262,8 @@ class ShardedSession(Session): identity_token: Optional[Any] = None, passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, lazy_loaded_from: Optional[InstanceState[Any]] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, **kw: Any, ) -> Union[Optional[_O], LoaderCallableStatus]: """override the default :meth:`.Session._identity_lookup` method so @@ -233,10 +285,13 @@ class ShardedSession(Session): return obj else: - q = self.query(mapper) - if lazy_loaded_from: - q = q._set_lazyload_from(lazy_loaded_from) - for shard_id in self.id_chooser(q, primary_key_identity): + for shard_id in self.identity_chooser( + mapper, + primary_key_identity, + lazy_loaded_from=lazy_loaded_from, + execution_options=execution_options, + bind_arguments=dict(bind_arguments) if bind_arguments else {}, + ): obj2 = super()._identity_lookup( mapper, primary_key_identity, @@ -325,11 +380,6 @@ class ShardedSession(Session): def execute_and_instances( orm_context: ORMExecuteState, ) -> Union[Result[_T], IteratorResult[_TP]]: - update_options: Union[ - None, - BulkUDCompileState.default_update_options, - Type[BulkUDCompileState.default_update_options], - ] active_options: Union[ None, QueryContext.default_load_options, @@ -337,58 +387,30 @@ def execute_and_instances( BulkUDCompileState.default_update_options, Type[BulkUDCompileState.default_update_options], ] - load_options: Union[ - None, - QueryContext.default_load_options, - Type[QueryContext.default_load_options], - ] if orm_context.is_select: - load_options = active_options = orm_context.load_options - update_options = None + active_options = orm_context.load_options elif orm_context.is_update or orm_context.is_delete: - load_options = None - update_options = active_options = orm_context.update_delete_options + active_options = orm_context.update_delete_options else: - load_options = update_options = active_options = None + active_options = None session = orm_context.session assert isinstance(session, ShardedSession) def iter_for_shard( shard_id: str, - load_options: Union[ - None, - QueryContext.default_load_options, - Type[QueryContext.default_load_options], - ], - update_options: Union[ - None, - BulkUDCompileState.default_update_options, - Type[BulkUDCompileState.default_update_options], - ], ) -> Union[Result[_T], IteratorResult[_TP]]: - execution_options = dict(orm_context.local_execution_options) bind_arguments = dict(orm_context.bind_arguments) bind_arguments["shard_id"] = shard_id - if orm_context.is_select: - assert load_options is not None - load_options += {"_refresh_identity_token": shard_id} - execution_options["_sa_orm_load_options"] = load_options - elif orm_context.is_update or orm_context.is_delete: - assert update_options is not None - update_options += {"_refresh_identity_token": shard_id} - execution_options["_sa_orm_update_options"] = update_options - - return orm_context.invoke_statement( - bind_arguments=bind_arguments, execution_options=execution_options - ) + orm_context.update_execution_options(identity_token=shard_id) + return orm_context.invoke_statement(bind_arguments=bind_arguments) - if active_options and active_options._refresh_identity_token is not None: - shard_id = active_options._refresh_identity_token + if active_options and active_options._identity_token is not None: + shard_id = active_options._identity_token elif "_sa_shard_id" in orm_context.execution_options: shard_id = orm_context.execution_options["_sa_shard_id"] elif "shard_id" in orm_context.bind_arguments: @@ -397,10 +419,10 @@ def execute_and_instances( shard_id = None if shard_id is not None: - return iter_for_shard(shard_id, load_options, update_options) + return iter_for_shard(shard_id) else: partial = [] for shard_id in session.execute_chooser(orm_context): - result_ = iter_for_shard(shard_id, load_options, update_options) + result_ = iter_for_shard(shard_id) partial.append(result_) return partial[0].merge(*partial[1:]) diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 181dbd4a2..805bfdc65 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -555,7 +555,7 @@ class BulkUDCompileState(ORMDMLState): _resolved_values = EMPTY_DICT _eval_condition = None _matched_rows = None - _refresh_identity_token = None + _identity_token = None @classmethod def can_use_returning( @@ -577,10 +577,8 @@ class BulkUDCompileState(ORMDMLState): params, execution_options, bind_arguments, - is_reentrant_invoke, + is_pre_event, ): - if is_reentrant_invoke: - return statement, execution_options ( update_options, @@ -590,6 +588,7 @@ class BulkUDCompileState(ORMDMLState): { "synchronize_session", "autoflush", + "identity_token", "is_delete_using", "is_update_from", "dml_strategy", @@ -637,55 +636,56 @@ class BulkUDCompileState(ORMDMLState): "for 'bulk' ORM updates (i.e. multiple parameter sets)" ) - if update_options._autoflush: - session._autoflush() - - if update_options._dml_strategy == "orm": + if not is_pre_event: + if update_options._autoflush: + session._autoflush() - if update_options._synchronize_session == "auto": - update_options = cls._do_pre_synchronize_auto( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - elif update_options._synchronize_session == "evaluate": - update_options = cls._do_pre_synchronize_evaluate( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - elif update_options._synchronize_session == "fetch": - update_options = cls._do_pre_synchronize_fetch( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - elif update_options._dml_strategy == "bulk": - if update_options._synchronize_session == "auto": - update_options += {"_synchronize_session": "evaluate"} + if update_options._dml_strategy == "orm": - # indicators from the "pre exec" step that are then - # added to the DML statement, which will also be part of the cache - # key. The compile level create_for_statement() method will then - # consume these at compiler time. - statement = statement._annotate( - { - "synchronize_session": update_options._synchronize_session, - "is_delete_using": update_options._is_delete_using, - "is_update_from": update_options._is_update_from, - "dml_strategy": update_options._dml_strategy, - "can_use_returning": update_options._can_use_returning, - } - ) + if update_options._synchronize_session == "auto": + update_options = cls._do_pre_synchronize_auto( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "evaluate": + update_options = cls._do_pre_synchronize_evaluate( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "fetch": + update_options = cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "auto": + update_options += {"_synchronize_session": "evaluate"} + + # indicators from the "pre exec" step that are then + # added to the DML statement, which will also be part of the cache + # key. The compile level create_for_statement() method will then + # consume these at compiler time. + statement = statement._annotate( + { + "synchronize_session": update_options._synchronize_session, + "is_delete_using": update_options._is_delete_using, + "is_update_from": update_options._is_update_from, + "dml_strategy": update_options._dml_strategy, + "can_use_returning": update_options._can_use_returning, + } + ) return ( statement, @@ -836,7 +836,7 @@ class BulkUDCompileState(ORMDMLState): if state.mapper.isa(mapper) and not state.expired ] - identity_token = update_options._refresh_identity_token + identity_token = update_options._identity_token if identity_token is not None: raw_data = [ (obj, state, dict_) @@ -1091,7 +1091,7 @@ class BulkORMInsert(ORMDMLState, InsertDMLState): params, execution_options, bind_arguments, - is_reentrant_invoke, + is_pre_event, ): ( @@ -1143,7 +1143,7 @@ class BulkORMInsert(ORMDMLState, InsertDMLState): context._orm_load_exec_options ) - if insert_options._autoflush: + if not is_pre_event and insert_options._autoflush: session._autoflush() statement = statement._annotate( @@ -1577,7 +1577,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): for param in params: identity_key = mapper.identity_key_from_primary_key( (param[key] for key in pk_keys), - update_options._refresh_identity_token, + update_options._identity_token, ) state = identity_map.fast_get_state(identity_key) if not state: @@ -1635,7 +1635,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): ) matched_rows = [ - tuple(row) + (update_options._refresh_identity_token,) + tuple(row) + (update_options._identity_token,) for row in pk_rows ] else: @@ -1651,8 +1651,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): for primary_key, identity_token in [ (row[0:-1], row[-1]) for row in matched_rows ] - if update_options._refresh_identity_token is None - or identity_token == update_options._refresh_identity_token + if update_options._identity_token is None + or identity_token == update_options._identity_token ] if identity_key in session.identity_map ] @@ -1912,7 +1912,7 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): ) matched_rows = [ - tuple(row) + (update_options._refresh_identity_token,) + tuple(row) + (update_options._identity_token,) for row in pk_rows ] else: diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 3bd8b02a7..b3478b83e 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -135,7 +135,7 @@ class QueryContext: _version_check = False _invoke_all_eagers = True _autoflush = True - _refresh_identity_token = None + _identity_token = None _yield_per = None _refresh_state = None _lazy_loaded_from = None @@ -194,14 +194,14 @@ class QueryContext: self.version_check = load_options._version_check self.refresh_state = load_options._refresh_state self.yield_per = load_options._yield_per - self.identity_token = load_options._refresh_identity_token + self.identity_token = load_options._identity_token def _get_top_level_context(self) -> QueryContext: return self.top_level_context or self _orm_load_exec_options = util.immutabledict( - {"_result_disable_adapt_to_context": True, "future_result": True} + {"_result_disable_adapt_to_context": True} ) @@ -235,7 +235,7 @@ class AbstractORMCompileState(CompileState): params, execution_options, bind_arguments, - is_reentrant_invoke, + is_pre_event, ): raise NotImplementedError() @@ -384,11 +384,11 @@ class ORMCompileState(AbstractORMCompileState): params, execution_options, bind_arguments, - is_reentrant_invoke, + is_pre_event, ): - if is_reentrant_invoke: - return statement, execution_options + # consume result-level load_options. These may have been set up + # in an ORMExecuteState hook ( load_options, execution_options, @@ -398,26 +398,24 @@ class ORMCompileState(AbstractORMCompileState): "populate_existing", "autoflush", "yield_per", + "identity_token", "sa_top_level_orm_context", }, execution_options, statement._execution_options, ) + # default execution options for ORM results: # 1. _result_disable_adapt_to_context=True # this will disable the ResultSetMetadata._adapt_to_context() # step which we don't need, as we have result processors cached # against the original SELECT statement before caching. - # 2. future_result=True. The ORM should **never** resolve columns - # in a result set based on names, only on Column objects that - # are correctly adapted to the context. W the legacy result - # it will still attempt name-based resolution and also emit a - # warning. if not execution_options: execution_options = _orm_load_exec_options else: execution_options = execution_options.union(_orm_load_exec_options) + # would have been placed here by legacy Query only if load_options._yield_per: execution_options = execution_options.union( {"yield_per": load_options._yield_per} @@ -457,7 +455,7 @@ class ORMCompileState(AbstractORMCompileState): if plugin_subject: bind_arguments["mapper"] = plugin_subject.mapper - if load_options._autoflush: + if not is_pre_event and load_options._autoflush: session._autoflush() return statement, execution_options @@ -483,6 +481,7 @@ class ORMCompileState(AbstractORMCompileState): load_options = execution_options.get( "_sa_orm_load_options", QueryContext.default_load_options ) + if compile_state.compile_options._is_star: return result @@ -3119,6 +3118,6 @@ class _IdentityTokenEntity(_ORMColumnEntity): def row_processor(self, context, result): def getter(row): - return context.load_options._refresh_identity_token + return context.load_options._identity_token return getter, self._label_name, self._extra_entities diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 6e7695f86..f331cd63b 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -701,7 +701,7 @@ def _set_get_options( if only_load_props: compile_options["_only_load_props"] = frozenset(only_load_props) if identity_token: - load_options["_refresh_identity_token"] = identity_token + load_options["_identity_token"] = identity_token if load_options: load_opt += load_options diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 01db08eb4..d2bd930ff 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -470,7 +470,7 @@ class Query( if only_load_props: compile_options["_only_load_props"] = frozenset(only_load_props) if identity_token: - load_options["_refresh_identity_token"] = identity_token + load_options["_identity_token"] = identity_token if load_options: self.load_options += load_options diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index bf3df0599..8b5f7c88a 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -267,6 +267,7 @@ class ORMExecuteState(util.MemoizedSlots): "execution_options", "local_execution_options", "bind_arguments", + "identity_token", "_compile_state_cls", "_starting_event_idx", "_events_todo", @@ -579,9 +580,8 @@ class ORMExecuteState(util.MemoizedSlots): def _is_crud(self) -> bool: return isinstance(self.statement, (dml.Update, dml.Delete)) - def update_execution_options(self, **opts: _ExecuteOptions) -> None: + def update_execution_options(self, **opts: Any) -> None: """Update the local execution options with new values.""" - # TODO: no coverage self.local_execution_options = self.local_execution_options.union(opts) def _orm_compile_options( @@ -1912,27 +1912,10 @@ class Session(_SessionClassMethods, EventTarget): ) else: compile_state_cls = None + bind_arguments.setdefault("clause", statement) execution_options = util.coerce_to_immutabledict(execution_options) - if compile_state_cls is not None: - ( - statement, - execution_options, - ) = compile_state_cls.orm_pre_session_exec( - self, - statement, - params, - execution_options, - bind_arguments, - _parent_execute_state is not None, - ) - else: - bind_arguments.setdefault("clause", statement) - execution_options = execution_options.union( - {"future_result": True} - ) - if _parent_execute_state: events_todo = _parent_execute_state._remaining_events() else: @@ -1941,6 +1924,25 @@ class Session(_SessionClassMethods, EventTarget): events_todo = list(events_todo) + [_add_event] if events_todo: + if compile_state_cls is not None: + # for event handlers, do the orm_pre_session_exec + # pass ahead of the event handlers, so that things like + # .load_options, .update_delete_options etc. are populated. + # is_pre_event=True allows the hook to hold off on things + # it doesn't want to do twice, including autoflush as well + # as "pre fetch" for DML, etc. + ( + statement, + execution_options, + ) = compile_state_cls.orm_pre_session_exec( + self, + statement, + params, + execution_options, + bind_arguments, + True, + ) + orm_exec_state = ORMExecuteState( self, statement, @@ -1962,6 +1964,24 @@ class Session(_SessionClassMethods, EventTarget): statement = orm_exec_state.statement execution_options = orm_exec_state.local_execution_options + if compile_state_cls is not None: + # now run orm_pre_session_exec() "for real". if there were + # event hooks, this will re-run the steps that interpret + # new execution_options into load_options / update_delete_options, + # which we assume the event hook might have updated. + # autoflush will also be invoked in this step if enabled. + ( + statement, + execution_options, + ) = compile_state_cls.orm_pre_session_exec( + self, + statement, + params, + execution_options, + bind_arguments, + False, + ) + bind = self.get_bind(**bind_arguments) conn = self._connection_for_bind(bind) @@ -2379,6 +2399,7 @@ class Session(_SessionClassMethods, EventTarget): bind: Optional[_SessionBind] = None, _sa_skip_events: Optional[bool] = None, _sa_skip_for_implicit_returning: bool = False, + **kw: Any, ) -> Union[Engine, Connection]: """Return a "bind" to which this :class:`.Session` is bound. @@ -2653,6 +2674,8 @@ class Session(_SessionClassMethods, EventTarget): identity_token: Any = None, passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, lazy_loaded_from: Optional[InstanceState[Any]] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, ) -> Union[Optional[_O], LoaderCallableStatus]: """Locate an object in the identity map. @@ -3262,6 +3285,7 @@ class Session(_SessionClassMethods, EventTarget): with_for_update: Optional[ForUpdateArg] = None, identity_token: Optional[Any] = None, execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, ) -> Optional[_O]: """Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -3355,6 +3379,13 @@ class Session(_SessionClassMethods, EventTarget): :ref:`orm_queryguide_execution_options` - ORM-specific execution options + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + .. versionadded: 2.0.0b5 + :return: The object instance, or ``None``. """ @@ -3367,6 +3398,7 @@ class Session(_SessionClassMethods, EventTarget): with_for_update=with_for_update, identity_token=identity_token, execution_options=execution_options, + bind_arguments=bind_arguments, ) def _get_impl( @@ -3379,7 +3411,8 @@ class Session(_SessionClassMethods, EventTarget): populate_existing: bool = False, with_for_update: Optional[ForUpdateArg] = None, identity_token: Optional[Any] = None, - execution_options: Optional[OrmExecuteOptionsParameter] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, ) -> Optional[_O]: # convert composite types to individual args @@ -3453,7 +3486,11 @@ class Session(_SessionClassMethods, EventTarget): ): instance = self._identity_lookup( - mapper, primary_key_identity, identity_token=identity_token + mapper, + primary_key_identity, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, ) if instance is not None: @@ -3484,13 +3521,14 @@ class Session(_SessionClassMethods, EventTarget): if options: statement = statement.options(*options) - if execution_options: - statement = statement.execution_options(**execution_options) return db_load_fn( self, statement, primary_key_identity, load_options=load_options, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, ) def merge( diff --git a/test/ext/test_deprecations.py b/test/ext/test_deprecations.py index 09f904487..97c4172ba 100644 --- a/test/ext/test_deprecations.py +++ b/test/ext/test_deprecations.py @@ -1,3 +1,5 @@ +from sqlalchemy import Column +from sqlalchemy import Integer from sqlalchemy import testing from sqlalchemy.ext.automap import automap_base from sqlalchemy.ext.horizontal_shard import ShardedSession @@ -68,7 +70,7 @@ class HorizontalShardTest(fixtures.TestBase): m1 = mock.Mock() with testing.expect_deprecated( - "The ``query_choser`` parameter is deprecated; please use" + "The ``query_chooser`` parameter is deprecated; please use" ): s = ShardedSession( shard_chooser=m1.shard_chooser, @@ -80,3 +82,30 @@ class HorizontalShardTest(fixtures.TestBase): s.execute_chooser(m2) eq_(m1.mock_calls, [mock.call.query_chooser(m2.statement)]) + + def test_id_chooser(self, decl_base): + class A(decl_base): + __tablename__ = "a" + id = Column(Integer, primary_key=True) + + m1 = mock.Mock() + + with testing.expect_deprecated( + "The ``id_chooser`` parameter is deprecated; please use" + ): + s = ShardedSession( + shard_chooser=m1.shard_chooser, + id_chooser=m1.id_chooser, + execute_chooser=m1.execute_chooser, + ) + + m2 = mock.Mock() + s.identity_chooser( + A.__mapper__, + m2.primary_key, + lazy_loaded_from=m2.lazy_loaded_from, + execution_options=m2.execution_options, + bind_arguments=m2.bind_arguments, + ) + + eq_(m1.mock_calls, [mock.call.id_chooser(mock.ANY, m2.primary_key)]) diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index ab4a24f71..8e5d09cab 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -28,6 +28,7 @@ from sqlalchemy.pool import SingletonThreadPool from sqlalchemy.sql import operators from sqlalchemy.sql import Select from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_deprecated from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import provision @@ -109,7 +110,15 @@ class ShardTest: else: return shard_chooser(mapper, instance.location) - def id_chooser(query, ident): + def identity_chooser( + mapper, + primary_key, + *, + lazy_loaded_from, + execution_options, + bind_arguments, + **kw, + ): return ["north_america", "asia", "europe", "south_america"] def execute_chooser(orm_context): @@ -144,7 +153,7 @@ class ShardTest: "south_america": db4, }, shard_chooser=shard_chooser, - id_chooser=id_chooser, + identity_chooser=identity_chooser, execute_chooser=execute_chooser, ) @@ -189,7 +198,7 @@ class ShardTest: tokyo.reports.append(Report(80.0, id_=1)) newyork.reports.append(Report(75, id_=1)) quito.reports.append(Report(85)) - sess = sharded_session(future=True) + sess = sharded_session() for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: sess.add(c) sess.flush() @@ -589,6 +598,68 @@ class DistinctEngineShardTest(ShardTest, fixtures.MappedTest): ) +class LegacyAPIShardTest(DistinctEngineShardTest): + @classmethod + def setup_session(cls): + global sharded_session + shard_lookup = { + "North America": "north_america", + "Asia": "asia", + "Europe": "europe", + "South America": "south_america", + } + + def shard_chooser(mapper, instance, clause=None): + if isinstance(instance, WeatherLocation): + return shard_lookup[instance.continent] + else: + return shard_chooser(mapper, instance.location) + + def id_chooser(query, primary_key): + return ["north_america", "asia", "europe", "south_america"] + + def query_chooser(query): + ids = [] + + class FindContinent(sql.ClauseVisitor): + def visit_binary(self, binary): + if binary.left.shares_lineage( + weather_locations.c.continent + ): + if binary.operator == operators.eq: + ids.append(shard_lookup[binary.right.value]) + elif binary.operator == operators.in_op: + for value in binary.right.value: + ids.append(shard_lookup[value]) + + if isinstance(query, Select) and query.whereclause is not None: + FindContinent().traverse(query.whereclause) + if len(ids) == 0: + return ["north_america", "asia", "europe", "south_america"] + else: + return ids + + sm = sessionmaker(class_=ShardedSession, autoflush=True) + sm.configure( + shards={ + "north_america": db1, + "asia": db2, + "europe": db3, + "south_america": db4, + }, + shard_chooser=shard_chooser, + id_chooser=id_chooser, + query_chooser=query_chooser, + ) + + def sharded_session(): + with expect_deprecated( + "The ``id_chooser`` parameter is deprecated", + "The ``query_chooser`` parameter is deprecated", + ): + return sm() + + class AttachedFileShardTest(ShardTest, fixtures.MappedTest): """Use modern schema conventions along with SQLite ATTACH.""" @@ -723,7 +794,7 @@ class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest): session = ShardedSession( shards={"test": testing.db}, shard_chooser=lambda *args: "test", - id_chooser=lambda *args: None, + identity_chooser=lambda *args: None, execute_chooser=lambda *args: ["test"], ) @@ -764,7 +835,7 @@ class RefreshDeferExpireTest(fixtures.DeclarativeMappedTest): return ShardedSession( shards={"main": testing.db}, shard_chooser=lambda *args: "main", - id_chooser=lambda *args: ["fake", "main"], + identity_chooser=lambda *args: ["fake", "main"], execute_chooser=lambda *args: ["fake", "main"], **kw, ) @@ -843,15 +914,23 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): else: assert False - def id_chooser(query, ident): - assert query.lazy_loaded_from - if isinstance(query.lazy_loaded_from.obj(), Book): - token = shard_for_book(query.lazy_loaded_from.obj()) - assert query.lazy_loaded_from.identity_token == token + def identity_chooser( + mapper, + primary_key, + *, + lazy_loaded_from, + execution_options, + bind_arguments, + **kw, + ): + assert lazy_loaded_from + if isinstance(lazy_loaded_from.obj(), Book): + token = shard_for_book(lazy_loaded_from.obj()) + assert lazy_loaded_from.identity_token == token - return [query.lazy_loaded_from.identity_token] + return [lazy_loaded_from.identity_token] - def no_query_chooser(orm_context): + def execute_chooser(orm_context): if ( orm_context.statement.column_descriptions[0]["type"] is Book and lazy_load_book @@ -878,8 +957,8 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): session = ShardedSession( shards={"test": db1, "test2": db2}, shard_chooser=shard_chooser, - id_chooser=id_chooser, - execute_chooser=no_query_chooser, + identity_chooser=identity_chooser, + execute_chooser=execute_chooser, ) return session diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 56d2815fa..05d5d376d 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -1,6 +1,7 @@ from unittest.mock import ANY from unittest.mock import call from unittest.mock import Mock +from unittest.mock import patch import sqlalchemy as sa from sqlalchemy import bindparam @@ -375,7 +376,6 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): result.context.execution_options, { "four": True, - "future_result": True, "one": True, "three": True, "two": True, @@ -741,7 +741,6 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): { "statement_two": True, "statement_four": True, - "future_result": True, "one": True, "two": True, "three": True, @@ -751,6 +750,42 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): }, ) + @testing.variation("session_start", [True, False]) + @testing.variation("dest_autoflush", [True, False]) + @testing.variation("stmt_type", ["select", "bulk", "dml"]) + def test_autoflush_change(self, session_start, dest_autoflush, stmt_type): + User = self.classes.User + + sess = fixture_session(autoflush=session_start) + + @event.listens_for(sess, "do_orm_execute") + def do_orm_execute(ctx): + ctx.update_execution_options(autoflush=dest_autoflush) + + with patch.object(sess, "_autoflush") as m1: + if stmt_type.select: + sess.execute(select(User)) + elif stmt_type.bulk: + sess.execute( + insert(User), + [ + {"id": 1, "name": "n1"}, + {"id": 2, "name": "n2"}, + {"id": 3, "name": "n3"}, + ], + ) + elif stmt_type.dml: + sess.execute( + update(User).where(User.id == 2).values(name="nn") + ) + else: + stmt_type.fail() + + if dest_autoflush: + eq_(m1.mock_calls, [call()]) + else: + eq_(m1.mock_calls, []) + class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): run_inserts = None diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 7966006cf..9e303a778 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -5507,6 +5507,7 @@ class YieldTest(_fixtures.FixtureTest): @event.listens_for(sess, "do_orm_execute") def check(ctx): eq_(ctx.load_options._yield_per, 15) + return eq_( { k: v @@ -5516,7 +5517,6 @@ class YieldTest(_fixtures.FixtureTest): { "yield_per": 15, "foo": "bar", - "future_result": True, }, ) @@ -5535,6 +5535,7 @@ class YieldTest(_fixtures.FixtureTest): @event.listens_for(sess, "do_orm_execute") def check(ctx): eq_(ctx.load_options._yield_per, 15) + eq_( { k: v @@ -5543,7 +5544,6 @@ class YieldTest(_fixtures.FixtureTest): }, { "yield_per": 15, - "future_result": True, }, ) @@ -5553,8 +5553,8 @@ class YieldTest(_fixtures.FixtureTest): assert isinstance( result.raw.cursor_strategy, _cursor.BufferedRowCursorFetchStrategy ) + eq_(result._yield_per, 15) eq_(result.raw.cursor_strategy._max_row_buffer, 15) - eq_(len(result.all()), 4) def test_no_joinedload_opt(self): @@ -7515,23 +7515,80 @@ class ExecutionOptionsTest(QueryTest): assert u.addresses[0].email_address == "jack@bean.com" assert u.orders[1].items[2].description == "item 5" - def test_option_transfer_future(self): + @testing.variation("source", ["statement", "do_orm_exec"]) + def test_execution_options_to_load_options(self, source): User = self.classes.User - stmt = select(User).execution_options( - populate_existing=True, autoflush=False, yield_per=10 - ) + + stmt = select(User) + + if source.statement: + stmt = stmt.execution_options( + populate_existing=True, + autoflush=False, + yield_per=10, + identity_token="some_token", + ) s = fixture_session() m1 = mock.Mock() - event.listen(s, "do_orm_execute", m1) + def do_orm_execute(ctx): + m1(ctx) + if source.do_orm_exec: + ctx.update_execution_options( + autoflush=False, + populate_existing=True, + yield_per=10, + identity_token="some_token", + ) + + event.listen(s, "do_orm_execute", do_orm_execute) + + from sqlalchemy.orm import loading + + with mock.patch.object(loading, "instances") as m2: + s.execute(stmt) + + if source.do_orm_exec: + # in do_orm_exec version, load options are empty, our new + # execution options have not yet been transferred. + eq_( + m1.mock_calls[0][1][0].load_options, + QueryContext.default_load_options, + ) + elif source.statement: + # in statement version, the incoming exc options have been + # transferred, because the fact that do_orm_exec is used + # means the options were set up up front for the benefit + # of the do_orm_exec hook itself. + eq_( + m1.mock_calls[0][1][0].load_options, + QueryContext.default_load_options( + _autoflush=False, + _populate_existing=True, + _yield_per=10, + _identity_token="some_token", + ), + ) + + # py37 mock does not have .args + call_args = m2.mock_calls[0][1] - s.execute(stmt) + cursor = call_args[0] + cursor.all() + # the orm_pre_session_exec() method + # was called unconditionally after the event handler + # in both cases (i.e. a second time) so options were transferred + # even if we set them up in the do_orm_exec hook only. + query_context = call_args[1] eq_( - m1.mock_calls[0][1][0].load_options, + query_context.load_options, QueryContext.default_load_options( - _autoflush=False, _populate_existing=True, _yield_per=10 + _autoflush=False, + _populate_existing=True, + _yield_per=10, + _identity_token="some_token", ), ) diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 79ea5d170..921c55f74 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import inspect as _py_inspect import pickle +from typing import TYPE_CHECKING import sqlalchemy as sa from sqlalchemy import delete @@ -48,6 +51,9 @@ from sqlalchemy.testing.util import gc_collect from sqlalchemy.util.compat import inspect_getfullargspec from test.orm import _fixtures +if TYPE_CHECKING: + from sqlalchemy.orm import ORMExecuteState + class ExecutionTest(_fixtures.FixtureTest): run_inserts = None @@ -563,7 +569,10 @@ class SessionUtilTest(_fixtures.FixtureTest): u1, ) - def test_get_execution_option(self): + @testing.variation( + "arg", ["execution_options", "identity_token", "bind_arguments"] + ) + def test_get_arguments(self, arg: testing.Variation) -> None: users, User = self.tables.users, self.classes.User self.mapper_registry.map_imperatively(User, users) @@ -571,12 +580,28 @@ class SessionUtilTest(_fixtures.FixtureTest): called = False @event.listens_for(sess, "do_orm_execute") - def check(ctx): + def check(ctx: ORMExecuteState) -> None: nonlocal called called = True - eq_(ctx.execution_options["foo"], "bar") - sess.get(User, 42, execution_options={"foo": "bar"}) + if arg.execution_options: + eq_(ctx.execution_options["foo"], "bar") + elif arg.bind_arguments: + eq_(ctx.bind_arguments["foo"], "bar") + elif arg.identity_token: + eq_(ctx.load_options._identity_token, "foobar") + else: + arg.fail() + + if arg.execution_options: + sess.get(User, 42, execution_options={"foo": "bar"}) + elif arg.bind_arguments: + sess.get(User, 42, bind_arguments={"foo": "bar"}) + elif arg.identity_token: + sess.get(User, 42, identity_token="foobar") + else: + arg.fail() + sess.close() is_true(called) |
