diff options
335 files changed, 8207 insertions, 6916 deletions
@@ -1,4 +1,4 @@ -Copyright 2005-2020 SQLAlchemy authors and contributors <see AUTHORS file>. +Copyright 2005-2021 SQLAlchemy authors and contributors <see AUTHORS file>. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/doc/build/changelog/unreleased_13/5808.rst b/doc/build/changelog/unreleased_13/5808.rst new file mode 100644 index 000000000..b6625c050 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5808.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: usecase, mysql + :tickets: 5808 + + Casting to ``FLOAT`` is now supported in MySQL >= (8, 0, 17) and + MariaDb >= (10, 4, 5).
\ No newline at end of file diff --git a/doc/build/changelog/unreleased_13/5813.rst b/doc/build/changelog/unreleased_13/5813.rst new file mode 100644 index 000000000..d6483a26f --- /dev/null +++ b/doc/build/changelog/unreleased_13/5813.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, oracle + :tickets: 5813 + + Fixed regression in Oracle dialect introduced by :ticket:`4894` in + SQLAlchemy 1.3.11 where use of a SQL expression in RETURNING for an UPDATE + would fail to compile, due to a check for "server_default" when an + arbitrary SQL expression is not a column. + diff --git a/doc/build/changelog/unreleased_13/5816.rst b/doc/build/changelog/unreleased_13/5816.rst new file mode 100644 index 000000000..5049622a8 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5816.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, sql + :tickets: 5816 + + Fixed bug where making use of the :meth:`.TypeEngine.with_variant` method + on a :class:`.TypeDecorator` type would fail to take into account the + dialect-specific mappings in use, due to a rule in :class:`.TypeDecorator` + that was instead attempting to check for chains of :class:`.TypeDecorator` + instances. + diff --git a/doc/build/changelog/unreleased_13/5821.rst b/doc/build/changelog/unreleased_13/5821.rst new file mode 100644 index 000000000..a2c5d4082 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5821.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, mysql + :tickets: 5821 + + Fixed deprecation warnings that arose as a result of the release of PyMySQL + 1.0, including deprecation warnings for the "db" and "passwd" parameters + now replaced with "database" and "password". + diff --git a/doc/build/changelog/unreleased_14/5811.rst b/doc/build/changelog/unreleased_14/5811.rst new file mode 100644 index 000000000..5ce358ca4 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5811.rst @@ -0,0 +1,30 @@ +.. change:: + :tags: bug, asyncio + :tickets: 5811 + + Implemented "connection-binding" for :class:`.AsyncSession`, the ability to + pass an :class:`.AsyncConnection` to create an :class:`.AsyncSession`. + Previously, this use case was not implemented and would use the associated + engine when the connection were passed. This fixes the issue where the + "join a session to an external transaction" use case would not work + correctly for the :class:`.AsyncSession`. Additionally, added methods + :meth:`.AsyncConnection.in_transaction`, + :meth:`.AsyncConnection.in_nested_transaction`, + :meth:`.AsyncConnection.get_transaction`, + :meth:`.AsyncConnection.get_nested_transaction` and + :attr:`.AsyncConnection.info` attribute. + +.. change:: + :tags: usecase, asyncio + + The :class:`.AsyncEngine`, :class:`.AsyncConnection` and + :class:`.AsyncTransaction` objects may be compared using Python ``==`` or + ``!=``, which will compare the two given objects based on the "sync" object + they are proxying towards. This is useful as there are cases particularly + for :class:`.AsyncTransaction` where multiple instances of + :class:`.AsyncTransaction` can be proxying towards the same sync + :class:`_engine.Transaction`, and are actually equivalent. The + :meth:`.AsyncConnection.get_transaction` method will currently return a new + proxying :class:`.AsyncTransaction` each time as the + :class:`.AsyncTransaction` is not otherwise statefully associated with its + originating :class:`.AsyncConnection`.
\ No newline at end of file diff --git a/doc/build/changelog/unreleased_14/asyncpg_prepared_cache.rst b/doc/build/changelog/unreleased_14/asyncpg_prepared_cache.rst new file mode 100644 index 000000000..eee6fb105 --- /dev/null +++ b/doc/build/changelog/unreleased_14/asyncpg_prepared_cache.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: postgresql, performance + + Enhanced the performance of the asyncpg dialect by caching the asyncpg + PreparedStatement objects on a per-connection basis. For a test case that + makes use of the same statement on a set of pooled connections this appears + to grant a 10-20% speed improvement. The cache size is adjustable and may + also be disabled. + + .. seealso:: + + :ref:`asyncpg_prepared_statement_cache` diff --git a/doc/build/conf.py b/doc/build/conf.py index de1612d69..c2c055caf 100644 --- a/doc/build/conf.py +++ b/doc/build/conf.py @@ -183,7 +183,7 @@ master_doc = "contents" # General information about the project. project = u"SQLAlchemy" -copyright = u"2007-2020, the SQLAlchemy authors and contributors" # noqa +copyright = u"2007-2021, the SQLAlchemy authors and contributors" # noqa # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/doc/build/copyright.rst b/doc/build/copyright.rst index 4df6e9634..1fbdbbee1 100644 --- a/doc/build/copyright.rst +++ b/doc/build/copyright.rst @@ -6,7 +6,7 @@ Appendix: Copyright This is the MIT license: `<http://www.opensource.org/licenses/mit-license.php>`_ -Copyright (c) 2005-2020 Michael Bayer and contributors. +Copyright (c) 2005-2021 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael Bayer. Permission is hereby granted, free of charge, to any person obtaining a copy of this diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst index 09e76f2ee..aed01678a 100644 --- a/doc/build/orm/extensions/asyncio.rst +++ b/doc/build/orm/extensions/asyncio.rst @@ -149,6 +149,8 @@ It is then used in a Python asynchronous context manager (i.e. ``async with:`` s so that it is automatically closed at the end of the block; this is equivalent to calling the :meth:`_asyncio.AsyncSession.close` method. +.. _session_run_sync: + Adapting ORM Lazy loads to asyncio ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index e3b054d0e..607585460 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -1,5 +1,5 @@ # sqlalchemy/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -61,10 +61,10 @@ from .sql import literal_column # noqa from .sql import modifier # noqa from .sql import not_ # noqa from .sql import null # noqa -from .sql import nullsfirst # noqa; deprecated 1.4; see #5435 -from .sql import nullslast # noqa; deprecated 1.4; see #5435 from .sql import nulls_first # noqa from .sql import nulls_last # noqa +from .sql import nullsfirst # noqa +from .sql import nullslast # noqa from .sql import or_ # noqa from .sql import outerjoin # noqa from .sql import outparam # noqa diff --git a/lib/sqlalchemy/cextension/immutabledict.c b/lib/sqlalchemy/cextension/immutabledict.c index 2a19cf3ad..53e59c195 100644 --- a/lib/sqlalchemy/cextension/immutabledict.c +++ b/lib/sqlalchemy/cextension/immutabledict.c @@ -1,6 +1,6 @@ /* immuatbledict.c -Copyright (C) 2020 the SQLAlchemy authors and contributors <see AUTHORS file> +Copyright (C) 2005-2021 the SQLAlchemy authors and contributors <see AUTHORS file> This module is part of SQLAlchemy and is released under the MIT License: http://www.opensource.org/licenses/mit-license.php diff --git a/lib/sqlalchemy/cextension/processors.c b/lib/sqlalchemy/cextension/processors.c index 0dd526d5d..b6f37a7bb 100644 --- a/lib/sqlalchemy/cextension/processors.c +++ b/lib/sqlalchemy/cextension/processors.c @@ -1,6 +1,6 @@ /* processors.c -Copyright (C) 2010-2020 the SQLAlchemy authors and contributors <see AUTHORS file> +Copyright (C) 2010-2021 the SQLAlchemy authors and contributors <see AUTHORS file> Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c index f99236e1e..89fd6947a 100644 --- a/lib/sqlalchemy/cextension/resultproxy.c +++ b/lib/sqlalchemy/cextension/resultproxy.c @@ -1,6 +1,6 @@ /* resultproxy.c -Copyright (C) 2010-2020 the SQLAlchemy authors and contributors <see AUTHORS file> +Copyright (C) 2010-2021 the SQLAlchemy authors and contributors <see AUTHORS file> Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/cextension/utils.c b/lib/sqlalchemy/cextension/utils.c index c612094dc..e06843c9d 100644 --- a/lib/sqlalchemy/cextension/utils.c +++ b/lib/sqlalchemy/cextension/utils.c @@ -1,6 +1,6 @@ /* utils.c -Copyright (C) 2012-2020 the SQLAlchemy authors and contributors <see AUTHORS file> +Copyright (C) 2012-2021 the SQLAlchemy authors and contributors <see AUTHORS file> This module is part of SQLAlchemy and is released under the MIT License: http://www.opensource.org/licenses/mit-license.php diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py index c1a3c1ef6..d2b9ba09f 100644 --- a/lib/sqlalchemy/connectors/__init__.py +++ b/lib/sqlalchemy/connectors/__init__.py @@ -1,5 +1,5 @@ # connectors/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py index e630f36e3..cd40863e4 100644 --- a/lib/sqlalchemy/connectors/mxodbc.py +++ b/lib/sqlalchemy/connectors/mxodbc.py @@ -1,5 +1,5 @@ # connectors/mxodbc.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 780161304..aa14cd9aa 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -1,5 +1,5 @@ # connectors/pyodbc.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index 3e636871b..276441be6 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -1,5 +1,5 @@ # databases/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py index 4a79608d9..22b47597a 100644 --- a/lib/sqlalchemy/dialects/__init__.py +++ b/lib/sqlalchemy/dialects/__init__.py @@ -1,5 +1,5 @@ # dialects/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py index dae499c62..24a2daad0 100644 --- a/lib/sqlalchemy/dialects/firebird/__init__.py +++ b/lib/sqlalchemy/dialects/firebird/__init__.py @@ -1,5 +1,5 @@ # firebird/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 8a110fbd3..82861e30f 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -1,5 +1,5 @@ # firebird/base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/firebird/fdb.py b/lib/sqlalchemy/dialects/firebird/fdb.py index a20aab8d8..14954b073 100644 --- a/lib/sqlalchemy/dialects/firebird/fdb.py +++ b/lib/sqlalchemy/dialects/firebird/fdb.py @@ -1,5 +1,5 @@ # firebird/fdb.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py index c6be8367b..4c937e0de 100644 --- a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py +++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -1,5 +1,5 @@ # firebird/kinterbasdb.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py index d987efa51..d6d0a4711 100644 --- a/lib/sqlalchemy/dialects/mssql/__init__.py +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -1,5 +1,5 @@ # mssql/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index bc5480e2c..538679fcf 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1,5 +1,5 @@ # mssql/base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index 974a55963..c37920797 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -1,5 +1,5 @@ # mssql/information_schema.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py index b274c2a2b..da4e45f07 100644 --- a/lib/sqlalchemy/dialects/mssql/mxodbc.py +++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py @@ -1,5 +1,5 @@ # mssql/mxodbc.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 962d1af01..5110badb9 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -1,5 +1,5 @@ # mssql/pymssql.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index dac6098c4..c94b50678 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -1,5 +1,5 @@ # mssql/pyodbc.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py index c6781c168..20dd68d8f 100644 --- a/lib/sqlalchemy/dialects/mysql/__init__.py +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -1,5 +1,5 @@ # mysql/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index f560ece33..f0665133f 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -1,5 +1,5 @@ # mysql/aiomysql.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors <see AUTHORS +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors <see AUTHORS # file> # # This module is part of SQLAlchemy and is released under @@ -34,6 +34,7 @@ handling. from .pymysql import MySQLDialect_pymysql from ... import pool +from ... import util from ...util.concurrency import await_fallback from ...util.concurrency import await_only @@ -226,7 +227,7 @@ class AsyncAdapt_aiomysql_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) - if async_fallback: + if util.asbool(async_fallback): return AsyncAdaptFallback_aiomysql_connection( self, await_fallback(self.aiomysql.connect(*arg, **kw)), @@ -244,6 +245,8 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): supports_server_side_cursors = True _sscursor = AsyncAdapt_aiomysql_ss_cursor + is_async = True + @classmethod def dbapi(cls): return AsyncAdapt_aiomysql_dbapi( @@ -251,14 +254,19 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): ) @classmethod - def get_pool_class(self, url): - return pool.AsyncAdaptedQueuePool + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool def create_connect_args(self, url): - args, kw = super(MySQLDialect_aiomysql, self).create_connect_args(url) - if "passwd" in kw: - kw["password"] = kw.pop("passwd") - return args, kw + return super(MySQLDialect_aiomysql, self).create_connect_args( + url, _translate_args=dict(username="user", database="db") + ) def is_disconnect(self, e, connection, cursor): if super(MySQLDialect_aiomysql, self).is_disconnect( diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 7a4d3261f..63dbbd83e 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1,5 +1,5 @@ # mysql/base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -1624,6 +1624,11 @@ class MySQLCompiler(compiler.SQLCompiler): return self.dialect.type_compiler.process(type_).replace( "NUMERIC", "DECIMAL" ) + elif ( + isinstance(type_, sqltypes.Float) + and self.dialect._support_float_cast + ): + return self.dialect.type_compiler.process(type_) else: return None @@ -1631,7 +1636,7 @@ class MySQLCompiler(compiler.SQLCompiler): type_ = self.process(cast.typeclause) if type_ is None: util.warn( - "Datatype %s does not support CAST on MySQL; " + "Datatype %s does not support CAST on MySQL/MariaDb; " "the CAST will be skipped." % self.dialect.type_compiler.process(cast.typeclause.type) ) @@ -2900,6 +2905,17 @@ class MySQLDialect(default.DefaultDialect): ) @property + def _support_float_cast(self): + if not self.server_version_info: + return False + elif self.is_mariadb: + # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/ + return self.server_version_info >= (10, 4, 5) + else: + # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa + return self.server_version_info >= (8, 0, 17) + + @property def _is_mariadb(self): return self.is_mariadb diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py index f1d0aedaf..0d7ba5594 100644 --- a/lib/sqlalchemy/dialects/mysql/cymysql.py +++ b/lib/sqlalchemy/dialects/mysql/cymysql.py @@ -1,5 +1,5 @@ # mysql/cymysql.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py index 2bc25585e..c44b60226 100644 --- a/lib/sqlalchemy/dialects/mysql/enumerated.py +++ b/lib/sqlalchemy/dialects/mysql/enumerated.py @@ -1,5 +1,5 @@ # mysql/enumerated.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py index a1c8258b0..655e68cad 100644 --- a/lib/sqlalchemy/dialects/mysql/json.py +++ b/lib/sqlalchemy/dialects/mysql/json.py @@ -1,5 +1,5 @@ # mysql/json.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py index 4e0b4e0a9..ddc11f6e6 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -1,5 +1,5 @@ # mysql/mariadbconnector.py -# Copyright (C) 2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index 66a429d35..5ed675b13 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -1,5 +1,5 @@ # mysql/mysqlconnector.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 605407f46..0318b5077 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -1,5 +1,5 @@ # mysql/mysqldb.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -177,10 +177,13 @@ class MySQLDialect_mysqldb(MySQLDialect): connection, additional_tests ) - def create_connect_args(self, url): - opts = url.translate_connect_args( - database="db", username="user", password="passwd" - ) + def create_connect_args(self, url, _translate_args=None): + if _translate_args is None: + _translate_args = dict( + database="db", username="user", password="passwd" + ) + + opts = url.translate_connect_args(**_translate_args) opts.update(url.query) util.coerce_kw_type(opts, "compress", bool) diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py index 7c2b220b4..5c8c7b7c2 100644 --- a/lib/sqlalchemy/dialects/mysql/oursql.py +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -1,5 +1,5 @@ # mysql/oursql.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index 7d7770105..0c321f854 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -1,5 +1,5 @@ # mysql/pymysql.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -57,6 +57,13 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): def dbapi(cls): return __import__("pymysql") + def create_connect_args(self, url, _translate_args=None): + if _translate_args is None: + _translate_args = dict(username="user") + return super(MySQLDialect_pymysql, self).create_connect_args( + url, _translate_args=_translate_args + ) + def is_disconnect(self, e, connection, cursor): if super(MySQLDialect_pymysql, self).is_disconnect( e, connection, cursor diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 5a696562e..048586b59 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -1,5 +1,5 @@ # mysql/pyodbc.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py index 14fb97c64..453a15d7d 100644 --- a/lib/sqlalchemy/dialects/mysql/reflection.py +++ b/lib/sqlalchemy/dialects/mysql/reflection.py @@ -1,5 +1,5 @@ # mysql/reflection.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/types.py b/lib/sqlalchemy/dialects/mysql/types.py index 594975000..9a6b804b3 100644 --- a/lib/sqlalchemy/dialects/mysql/types.py +++ b/lib/sqlalchemy/dialects/mysql/types.py @@ -1,5 +1,5 @@ # mysql/types.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py index a4dee02ff..a6af7d8c4 100644 --- a/lib/sqlalchemy/dialects/oracle/__init__.py +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -1,5 +1,5 @@ # oracle/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 94c4aacbd..1cc8b7aef 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1,5 +1,5 @@ # oracle/base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -1017,6 +1017,7 @@ class OracleCompiler(compiler.SQLCompiler): ): if ( self.isupdate + and isinstance(column, sa_schema.Column) and isinstance(column.server_default, Computed) and not self.dialect._supports_update_returning_computed_cols ): diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index e7db9272f..042443692 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -1,4 +1,4 @@ -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 2762a9971..108e27c8f 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -1,5 +1,5 @@ # postgresql/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index dacf1e2c2..ad71db89e 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -1,5 +1,5 @@ # postgresql/array.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 889293eab..7c6e8fb02 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -1,5 +1,5 @@ # postgresql/asyncpg.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors <see AUTHORS +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors <see AUTHORS # file> # # This module is part of SQLAlchemy and is released under @@ -45,6 +45,57 @@ in conjunction with :func:`_sa.create_engine`:: ``json_deserializer`` when creating the engine with :func:`create_engine` or :func:`create_async_engine`. + +.. _asyncpg_prepared_statement_cache: + +Prepared Statement Cache +-------------------------- + +The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()`` +for all statements. The prepared statement objects are cached after +construction which appears to grant a 10% or more performance improvement for +statement invocation. The cache is on a per-DBAPI connection basis, which +means that the primary storage for prepared statements is within DBAPI +connections pooled within the connection pool. The size of this cache +defaults to 100 statements per DBAPI connection and may be adjusted using the +``prepared_statement_cache_size`` DBAPI argument (note that while this argument +is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the +asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect +argument):: + + + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500") + +To disable the prepared statement cache, use a value of zero:: + + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0") + +.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg. + + +.. warning:: The ``asyncpg`` database driver necessarily uses caches for + PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes + such as ``ENUM`` objects are changed via DDL operations. Additionally, + prepared statements themselves which are optionally cached by SQLAlchemy's + driver as described above may also become "stale" when DDL has been emitted + to the PostgreSQL database which modifies the tables or other objects + involved in a particular prepared statement. + + The SQLAlchemy asyncpg dialect will invalidate these caches within its local + process when statements that represent DDL are emitted on a local + connection, but this is only controllable within a single Python process / + database engine. If DDL changes are made from other database engines + and/or processes, a running application may encounter asyncpg exceptions + ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup + failed for type <oid>")`` if it refers to pooled database connections which + operated upon the previous structures. The SQLAlchemy asyncpg dialect will + recover from these error cases when the driver raises these exceptions by + clearing its internal caches as well as those of the asyncpg driver in + response to them, but cannot prevent them from being raised in the first + place if the cached prepared statement or asyncpg type caches have gone + stale, nor can it retry the statement as the PostgreSQL transaction is + invalidated when these errors occur. + """ # noqa import collections @@ -52,6 +103,7 @@ import decimal import itertools import json as _py_json import re +import time from . import json from .base import _DECIMAL_TYPES @@ -235,9 +287,23 @@ class AsyncpgOID(OID): class PGExecutionContext_asyncpg(PGExecutionContext): + def handle_dbapi_exception(self, e): + if isinstance( + e, + ( + self.dialect.dbapi.InvalidCachedStatementError, + self.dialect.dbapi.InternalServerError, + ), + ): + self.dialect._invalidate_schema_cache() + def pre_exec(self): if self.isddl: - self._dbapi_connection.reset_schema_state() + self.dialect._invalidate_schema_cache() + + self.cursor._invalidate_schema_cache_asof = ( + self.dialect._invalidate_schema_cache_asof + ) if not self.compiled: return @@ -269,6 +335,7 @@ class AsyncAdapt_asyncpg_cursor: "rowcount", "_inputsizes", "_cursor", + "_invalidate_schema_cache_asof", ) server_side = False @@ -282,6 +349,7 @@ class AsyncAdapt_asyncpg_cursor: self.arraysize = 1 self.rowcount = -1 self._inputsizes = None + self._invalidate_schema_cache_asof = 0 def close(self): self._rows[:] = [] @@ -302,25 +370,25 @@ class AsyncAdapt_asyncpg_cursor: ) async def _prepare_and_execute(self, operation, parameters): - # TODO: I guess cache these in an LRU cache, or see if we can - # use some asyncpg concept - - # TODO: would be nice to support the dollar numeric thing - # directly, this is much easier for now if not self._adapt_connection._started: await self._adapt_connection._start_transaction() params = self._parameters() + + # TODO: would be nice to support the dollar numeric thing + # directly, this is much easier for now operation = re.sub(r"\?", lambda m: next(params), operation) + try: - prepared_stmt = await self._connection.prepare(operation) + prepared_stmt, attributes = await self._adapt_connection._prepare( + operation, self._invalidate_schema_cache_asof + ) - attributes = prepared_stmt.get_attributes() if attributes: self.description = [ (attr.name, attr.type.oid, None, None, None, None, None) - for attr in prepared_stmt.get_attributes() + for attr in attributes ] else: self.description = None @@ -350,15 +418,21 @@ class AsyncAdapt_asyncpg_cursor: self._handle_exception(error) def executemany(self, operation, seq_of_parameters): - if not self._adapt_connection._started: - self._adapt_connection.await_( - self._adapt_connection._start_transaction() + adapt_connection = self._adapt_connection + + adapt_connection.await_( + adapt_connection._check_type_cache_invalidation( + self._invalidate_schema_cache_asof ) + ) + + if not adapt_connection._started: + adapt_connection.await_(adapt_connection._start_transaction()) params = self._parameters() operation = re.sub(r"\?", lambda m: next(params), operation) try: - return self._adapt_connection.await_( + return adapt_connection.await_( self._connection.executemany(operation, seq_of_parameters) ) except Exception as error: @@ -485,11 +559,13 @@ class AsyncAdapt_asyncpg_connection: "deferrable", "_transaction", "_started", + "_prepared_statement_cache", + "_invalidate_schema_cache_asof", ) await_ = staticmethod(await_only) - def __init__(self, dbapi, connection): + def __init__(self, dbapi, connection, prepared_statement_cache_size=100): self.dbapi = dbapi self._connection = connection self.isolation_level = self._isolation_setting = "read_committed" @@ -497,6 +573,46 @@ class AsyncAdapt_asyncpg_connection: self.deferrable = False self._transaction = None self._started = False + self._invalidate_schema_cache_asof = time.time() + + if prepared_statement_cache_size: + self._prepared_statement_cache = util.LRUCache( + prepared_statement_cache_size + ) + else: + self._prepared_statement_cache = None + + async def _check_type_cache_invalidation(self, invalidate_timestamp): + if invalidate_timestamp > self._invalidate_schema_cache_asof: + await self._connection.reload_schema_state() + self._invalidate_schema_cache_asof = invalidate_timestamp + + async def _prepare(self, operation, invalidate_timestamp): + await self._check_type_cache_invalidation(invalidate_timestamp) + + cache = self._prepared_statement_cache + if cache is None: + prepared_stmt = await self._connection.prepare(operation) + attributes = prepared_stmt.get_attributes() + return prepared_stmt, attributes + + # asyncpg uses a type cache for the "attributes" which seems to go + # stale independently of the PreparedStatement itself, so place that + # collection in the cache as well. + if operation in cache: + prepared_stmt, attributes, cached_timestamp = cache[operation] + + # preparedstatements themselves also go stale for certain DDL + # changes such as size of a VARCHAR changing, so there is also + # a cross-connection invalidation timestamp + if cached_timestamp > invalidate_timestamp: + return prepared_stmt, attributes + + prepared_stmt = await self._connection.prepare(operation) + attributes = prepared_stmt.get_attributes() + cache[operation] = (prepared_stmt, attributes, time.time()) + + return prepared_stmt, attributes def _handle_exception(self, error): if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error): @@ -551,9 +667,6 @@ class AsyncAdapt_asyncpg_connection: else: return AsyncAdapt_asyncpg_cursor(self) - def reset_schema_state(self): - self.await_(self._connection.reload_schema_state()) - def rollback(self): if self._started: self.await_(self._transaction.rollback()) @@ -586,16 +699,20 @@ class AsyncAdapt_asyncpg_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) - - if async_fallback: + prepared_statement_cache_size = kw.pop( + "prepared_statement_cache_size", 100 + ) + if util.asbool(async_fallback): return AsyncAdaptFallback_asyncpg_connection( self, await_fallback(self.asyncpg.connect(*arg, **kw)), + prepared_statement_cache_size=prepared_statement_cache_size, ) else: return AsyncAdapt_asyncpg_connection( self, await_only(self.asyncpg.connect(*arg, **kw)), + prepared_statement_cache_size=prepared_statement_cache_size, ) class Error(Exception): @@ -628,15 +745,29 @@ class AsyncAdapt_asyncpg_dbapi: class NotSupportedError(DatabaseError): pass + class InternalServerError(InternalError): + pass + + class InvalidCachedStatementError(NotSupportedError): + def __init__(self, message): + super( + AsyncAdapt_asyncpg_dbapi.InvalidCachedStatementError, self + ).__init__( + message + " (SQLAlchemy asyncpg dialect will now invalidate " + "all prepared caches in response to this exception)", + ) + @util.memoized_property def _asyncpg_error_translate(self): import asyncpg return { - asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa + asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501 asyncpg.exceptions.PostgresError: self.Error, asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError, asyncpg.exceptions.InterfaceError: self.InterfaceError, + asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501 + asyncpg.exceptions.InternalServerError: self.InternalServerError, } def Binary(self, value): @@ -729,6 +860,11 @@ class PGDialect_asyncpg(PGDialect): REGCLASS: AsyncpgREGCLASS, }, ) + is_async = True + _invalidate_schema_cache_asof = 0 + + def _invalidate_schema_cache(self): + self._invalidate_schema_cache_asof = time.time() @util.memoized_property def _dbapi_version(self): @@ -786,14 +922,21 @@ class PGDialect_asyncpg(PGDialect): def create_connect_args(self, url): opts = url.translate_connect_args(username="user") - if "port" in opts: - opts["port"] = int(opts["port"]) + opts.update(url.query) + util.coerce_kw_type(opts, "prepared_statement_cache_size", int) + util.coerce_kw_type(opts, "port", int) return ([], opts) @classmethod - def get_pool_class(self, url): - return pool.AsyncAdaptedQueuePool + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool def is_disconnect(self, e, connection, cursor): if connection: diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 9b6c632da..735990a20 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1,5 +1,5 @@ # postgresql/base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index 78cad974f..76dfafd04 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -1,5 +1,5 @@ # postgresql/on_conflict.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 4c8c3fc22..908a0b675 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -1,5 +1,5 @@ # postgresql/ext.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index 15ec2a585..cfd94b9b3 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -1,5 +1,5 @@ # postgresql/hstore.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 63e1656e0..a5dfa40d6 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -1,5 +1,5 @@ # postgresql/json.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 439249157..6e7318272 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -1,5 +1,5 @@ # postgresql/pg8000.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors <see AUTHORS +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors <see AUTHORS # file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py index 575316c61..d345cdfdf 100644 --- a/lib/sqlalchemy/dialects/postgresql/provision.py +++ b/lib/sqlalchemy/dialects/postgresql/provision.py @@ -1,8 +1,11 @@ import time from ... import exc +from ... import inspect from ... import text from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_post_tables +from ...testing.provision import drop_all_schema_objects_pre_tables from ...testing.provision import drop_db from ...testing.provision import log from ...testing.provision import set_default_schema_on_connection @@ -78,3 +81,24 @@ def _postgresql_set_default_schema_on_connection( cursor.execute("SET SESSION search_path='%s'" % schema_name) cursor.close() dbapi_connection.autocommit = existing_autocommit + + +@drop_all_schema_objects_pre_tables.for_db("postgresql") +def drop_all_schema_objects_pre_tables(cfg, eng): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + for xid in conn.execute("select gid from pg_prepared_xacts").scalars(): + conn.execute("ROLLBACK PREPARED '%s'" % xid) + + +@drop_all_schema_objects_post_tables.for_db("postgresql") +def drop_all_schema_objects_post_tables(cfg, eng): + from sqlalchemy.dialects import postgresql + + inspector = inspect(eng) + with eng.begin() as conn: + for enum in inspector.get_enums("*"): + conn.execute( + postgresql.DropEnumType( + postgresql.ENUM(name=enum["name"], schema=enum["schema"]) + ) + ) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 72c36b4a8..2a2a7fa53 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -1,5 +1,5 @@ # postgresql/psycopg2.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py index e4ebbb262..a449f9e65 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -1,5 +1,5 @@ # testing/engines.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py index 8dbd23fe9..64dd7262d 100644 --- a/lib/sqlalchemy/dialects/postgresql/pygresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py @@ -1,5 +1,5 @@ # postgresql/pygresql.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py index bd015a5b8..6e4db217d 100644 --- a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -1,5 +1,5 @@ # postgresql/pypostgresql.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index ddc12c096..1f6f75f6d 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -1,4 +1,4 @@ -# Copyright (C) 2013-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2013-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py index 72402dd92..d12203cbd 100644 --- a/lib/sqlalchemy/dialects/sqlite/__init__.py +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -1,5 +1,5 @@ # sqlite/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index c53b3c228..a4c25e764 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1,5 +1,5 @@ # sqlite/base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py index 2d7ea6e4a..9c8f10f7b 100644 --- a/lib/sqlalchemy/dialects/sqlite/dml.py +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -1,4 +1,4 @@ -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py index 8f72e12fa..659043366 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -1,5 +1,5 @@ # sqlite/pysqlcipher.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index dac04e0ca..8636b7519 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -1,5 +1,5 @@ # sqlite/pysqlite.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py index 03a685b3f..5d3eb8290 100644 --- a/lib/sqlalchemy/dialects/sybase/__init__.py +++ b/lib/sqlalchemy/dialects/sybase/__init__.py @@ -1,5 +1,5 @@ # sybase/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index e3848a9b2..49243be78 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -1,5 +1,5 @@ # sybase/base.py -# Copyright (C) 2010-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2010-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # get_select_precolumns(), limit_clause() implementation # copyright (C) 2007 Fisch Asset Management diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py index d23482357..6b2f07c54 100644 --- a/lib/sqlalchemy/dialects/sybase/mxodbc.py +++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py @@ -1,5 +1,5 @@ # sybase/mxodbc.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py index d11aae1c5..bbd6d968a 100644 --- a/lib/sqlalchemy/dialects/sybase/pyodbc.py +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -1,5 +1,5 @@ # sybase/pyodbc.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py index a36cd74ca..d6d2f2ed2 100644 --- a/lib/sqlalchemy/dialects/sybase/pysybase.py +++ b/lib/sqlalchemy/dialects/sybase/pysybase.py @@ -1,5 +1,5 @@ # sybase/pysybase.py -# Copyright (C) 2010-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2010-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 7523f3b26..2b98261ef 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -1,5 +1,5 @@ # engine/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index d36cf30e9..50f00c025 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1,5 +1,5 @@ # engine/base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index b48cead79..f89be1809 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -1,5 +1,5 @@ # engine/create.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index b06c622e0..80cb98b1c 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1,5 +1,5 @@ # engine/cursor.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index a754ebe58..e5a6384d2 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1,5 +1,5 @@ # engine/default.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -210,6 +210,8 @@ class DefaultDialect(interfaces.Dialect): """ + is_async = False + CACHE_HIT = CACHE_HIT CACHE_MISS = CACHE_MISS CACHING_DISABLED = CACHING_DISABLED diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index ccc6c5968..fb8e5aeb2 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -1,5 +1,5 @@ # sqlalchemy/engine/events.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index b6f7bc49a..b5e23a4c2 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -1,5 +1,5 @@ # engine/interfaces.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py index 6c91d1434..f6cb71d40 100644 --- a/lib/sqlalchemy/engine/mock.py +++ b/lib/sqlalchemy/engine/mock.py @@ -1,5 +1,5 @@ # engine/mock.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index c5d66081a..cff209575 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -1,5 +1,5 @@ # engine/reflection.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 73b07e540..cc3877e05 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -1,5 +1,5 @@ # engine/result.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 60954fcec..ac65d1b18 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -1,5 +1,5 @@ # engine/row.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index a99815390..1e9b707e8 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -1,5 +1,5 @@ # engine/strategies.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 58f59642c..85e206019 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -1,5 +1,5 @@ # engine/url.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 26dd5ddd0..4e302f464 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -1,5 +1,5 @@ # engine/util.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py index c5c27b078..093f596cb 100644 --- a/lib/sqlalchemy/event/__init__.py +++ b/lib/sqlalchemy/event/__init__.py @@ -1,5 +1,5 @@ # event/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/event/api.py b/lib/sqlalchemy/event/api.py index cd09235c3..f1a2bb774 100644 --- a/lib/sqlalchemy/event/api.py +++ b/lib/sqlalchemy/event/api.py @@ -1,5 +1,5 @@ # event/api.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index 122221d40..245eaab60 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -1,5 +1,5 @@ # event/attr.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index 1ba88f3d2..181db0cf2 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -1,5 +1,5 @@ # event/base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py index 14115d377..ce2ed2d4f 100644 --- a/lib/sqlalchemy/event/legacy.py +++ b/lib/sqlalchemy/event/legacy.py @@ -1,5 +1,5 @@ # event/legacy.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index d1009eca9..13310b11b 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -1,5 +1,5 @@ # event/registry.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index 93ef43815..2d7f4fcbd 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -1,5 +1,5 @@ # sqlalchemy/events.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 63c56c34d..b031c1610 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -1,5 +1,5 @@ # sqlalchemy/exc.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py index 1f842fc2a..5f1783f75 100644 --- a/lib/sqlalchemy/ext/__init__.py +++ b/lib/sqlalchemy/ext/__init__.py @@ -1,5 +1,5 @@ # ext/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index a2c6b596f..0b4d5954e 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -1,5 +1,5 @@ # ext/associationproxy.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index 051f9e21a..fa8c5006e 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -23,3 +23,20 @@ class StartableContext(abc.ABC): "%s context has not been started and object has not been awaited." % (self.__class__.__name__) ) + + +class ProxyComparable: + def __hash__(self): + return id(self) + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self._proxied == other._proxied + ) + + def __ne__(self, other): + return ( + not isinstance(other, self.__class__) + or self._proxied != other._proxied + ) diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 16edcc2b2..829e89b71 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -4,6 +4,7 @@ from typing import Mapping from typing import Optional from . import exc as async_exc +from .base import ProxyComparable from .base import StartableContext from .result import AsyncResult from ... import exc @@ -41,7 +42,7 @@ def create_async_engine(*arg, **kw): class AsyncConnectable: - __slots__ = "_slots_dispatch" + __slots__ = "_slots_dispatch", "__weakref__" @util.create_proxy_methods( @@ -57,7 +58,7 @@ class AsyncConnectable: "default_isolation_level", ], ) -class AsyncConnection(StartableContext, AsyncConnectable): +class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """An asyncio proxy for a :class:`_engine.Connection`. :class:`_asyncio.AsyncConnection` is acquired using the @@ -131,6 +132,24 @@ class AsyncConnection(StartableContext, AsyncConnectable): def _proxied(self): return self.sync_connection + @property + def info(self): + """Return the :attr:`_engine.Connection.info` dictionary of the + underlying :class:`_engine.Connection`. + + This dictionary is freely writable for user-defined state to be + associated with the database connection. + + This attribute is only available if the :class:`.AsyncConnection` is + currently connected. If the :attr:`.AsyncConnection.closed` attribute + is ``True``, then accessing this attribute will raise + :class:`.ResourceClosedError`. + + .. versionadded:: 1.4.0b2 + + """ + return self.sync_connection.info + def _sync_connection(self): if not self.sync_connection: self._raise_for_not_started() @@ -166,6 +185,69 @@ class AsyncConnection(StartableContext, AsyncConnectable): conn = self._sync_connection() return await greenlet_spawn(conn.get_isolation_level) + def in_transaction(self): + """Return True if a transaction is in progress. + + .. versionadded:: 1.4.0b2 + + """ + + conn = self._sync_connection() + + return conn.in_transaction() + + def in_nested_transaction(self): + """Return True if a transaction is in progress. + + .. versionadded:: 1.4.0b2 + + """ + conn = self._sync_connection() + + return conn.in_nested_transaction() + + def get_transaction(self): + """Return an :class:`.AsyncTransaction` representing the current + transaction, if any. + + This makes use of the underlying synchronous connection's + :meth:`_engine.Connection.get_transaction` method to get the current + :class:`_engine.Transaction`, which is then proxied in a new + :class:`.AsyncTransaction` object. + + .. versionadded:: 1.4.0b2 + + """ + conn = self._sync_connection() + + trans = conn.get_transaction() + if trans is not None: + return AsyncTransaction._from_existing_transaction(self, trans) + else: + return None + + def get_nested_transaction(self): + """Return an :class:`.AsyncTransaction` representing the current + nested (savepoint) transaction, if any. + + This makes use of the underlying synchronous connection's + :meth:`_engine.Connection.get_nested_transaction` method to get the + current :class:`_engine.Transaction`, which is then proxied in a new + :class:`.AsyncTransaction` object. + + .. versionadded:: 1.4.0b2 + + """ + conn = self._sync_connection() + + trans = conn.get_nested_transaction() + if trans is not None: + return AsyncTransaction._from_existing_transaction( + self, trans, True + ) + else: + return None + async def execution_options(self, **opt): r"""Set non-SQL options for the connection which take effect during execution. @@ -366,6 +448,16 @@ class AsyncConnection(StartableContext, AsyncConnectable): with async_engine.begin() as conn: await conn.run_sync(metadata.create_all) + .. note:: + + The provided callable is invoked inline within the asyncio event + loop, and will block on traditional IO calls. IO within this + callable should only call into SQLAlchemy's asyncio database + APIs which will be properly adapted to the greenlet context. + + .. seealso:: + + :ref:`session_run_sync` """ conn = self._sync_connection() @@ -391,7 +483,7 @@ class AsyncConnection(StartableContext, AsyncConnectable): ], attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"], ) -class AsyncEngine(AsyncConnectable): +class AsyncEngine(ProxyComparable, AsyncConnectable): """An asyncio proxy for a :class:`_engine.Engine`. :class:`_asyncio.AsyncEngine` is acquired using the @@ -513,7 +605,7 @@ class AsyncEngine(AsyncConnectable): return await greenlet_spawn(self.sync_engine.dispose) -class AsyncTransaction(StartableContext): +class AsyncTransaction(ProxyComparable, StartableContext): """An asyncio proxy for a :class:`_engine.Transaction`.""" __slots__ = ("connection", "sync_transaction", "nested") @@ -523,12 +615,29 @@ class AsyncTransaction(StartableContext): self.sync_transaction: Optional[Transaction] = None self.nested = nested + @classmethod + def _from_existing_transaction( + cls, + connection: AsyncConnection, + sync_transaction: Transaction, + nested: bool = False, + ): + obj = cls.__new__(cls) + obj.connection = connection + obj.sync_transaction = sync_transaction + obj.nested = nested + return obj + def _sync_transaction(self): if not self.sync_transaction: self._raise_for_not_started() return self.sync_transaction @property + def _proxied(self): + return self.sync_transaction + + @property def is_valid(self) -> bool: return self._sync_transaction().is_valid @@ -582,7 +691,10 @@ class AsyncTransaction(StartableContext): await self.rollback() -def _get_sync_engine(async_engine): +def _get_sync_engine_or_connection(async_engine): + if isinstance(async_engine, AsyncConnection): + return async_engine.sync_connection + try: return async_engine.sync_engine except AttributeError as e: diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index bac2aa44b..faa279cf9 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -75,12 +75,13 @@ class AsyncSession: kw["future"] = True if bind: self.bind = engine - bind = engine._get_sync_engine(bind) + bind = engine._get_sync_engine_or_connection(bind) if binds: self.binds = binds binds = { - key: engine._get_sync_engine(b) for key, b in binds.items() + key: engine._get_sync_engine_or_connection(b) + for key, b in binds.items() } self.sync_session = self._proxied = Session( @@ -120,6 +121,16 @@ class AsyncSession: with AsyncSession(async_engine) as session: await session.run_sync(some_business_method) + .. note:: + + The provided callable is invoked inline within the asyncio event + loop, and will block on traditional IO calls. IO within this + callable should only call into SQLAlchemy's asyncio database + APIs which will be properly adapted to the greenlet context. + + .. seealso:: + + :ref:`session_run_sync` """ return await greenlet_spawn(fn, self.sync_session, *arg, **kw) diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index 8fe318dfb..6d90214bf 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -1,5 +1,5 @@ # ext/automap.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 8a2023e96..5d33bc76d 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -1,5 +1,5 @@ # sqlalchemy/ext/baked.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index f0447d8df..5a31173ec 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -1,5 +1,5 @@ # ext/compiler.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py index 8b38945b2..9fc45dfaa 100644 --- a/lib/sqlalchemy/ext/declarative/__init__.py +++ b/lib/sqlalchemy/ext/declarative/__init__.py @@ -1,5 +1,5 @@ # ext/declarative/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py index 32344d538..a2e295280 100644 --- a/lib/sqlalchemy/ext/declarative/extensions.py +++ b/lib/sqlalchemy/ext/declarative/extensions.py @@ -1,5 +1,5 @@ # ext/declarative/extensions.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 220cdbfc8..0829d9f13 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -1,5 +1,5 @@ # ext/horizontal_shard.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 8679d907a..5fcb4fac0 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -1,5 +1,5 @@ # ext/hybrid.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py index f58acceeb..b64a358ce 100644 --- a/lib/sqlalchemy/ext/indexable.py +++ b/lib/sqlalchemy/ext/indexable.py @@ -1,5 +1,5 @@ # ext/index.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 32a22a495..4bb9e795b 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -1,5 +1,5 @@ # ext/mutable.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 03ea096e7..2cb85588b 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -1,5 +1,5 @@ # ext/orderinglist.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py index afd44ca3d..08c5cfc06 100644 --- a/lib/sqlalchemy/ext/serializer.py +++ b/lib/sqlalchemy/ext/serializer.py @@ -1,5 +1,5 @@ # ext/serializer.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/future/__init__.py b/lib/sqlalchemy/future/__init__.py index b07b9b040..4e4054e84 100644 --- a/lib/sqlalchemy/future/__init__.py +++ b/lib/sqlalchemy/future/__init__.py @@ -1,5 +1,5 @@ # sql/future/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/future/orm/__init__.py b/lib/sqlalchemy/future/orm/__init__.py index 56b5dfa46..abf6476e3 100644 --- a/lib/sqlalchemy/future/orm/__init__.py +++ b/lib/sqlalchemy/future/orm/__init__.py @@ -1,5 +1,5 @@ # sql/future/orm/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py index 69b160ce7..3341bfac8 100644 --- a/lib/sqlalchemy/inspection.py +++ b/lib/sqlalchemy/inspection.py @@ -1,5 +1,5 @@ # sqlalchemy/inspect.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index 44f8c4ff8..687cc066b 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -1,5 +1,5 @@ # sqlalchemy/log.py -# Copyright (C) 2006-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2006-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk # diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 7d35856f3..2b7ad7bbd 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -1,5 +1,5 @@ # orm/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 92650c1d0..dd354c4e0 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -1,5 +1,5 @@ # orm/attributes.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 6fdf1f372..b805c6f93 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -1,5 +1,5 @@ # orm/base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index ad1d9adcd..0bc888197 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -1,5 +1,5 @@ # ext/declarative/clsregistry.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 262aeaf04..63278fb7e 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -1,5 +1,5 @@ # orm/collections.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 41b2146a3..584a07970 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -1,5 +1,5 @@ # orm/context.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 4d9766204..89d18bcac 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -1,5 +1,5 @@ # ext/declarative/api.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 353f44e43..db6d274c8 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -1,5 +1,5 @@ # ext/declarative/base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 9c2c5ade3..5a329b28c 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -1,5 +1,5 @@ # orm/dependency.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 713891d91..695b1a7b4 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -1,5 +1,5 @@ # orm/descriptor_props.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 4426041e3..32eb23199 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -1,5 +1,5 @@ # orm/dynamic.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 8763f0a3d..135125663 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -1,5 +1,5 @@ # orm/evaluator.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 9f9ebd461..0824ae7de 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1,5 +1,5 @@ # orm/events.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py index 19f8ca3bc..9aab78e06 100644 --- a/lib/sqlalchemy/orm/exc.py +++ b/lib/sqlalchemy/orm/exc.py @@ -1,5 +1,5 @@ # orm/exc.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index e4795a92d..c0ed38365 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -1,5 +1,5 @@ # orm/identity.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index f390c49a7..d2ff72180 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -1,5 +1,5 @@ # orm/instrumentation.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index bacec422c..bbe39af60 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -1,5 +1,5 @@ # orm/interfaces.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index ecb704a04..3a85a2d7f 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -1,5 +1,5 @@ # orm/loading.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e8f98d150..9ac0c85c6 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1,5 +1,5 @@ # orm/mapper.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index f6c03d007..13ff90cdb 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -1,5 +1,5 @@ # orm/path_registry.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index cfb6d9265..b57963eb3 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1,5 +1,5 @@ # orm/persistence.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 5fb3beca3..4d0c7528b 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -1,5 +1,5 @@ # orm/properties.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 334283bb9..5dd2a8a21 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1,5 +1,5 @@ # orm/query.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 31a3b9ec9..550ff3833 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1,5 +1,5 @@ # orm/relationships.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 29d845c0a..6aef243c7 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -1,5 +1,5 @@ # orm/scoping.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index a5f0894f6..e8312f393 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1,5 +1,5 @@ # orm/session.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index b139d5933..cda98b890 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -1,5 +1,5 @@ # orm/state.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 6838011b1..837d4d548 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1,5 +1,5 @@ # orm/strategies.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index c5d5b146d..fbecfedeb 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -1,4 +1,4 @@ -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index ceaf54e5d..691961f38 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -1,5 +1,5 @@ # orm/sync.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 868f8e087..c293d90cb 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -1,5 +1,5 @@ # orm/unitofwork.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index bbb428683..1bc0ceb4d 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1,5 +1,5 @@ # orm/util.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index 353f34333..29f589acc 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -1,5 +1,5 @@ # sqlalchemy/pool/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -29,6 +29,7 @@ from .dbapi_proxy import clear_managers from .dbapi_proxy import manage from .impl import AssertionPool from .impl import AsyncAdaptedQueuePool +from .impl import FallbackAsyncAdaptedQueuePool from .impl import NullPool from .impl import QueuePool from .impl import SingletonThreadPool @@ -46,6 +47,7 @@ __all__ = [ "NullPool", "QueuePool", "AsyncAdaptedQueuePool", + "FallbackAsyncAdaptedQueuePool", "SingletonThreadPool", "StaticPool", ] diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 68fa5fe85..7c9509e45 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -1,5 +1,5 @@ # sqlalchemy/pool.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/pool/dbapi_proxy.py b/lib/sqlalchemy/pool/dbapi_proxy.py index 6e11d2e59..96b5e8cba 100644 --- a/lib/sqlalchemy/pool/dbapi_proxy.py +++ b/lib/sqlalchemy/pool/dbapi_proxy.py @@ -1,5 +1,5 @@ # sqlalchemy/pool/dbapi_proxy.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/pool/events.py b/lib/sqlalchemy/pool/events.py index 363afdd78..03106019e 100644 --- a/lib/sqlalchemy/pool/events.py +++ b/lib/sqlalchemy/pool/events.py @@ -1,5 +1,5 @@ # sqlalchemy/pool/events.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index 38afbc7a1..825ac0307 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -1,5 +1,5 @@ # sqlalchemy/pool.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -226,6 +226,10 @@ class AsyncAdaptedQueuePool(QueuePool): _queue_class = sqla_queue.AsyncAdaptedQueue +class FallbackAsyncAdaptedQueuePool(AsyncAdaptedQueuePool): + _queue_class = sqla_queue.FallbackAsyncAdaptedQueue + + class NullPool(Pool): """A Pool which does not pool connections. diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py index 8618d5e2a..c090548e8 100644 --- a/lib/sqlalchemy/processors.py +++ b/lib/sqlalchemy/processors.py @@ -1,5 +1,5 @@ # sqlalchemy/processors.py -# Copyright (C) 2010-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2010-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com # diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index b83b5525f..9bd780f7e 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -1,5 +1,5 @@ # schema.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 645189e76..11795c3b2 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -1,5 +1,5 @@ # sql/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 94d37573c..b29c9cc66 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -1,5 +1,5 @@ # sql/annotation.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 5178a7ab1..a1426b628 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -1,5 +1,5 @@ # sql/base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 43c89ee82..05e0a4fcf 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -1,5 +1,5 @@ # sql/coercions.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8ee575cca..45cb75c25 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1,5 +1,5 @@ # sql/compiler.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 1c68d6450..f67e76181 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -1,5 +1,5 @@ # sql/crud.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index d471f29d5..a166a6bdf 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -1,5 +1,5 @@ # sql/ddl.py -# Copyright (C) 2009-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2009-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index be6d0787b..88f9e6523 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -1,5 +1,5 @@ # sql/default_comparator.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index b755eef73..c402de121 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -1,5 +1,5 @@ # sql/dml.py -# Copyright (C) 2009-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2009-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 75c1fc1bf..d3c767b5d 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1,5 +1,5 @@ # sql/elements.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/events.py b/lib/sqlalchemy/sql/events.py index 797ca697f..deaa992af 100644 --- a/lib/sqlalchemy/sql/events.py +++ b/lib/sqlalchemy/sql/events.py @@ -1,5 +1,5 @@ # sqlalchemy/sql/events.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index a3a4ec351..359f5e533 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1,5 +1,5 @@ # sql/expression.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 6d331910d..a9ea98d04 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1,5 +1,5 @@ # sql/functions.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 3f0ca477e..92c3ac9f7 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -1,5 +1,5 @@ # sql/lambdas.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py index f530177dd..130ec6875 100644 --- a/lib/sqlalchemy/sql/naming.py +++ b/lib/sqlalchemy/sql/naming.py @@ -1,5 +1,5 @@ # sqlalchemy/naming.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 29a2f191e..a7dfa9b6d 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -1,5 +1,5 @@ # sql/operators.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index b88625b88..2c4ff75c4 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -1,5 +1,5 @@ # sql/roles.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index b5e45c18d..36d69456e 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1,5 +1,5 @@ # sql/schema.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -2685,9 +2685,12 @@ class Sequence(IdentityOptions, roles.StatementRole, DefaultGenerator): for this :class:`.Sequence` within any SQL expression. """ - return util.preloaded.sql_functions.func.next_value( - self, bind=self.bind - ) + if self.bind: + return util.preloaded.sql_functions.func.next_value( + self, bind=self.bind + ) + else: + return util.preloaded.sql_functions.func.next_value(self) def _set_parent(self, column, **kw): super(Sequence, self)._set_parent(column) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index b49fe92df..a1dfa8b56 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1,5 +1,5 @@ # sql/selectable.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 072afe46e..d20c8168d 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1,5 +1,5 @@ # sql/sqltypes.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index b48886cca..462a8763b 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1,5 +1,5 @@ # sql/types_api.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -1082,8 +1082,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): In most cases this returns a dialect-adapted form of the :class:`.TypeEngine` type represented by ``self.impl``. - Makes usage of :meth:`dialect_impl` but also traverses - into wrapped :class:`.TypeDecorator` instances. + Makes usage of :meth:`dialect_impl`. Behavior can be customized here by overriding :meth:`load_dialect_impl`. @@ -1091,8 +1090,6 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): adapted = dialect.type_descriptor(self) if not isinstance(adapted, type(self)): return adapted - elif isinstance(self.impl, TypeDecorator): - return self.impl.type_engine(dialect) else: return self.load_dialect_impl(dialect) @@ -1117,7 +1114,6 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): method. """ - # some dialects have a lookup for a TypeDecorator subclass directly. # postgresql.INTERVAL being the main example typ = self.dialect_impl(dialect) diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 55c17a193..df805c557 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,5 +1,5 @@ # sql/util.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 43b7cb4af..fe0fbf669 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -1,5 +1,5 @@ # sql/visitors.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index c46a3fa89..c1afeb907 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -1,5 +1,5 @@ # testing/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -29,6 +29,7 @@ from .assertions import in_ # noqa from .assertions import is_ # noqa from .assertions import is_false # noqa from .assertions import is_instance_of # noqa +from .assertions import is_none # noqa from .assertions import is_not # noqa from .assertions import is_not_ # noqa from .assertions import is_true # noqa diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 17a0acf20..b2a4ac66e 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -1,5 +1,5 @@ # testing/assertions.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -224,12 +224,16 @@ def is_instance_of(a, b, msg=None): assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b) +def is_none(a, msg=None): + is_(a, None, msg=msg) + + def is_true(a, msg=None): - is_(a, True, msg=msg) + is_(bool(a), True, msg=msg) def is_false(a, msg=None): - is_(a, False, msg=msg) + is_(bool(a), False, msg=msg) def is_(a, b, msg=None): diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index e20209ba5..1bdd11585 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -1,5 +1,5 @@ # testing/assertsql.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py new file mode 100644 index 000000000..ef92fa5b9 --- /dev/null +++ b/lib/sqlalchemy/testing/asyncio.py @@ -0,0 +1,129 @@ +# testing/asyncio.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +# functions and wrappers to run tests, fixtures, provisioning and +# setup/teardown in an asyncio event loop, conditionally based on the +# current DB driver being used for a test. + +# note that SQLAlchemy's asyncio integration also supports a method +# of running individual asyncio functions inside of separate event loops +# using "async_fallback" mode; however running whole functions in the event +# loop is a more accurate test for how SQLAlchemy's asyncio features +# would run in the real world. + + +from functools import wraps +import inspect + +from . import config +from ..util.concurrency import _util_async_run +from ..util.concurrency import _util_async_run_coroutine_function + +# may be set to False if the +# --disable-asyncio flag is passed to the test runner. +ENABLE_ASYNCIO = True + + +def _run_coroutine_function(fn, *args, **kwargs): + return _util_async_run_coroutine_function(fn, *args, **kwargs) + + +def _assume_async(fn, *args, **kwargs): + """Run a function in an asyncio loop unconditionally. + + This function is used for provisioning features like + testing a database connection for server info. + + Note that for blocking IO database drivers, this means they block the + event loop. + + """ + + if not ENABLE_ASYNCIO: + return fn(*args, **kwargs) + + return _util_async_run(fn, *args, **kwargs) + + +def _maybe_async_provisioning(fn, *args, **kwargs): + """Run a function in an asyncio loop if any current drivers might need it. + + This function is used for provisioning features that take + place outside of a specific database driver being selected, so if the + current driver that happens to be used for the provisioning operation + is an async driver, it will run in asyncio and not fail. + + Note that for blocking IO database drivers, this means they block the + event loop. + + """ + if not ENABLE_ASYNCIO: + + return fn(*args, **kwargs) + + if config.any_async: + return _util_async_run(fn, *args, **kwargs) + else: + return fn(*args, **kwargs) + + +def _maybe_async(fn, *args, **kwargs): + """Run a function in an asyncio loop if the current selected driver is + async. + + This function is used for test setup/teardown and tests themselves + where the current DB driver is known. + + + """ + if not ENABLE_ASYNCIO: + + return fn(*args, **kwargs) + + is_async = config._current.is_async + + if is_async: + return _util_async_run(fn, *args, **kwargs) + else: + return fn(*args, **kwargs) + + +def _maybe_async_wrapper(fn): + """Apply the _maybe_async function to an existing function and return + as a wrapped callable, supporting generator functions as well. + + This is currently used for pytest fixtures that support generator use. + + """ + + if inspect.isgeneratorfunction(fn): + _stop = object() + + def call_next(gen): + try: + return next(gen) + # can't raise StopIteration in an awaitable. + except StopIteration: + return _stop + + @wraps(fn) + def wrap_fixture(*args, **kwargs): + gen = fn(*args, **kwargs) + while True: + value = _maybe_async(call_next, gen) + if value is _stop: + break + yield value + + else: + + @wraps(fn) + def wrap_fixture(*args, **kwargs): + return _maybe_async(fn, *args, **kwargs) + + return wrap_fixture diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 0b8027b84..f64153f33 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -1,5 +1,5 @@ # testing/config.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -7,6 +7,8 @@ import collections +from .. import util + requirements = None db = None db_url = None @@ -14,6 +16,7 @@ db_opts = None file_config = None test_schema = None test_schema_2 = None +any_async = False _current = None ident = "main" @@ -104,6 +107,10 @@ class Config(object): self.test_schema = "test_schema" self.test_schema_2 = "test_schema_2" + self.is_async = db.dialect.is_async and not util.asbool( + db.url.query.get("async_fallback", False) + ) + _stack = collections.deque() _configs = set() @@ -121,7 +128,15 @@ class Config(object): If there are no configs set up yet, this config also gets set as the "_current". """ + global any_async + cfg = Config(db, db_opts, options, file_config) + + # if any backends include an async driver, then ensure + # all setup/teardown and tests are wrapped in the maybe_async() + # decorator that will set up a greenlet context for async drivers. + any_async = any_async or cfg.is_async + cls._configs.add(cfg) return cfg diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index bb137cb32..a4c1f3973 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -1,5 +1,5 @@ # testing/engines.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -46,7 +46,7 @@ class ConnectionKiller(object): fn() except Exception as e: warnings.warn( - "testing_reaper couldn't " "rollback/close connection: %s" % e + "testing_reaper couldn't rollback/close connection: %s" % e ) def rollback_all(self): @@ -97,7 +97,10 @@ class ConnectionKiller(object): self.conns = set() for rec in list(self.testing_engines): - rec.dispose() + if hasattr(rec, "sync_engine"): + rec.sync_engine.dispose() + else: + rec.dispose() def assert_all_closed(self): for rec in self.proxy_refs: @@ -199,9 +202,7 @@ class ReconnectFixture(object): try: fn() except Exception as e: - warnings.warn( - "ReconnectFixture couldn't " "close connection: %s" % e - ) + warnings.warn("ReconnectFixture couldn't close connection: %s" % e) def shutdown(self, stop=False): # TODO: this doesn't cover all cases @@ -238,10 +239,12 @@ def reconnecting_engine(url=None, options=None): return engine -def testing_engine(url=None, options=None, future=False): +def testing_engine(url=None, options=None, future=False, asyncio=False): """Produce an engine configured by --options with optional overrides.""" - if future or config.db and config.db._is_future: + if asyncio: + from sqlalchemy.ext.asyncio import create_async_engine as create_engine + elif future or config.db and config.db._is_future: from sqlalchemy.future import create_engine else: from sqlalchemy import create_engine @@ -265,11 +268,14 @@ def testing_engine(url=None, options=None, future=False): default_opt.update(options) engine = create_engine(url, **options) - engine._has_events = True # enable event blocks, helps with profiling + if asyncio: + engine.sync_engine._has_events = True + else: + engine._has_events = True # enable event blocks, helps with profiling if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 - engine.pool._max_overflow = 0 + engine.pool._max_overflow = 5 if use_reaper: testing_reaper.add_engine(engine) diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py index 085c19196..050a30a89 100644 --- a/lib/sqlalchemy/testing/entities.py +++ b/lib/sqlalchemy/testing/entities.py @@ -1,5 +1,5 @@ # testing/entities.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 6ec438193..b20e17442 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -1,5 +1,5 @@ # testing/exclusions.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 0ede25176..ac4d3d8fa 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -1,5 +1,5 @@ # testing/fixtures.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -13,6 +13,7 @@ from . import assertions from . import config from . import schema from .engines import drop_all_tables +from .engines import testing_engine from .entities import BasicEntity from .entities import ComparableEntity from .entities import ComparableMixin # noqa @@ -24,7 +25,6 @@ from ..orm import registry from ..orm.decl_api import DeclarativeMeta from ..schema import sort_tables_and_constraints - # whether or not we use unittest changes things dramatically, # as far as how pytest collection works. @@ -73,21 +73,31 @@ class TestBase(object): trans.rollback() conn.close() - # propose a replacement for @testing.provide_metadata. - # the problem with this is that TablesTest below has a ".metadata" - # attribute already which is accessed directly as part of the - # @testing.provide_metadata pattern. Might need to call this _metadata - # for it to be useful. - # @config.fixture() - # def metadata(self): - # """Provide bound MetaData for a single test, dropping afterwards.""" - # - # from . import engines - # metadata = schema.MetaData(config.db) - # try: - # yield metadata - # finally: - # engines.drop_all_tables(metadata, config.db) + @config.fixture() + def future_connection(self): + + eng = testing_engine(future=True) + conn = eng.connect() + trans = conn.begin() + try: + yield conn + finally: + if trans.is_active: + trans.rollback() + conn.close() + + @config.fixture() + def metadata(self): + """Provide bound MetaData for a single test, dropping afterwards.""" + + from . import engines + from ..sql import schema + + metadata = schema.MetaData() + try: + yield metadata + finally: + engines.drop_all_tables(metadata, config.db) class FutureEngineMixin(object): @@ -136,11 +146,15 @@ class TablesTest(TestBase): run_dispose_bind = None bind = None - metadata = None + _tables_metadata = None tables = None other = None sequences = None + @property + def tables_test_metadata(self): + return self._tables_metadata + @classmethod def setup_class(cls): cls._init_class() @@ -161,8 +175,7 @@ class TablesTest(TestBase): cls.sequences = adict() cls.bind = cls.setup_bind() - cls.metadata = sa.MetaData() - cls.metadata.bind = cls.bind + cls._tables_metadata = sa.MetaData() @classmethod def _setup_once_inserts(cls): @@ -174,21 +187,21 @@ class TablesTest(TestBase): @classmethod def _setup_once_tables(cls): if cls.run_define_tables == "once": - cls.define_tables(cls.metadata) + cls.define_tables(cls._tables_metadata) if cls.run_create_tables == "once": - cls.metadata.create_all(cls.bind) - cls.tables.update(cls.metadata.tables) - cls.sequences.update(cls.metadata._sequences) + cls._tables_metadata.create_all(cls.bind) + cls.tables.update(cls._tables_metadata.tables) + cls.sequences.update(cls._tables_metadata._sequences) def _setup_each_tables(self): if self.run_define_tables == "each": - self.define_tables(self.metadata) + self.define_tables(self._tables_metadata) if self.run_create_tables == "each": - self.metadata.create_all(self.bind) - self.tables.update(self.metadata.tables) - self.sequences.update(self.metadata._sequences) + self._tables_metadata.create_all(self.bind) + self.tables.update(self._tables_metadata.tables) + self.sequences.update(self._tables_metadata._sequences) elif self.run_create_tables == "each": - self.metadata.create_all(self.bind) + self._tables_metadata.create_all(self.bind) def _setup_each_inserts(self): if self.run_inserts == "each": @@ -200,10 +213,10 @@ class TablesTest(TestBase): if self.run_define_tables == "each": self.tables.clear() if self.run_create_tables == "each": - drop_all_tables(self.metadata, self.bind) - self.metadata.clear() + drop_all_tables(self._tables_metadata, self.bind) + self._tables_metadata.clear() elif self.run_create_tables == "each": - drop_all_tables(self.metadata, self.bind) + drop_all_tables(self._tables_metadata, self.bind) # no need to run deletes if tables are recreated on setup if ( @@ -216,7 +229,7 @@ class TablesTest(TestBase): [ t for (t, fks) in sort_tables_and_constraints( - self.metadata.tables.values() + self._tables_metadata.tables.values() ) if t is not None ] @@ -239,12 +252,12 @@ class TablesTest(TestBase): @classmethod def _teardown_once_metadata_bind(cls): if cls.run_create_tables: - drop_all_tables(cls.metadata, cls.bind) + drop_all_tables(cls._tables_metadata, cls.bind) if cls.run_dispose_bind == "once": cls.dispose_bind(cls.bind) - cls.metadata.bind = None + cls._tables_metadata.bind = None if cls.run_setup_bind is not None: cls.bind = None @@ -294,7 +307,7 @@ class TablesTest(TestBase): headers[table] = data[0] rows[table] = data[1:] for table, fks in sort_tables_and_constraints( - cls.metadata.tables.values() + cls._tables_metadata.tables.values() ): if table is None: continue @@ -340,6 +353,12 @@ def create_session(**kw): return sa.orm.Session(config.db, **kw) +def fixture_session(**kw): + kw.setdefault("autoflush", True) + kw.setdefault("expire_on_commit", True) + return sa.orm.Session(config.db, **kw) + + class ORMTest(_ORMTest, TestBase): pass @@ -480,7 +499,7 @@ class DeclarativeMappedTest(MappedTest): __table_cls__ = schema.Table _DeclBase = declarative_base( - metadata=cls.metadata, + metadata=cls._tables_metadata, metaclass=FindFixtureDeclarative, cls=DeclarativeBasic, ) @@ -490,8 +509,8 @@ class DeclarativeMappedTest(MappedTest): # classes super(DeclarativeMappedTest, cls)._with_register_classes(fn) - if cls.metadata.tables and cls.run_create_tables: - cls.metadata.create_all(config.db) + if cls._tables_metadata.tables and cls.run_create_tables: + cls._tables_metadata.create_all(config.db) class ComputedReflectionFixtureTest(TablesTest): diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py index a63082b9b..1bbde3a1e 100644 --- a/lib/sqlalchemy/testing/mock.py +++ b/lib/sqlalchemy/testing/mock.py @@ -1,5 +1,5 @@ # testing/mock.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py index 8f8e26913..dcdfcb1a9 100644 --- a/lib/sqlalchemy/testing/pickleable.py +++ b/lib/sqlalchemy/testing/pickleable.py @@ -1,5 +1,5 @@ # testing/pickleable.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -52,9 +52,9 @@ class Screen(object): class Foo(object): - def __init__(self, moredata): + def __init__(self, moredata, stuff="im stuff"): self.data = "im data" - self.stuff = "im stuff" + self.stuff = stuff self.moredata = moredata __hash__ = object.__hash__ diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 5e41f2cdf..3594cd276 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -1,5 +1,5 @@ # plugin/plugin_base.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -63,21 +63,21 @@ def setup_options(make_option): make_option( "--log-info", action="callback", - type="string", + type=str, callback=_log, help="turn on info logging for <LOG> (multiple OK)", ) make_option( "--log-debug", action="callback", - type="string", + type=str, callback=_log, help="turn on debug logging for <LOG> (multiple OK)", ) make_option( "--db", action="append", - type="string", + type=str, dest="db", help="Use prefab database uri. Multiple OK, " "first one is run by default.", @@ -91,7 +91,7 @@ def setup_options(make_option): make_option( "--dburi", action="append", - type="string", + type=str, dest="dburi", help="Database uri. Multiple OK, " "first one is run by default.", ) @@ -111,6 +111,11 @@ def setup_options(make_option): help="Drop all tables in the target database first", ) make_option( + "--disable-asyncio", + action="store_true", + help="disable test / fixtures / provisoning running in asyncio", + ) + make_option( "--backend-only", action="store_true", dest="backend_only", @@ -130,20 +135,20 @@ def setup_options(make_option): ) make_option( "--profile-sort", - type="string", + type=str, default="cumulative", dest="profilesort", help="Type of sort for profiling standard output", ) make_option( "--profile-dump", - type="string", + type=str, dest="profiledump", help="Filename where a single profile run will be dumped", ) make_option( "--postgresql-templatedb", - type="string", + type=str, help="name of template database to use for PostgreSQL " "CREATE DATABASE (defaults to current database)", ) @@ -156,7 +161,7 @@ def setup_options(make_option): ) make_option( "--write-idents", - type="string", + type=str, dest="write_idents", help="write out generated follower idents to <file>, " "when -n<num> is used", @@ -172,7 +177,7 @@ def setup_options(make_option): make_option( "--requirements", action="callback", - type="string", + type=str, callback=_requirements_opt, help="requirements class for testing, overrides setup.cfg", ) @@ -188,14 +193,14 @@ def setup_options(make_option): "--include-tag", action="callback", callback=_include_tag, - type="string", + type=str, help="Include tests with tag <tag>", ) make_option( "--exclude-tag", action="callback", callback=_exclude_tag, - type="string", + type=str, help="Exclude tests with tag <tag>", ) make_option( @@ -375,10 +380,18 @@ def _init_symbols(options, file_config): @post +def _set_disable_asyncio(opt, file_config): + if opt.disable_asyncio: + from sqlalchemy.testing import asyncio + + asyncio.ENABLE_ASYNCIO = False + + +@post def _engine_uri(options, file_config): - from sqlalchemy.testing import config from sqlalchemy import testing + from sqlalchemy.testing import config from sqlalchemy.testing import provision if options.dburi: @@ -448,73 +461,13 @@ def _setup_requirements(argument): @post def _prep_testing_database(options, file_config): - from sqlalchemy.testing import config, util - from sqlalchemy.testing.exclusions import against - from sqlalchemy import schema, inspect + from sqlalchemy.testing import config if options.dropfirst: - for cfg in config.Config.all_configs(): - e = cfg.db - - # TODO: this has to be part of provision.py in postgresql - if against(cfg, "postgresql"): - with e.connect().execution_options( - isolation_level="AUTOCOMMIT" - ) as conn: - for xid in conn.execute( - "select gid from pg_prepared_xacts" - ).scalars(): - conn.execute("ROLLBACK PREPARED '%s'" % xid) - - inspector = inspect(e) - try: - view_names = inspector.get_view_names() - except NotImplementedError: - pass - else: - for vname in view_names: - e.execute( - schema._DropView( - schema.Table(vname, schema.MetaData()) - ) - ) + from sqlalchemy.testing import provision - if config.requirements.schemas.enabled_for_config(cfg): - try: - view_names = inspector.get_view_names(schema="test_schema") - except NotImplementedError: - pass - else: - for vname in view_names: - e.execute( - schema._DropView( - schema.Table( - vname, - schema.MetaData(), - schema="test_schema", - ) - ) - ) - - util.drop_all_tables(e, inspector) - - if config.requirements.schemas.enabled_for_config(cfg): - util.drop_all_tables(e, inspector, schema=cfg.test_schema) - - # TODO: this has to be part of provision.py in postgresql - if against(cfg, "postgresql"): - from sqlalchemy.dialects import postgresql - - for enum in inspector.get_enums("*"): - e.execute( - postgresql.DropEnumType( - postgresql.ENUM( - name=enum["name"], schema=enum["schema"] - ) - ) - ) - - # TODO: need to do a get_sequences and drop them also after tables + for cfg in config.Config.all_configs(): + provision.drop_all_schema_objects(cfg, cfg.db) @post diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 644ea6dc2..46468a07d 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -26,11 +26,6 @@ else: from typing import Sequence try: - import asyncio -except ImportError: - pass - -try: import xdist # noqa has_xdist = True @@ -126,11 +121,15 @@ def collect_types_fixture(): def pytest_sessionstart(session): - plugin_base.post_begin() + from sqlalchemy.testing import asyncio + + asyncio._assume_async(plugin_base.post_begin) def pytest_sessionfinish(session): - plugin_base.final_process_cleanup() + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup) if session.config.option.dump_pyannotate: from pyannotate_runtime import collect_types @@ -162,23 +161,31 @@ if has_xdist: import uuid def pytest_configure_node(node): + from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio + # the master for each node fills workerinput dictionary # which pytest-xdist will transfer to the subprocess plugin_base.memoize_important_follower_config(node.workerinput) node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12] - from sqlalchemy.testing import provision - provision.create_follower_db(node.workerinput["follower_ident"]) + asyncio._maybe_async_provisioning( + provision.create_follower_db, node.workerinput["follower_ident"] + ) def pytest_testnodedown(node, error): from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio - provision.drop_follower_db(node.workerinput["follower_ident"]) + asyncio._maybe_async_provisioning( + provision.drop_follower_db, node.workerinput["follower_ident"] + ) def pytest_collection_modifyitems(session, config, items): + # look for all those classes that specify __backend__ and # expand them out into per-database test cases. @@ -189,6 +196,8 @@ def pytest_collection_modifyitems(session, config, items): # it's to suit the rather odd use case here which is that we are adding # new classes to a module on the fly. + from sqlalchemy.testing import asyncio + rebuilt_items = collections.defaultdict( lambda: collections.defaultdict(list) ) @@ -201,20 +210,26 @@ def pytest_collection_modifyitems(session, config, items): ] test_classes = set(item.parent for item in items) - for test_class in test_classes: - for sub_cls in plugin_base.generate_sub_tests( - test_class.cls, test_class.parent.module - ): - if sub_cls is not test_class.cls: - per_cls_dict = rebuilt_items[test_class.cls] - # support pytest 5.4.0 and above pytest.Class.from_parent - ctor = getattr(pytest.Class, "from_parent", pytest.Class) - for inst in ctor( - name=sub_cls.__name__, parent=test_class.parent.parent - ).collect(): - for t in inst.collect(): - per_cls_dict[t.name].append(t) + def setup_test_classes(): + for test_class in test_classes: + for sub_cls in plugin_base.generate_sub_tests( + test_class.cls, test_class.parent.module + ): + if sub_cls is not test_class.cls: + per_cls_dict = rebuilt_items[test_class.cls] + + # support pytest 5.4.0 and above pytest.Class.from_parent + ctor = getattr(pytest.Class, "from_parent", pytest.Class) + for inst in ctor( + name=sub_cls.__name__, parent=test_class.parent.parent + ).collect(): + for t in inst.collect(): + per_cls_dict[t.name].append(t) + + # class requirements will sometimes need to access the DB to check + # capabilities, so need to do this for async + asyncio._maybe_async_provisioning(setup_test_classes) newitems = [] for item in items: @@ -238,6 +253,10 @@ def pytest_collection_modifyitems(session, config, items): def pytest_pycollect_makeitem(collector, name, obj): if inspect.isclass(obj) and plugin_base.want_class(name, obj): + from sqlalchemy.testing import config + + if config.any_async: + obj = _apply_maybe_async(obj) ctor = getattr(pytest.Class, "from_parent", pytest.Class) @@ -258,6 +277,46 @@ def pytest_pycollect_makeitem(collector, name, obj): return [] +def _is_wrapped_coroutine_function(fn): + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + + return inspect.iscoroutinefunction(fn) + + +def _apply_maybe_async(obj, recurse=True): + from sqlalchemy.testing import asyncio + + setup_names = {"setup", "setup_class", "teardown", "teardown_class"} + for name, value in vars(obj).items(): + if ( + (callable(value) or isinstance(value, classmethod)) + and not getattr(value, "_maybe_async_applied", False) + and (name.startswith("test_") or name in setup_names) + and not _is_wrapped_coroutine_function(value) + ): + is_classmethod = False + if isinstance(value, classmethod): + value = value.__func__ + is_classmethod = True + + @_pytest_fn_decorator + def make_async(fn, *args, **kwargs): + return asyncio._maybe_async(fn, *args, **kwargs) + + do_async = make_async(value) + if is_classmethod: + do_async = classmethod(do_async) + do_async._maybe_async_applied = True + + setattr(obj, name, do_async) + if recurse: + for cls in obj.mro()[1:]: + if cls != object: + _apply_maybe_async(cls, False) + return obj + + _current_class = None @@ -297,6 +356,8 @@ def _parametrize_cls(module, cls): def pytest_runtest_setup(item): + from sqlalchemy.testing import asyncio + # here we seem to get called only based on what we collected # in pytest_collection_modifyitems. So to do class-based stuff # we have to tear that out. @@ -307,7 +368,7 @@ def pytest_runtest_setup(item): # ... so we're doing a little dance here to figure it out... if _current_class is None: - class_setup(item.parent.parent) + asyncio._maybe_async(class_setup, item.parent.parent) _current_class = item.parent.parent # this is needed for the class-level, to ensure that the @@ -315,20 +376,22 @@ def pytest_runtest_setup(item): # class-level teardown... def finalize(): global _current_class - class_teardown(item.parent.parent) + asyncio._maybe_async(class_teardown, item.parent.parent) _current_class = None item.parent.parent.addfinalizer(finalize) - test_setup(item) + asyncio._maybe_async(test_setup, item) def pytest_runtest_teardown(item): + from sqlalchemy.testing import asyncio + # ...but this works better as the hook here rather than # using a finalizer, as the finalizer seems to get in the way # of the test reporting failures correctly (you get a bunch of # pytest assertion stuff instead) - test_teardown(item) + asyncio._maybe_async(test_teardown, item) def test_setup(item): @@ -342,7 +405,9 @@ def test_teardown(item): def class_setup(item): - plugin_base.start_test_class(item.cls) + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls) def class_teardown(item): @@ -372,17 +437,19 @@ def _pytest_fn_decorator(target): if add_positional_parameters: spec.args.extend(add_positional_parameters) - metadata = dict(target="target", fn="__fn", name=fn.__name__) + metadata = dict( + __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__ + ) metadata.update(format_argspec_plus(spec, grouped=False)) code = ( """\ def %(name)s(%(args)s): - return %(target)s(%(fn)s, %(apply_kw)s) + return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s) """ % metadata ) decorated = _exec_code_in_env( - code, {"target": target, "__fn": fn}, fn.__name__ + code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__ ) if not add_positional_parameters: decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ @@ -554,14 +621,49 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): return pytest.param(*parameters[1:], id=ident) def fixture(self, *arg, **kw): - return pytest.fixture(*arg, **kw) + from sqlalchemy.testing import config + from sqlalchemy.testing import asyncio + + # wrapping pytest.fixture function. determine if + # decorator was called as @fixture or @fixture(). + if len(arg) > 0 and callable(arg[0]): + # was called as @fixture(), we have the function to wrap. + fn = arg[0] + arg = arg[1:] + else: + # was called as @fixture, don't have the function yet. + fn = None + + # create a pytest.fixture marker. because the fn is not being + # passed, this is always a pytest.FixtureFunctionMarker() + # object (or whatever pytest is calling it when you read this) + # that is waiting for a function. + fixture = pytest.fixture(*arg, **kw) + + # now apply wrappers to the function, including fixture itself + + def wrap(fn): + if config.any_async: + fn = asyncio._maybe_async_wrapper(fn) + # other wrappers may be added here + + # now apply FixtureFunctionMarker + fn = fixture(fn) + return fn + + if fn: + return wrap(fn) + else: + return wrap def get_current_test_name(self): return os.environ.get("PYTEST_CURRENT_TEST") def async_test(self, fn): + from sqlalchemy.testing import asyncio + @_pytest_fn_decorator def decorate(fn, *args, **kwargs): - asyncio.get_event_loop().run_until_complete(fn(*args, **kwargs)) + asyncio._run_coroutine_function(fn, *args, **kwargs) return decorate(fn) diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index 48e44428b..16c6d458c 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -1,5 +1,5 @@ # testing/profiling.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index c4f489a69..4ee0567f2 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -3,10 +3,15 @@ import logging from . import config from . import engines +from . import util from .. import exc +from .. import inspect from ..engine import url as sa_url +from ..sql import ddl +from ..sql import schema from ..util import compat + log = logging.getLogger(__name__) FOLLOWER_IDENT = None @@ -70,7 +75,8 @@ def setup_config(db_url, options, file_config, follower_ident): # a symbolic name that tests can use if they need to disambiguate # names across databases - config.ident = follower_ident + if follower_ident: + config.ident = follower_ident if follower_ident: configure_follower(cfg, follower_ident) @@ -94,11 +100,11 @@ def generate_db_urls(db_urls, extra_drivers): --dburi postgresql://db2 \ --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true - Noting that the default postgresql driver is psycopg2. the output + Noting that the default postgresql driver is psycopg2, the output would be:: postgresql+psycopg2://db1 - postgresql+asyncpg://db1?async_fallback=true + postgresql+asyncpg://db1 postgresql+psycopg2://db2 postgresql+psycopg2://db3 @@ -108,6 +114,12 @@ def generate_db_urls(db_urls, extra_drivers): for a driver that is both coming from --dburi as well as --dbdrivers, we want to keep it in that dburi. + Driver specific query options can be specified by added them to the + driver name. For example, to enable the async fallback option for + asyncpg:: + + --dburi postgresql://db1 \ + --dbdriver=asyncpg?async_fallback=true """ urls = set() @@ -205,6 +217,63 @@ def _configs_for_db_operation(): @register.init +def drop_all_schema_objects_pre_tables(cfg, eng): + pass + + +@register.init +def drop_all_schema_objects_post_tables(cfg, eng): + pass + + +def drop_all_schema_objects(cfg, eng): + + drop_all_schema_objects_pre_tables(cfg, eng) + + inspector = inspect(eng) + try: + view_names = inspector.get_view_names() + except NotImplementedError: + pass + else: + with eng.begin() as conn: + for vname in view_names: + conn.execute( + ddl._DropView(schema.Table(vname, schema.MetaData())) + ) + + if config.requirements.schemas.enabled_for_config(cfg): + try: + view_names = inspector.get_view_names(schema="test_schema") + except NotImplementedError: + pass + else: + with eng.begin() as conn: + for vname in view_names: + conn.execute( + ddl._DropView( + schema.Table( + vname, + schema.MetaData(), + schema="test_schema", + ) + ) + ) + + util.drop_all_tables(eng, inspector) + + if config.requirements.schemas.enabled_for_config(cfg): + util.drop_all_tables(eng, inspector, schema=cfg.test_schema) + + drop_all_schema_objects_post_tables(cfg, eng) + + if config.requirements.sequences.enabled_for_config(cfg): + with eng.begin() as conn: + for seq in inspector.get_sequence_names(): + conn.execute(ddl.DropSequence(schema.Sequence(seq))) + + +@register.init def create_db(cfg, eng, ident): """Dynamically create a database for testing. diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 30b42cbf3..3999c5793 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1,5 +1,5 @@ # testing/requirements.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 8e26d2eaf..22b1f7b77 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -1,5 +1,5 @@ # testing/schema.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index f1c573662..6c3c1005a 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -207,10 +207,10 @@ class QuotedNameArgumentTest(fixtures.TablesTest): ] for name in names: query = "CREATE VIEW %s AS SELECT * FROM %s" % ( - testing.db.dialect.identifier_preparer.quote( + config.db.dialect.identifier_preparer.quote( "view %s" % name ), - testing.db.dialect.identifier_preparer.quote(name), + config.db.dialect.identifier_preparer.quote(name), ) event.listen(metadata, "after_create", DDL(query)) @@ -219,7 +219,7 @@ class QuotedNameArgumentTest(fixtures.TablesTest): "before_drop", DDL( "DROP VIEW %s" - % testing.db.dialect.identifier_preparer.quote( + % config.db.dialect.identifier_preparer.quote( "view %s" % name ) ), @@ -233,52 +233,52 @@ class QuotedNameArgumentTest(fixtures.TablesTest): @quote_fixtures def test_get_table_options(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) insp.get_table_options(name) @quote_fixtures @testing.requires.view_column_reflection def test_get_view_definition(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_view_definition("view %s" % name) @quote_fixtures def test_get_columns(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_columns(name) @quote_fixtures def test_get_pk_constraint(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_pk_constraint(name) @quote_fixtures def test_get_foreign_keys(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_foreign_keys(name) @quote_fixtures def test_get_indexes(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_indexes(name) @quote_fixtures @testing.requires.unique_constraint_reflection def test_get_unique_constraints(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_unique_constraints(name) @quote_fixtures @testing.requires.comment_reflection def test_get_table_comment(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_table_comment(name) @quote_fixtures @testing.requires.check_constraint_reflection def test_get_check_constraints(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_check_constraints(name) @@ -451,7 +451,9 @@ class ComponentReflectionTest(fixtures.TablesTest): @classmethod def define_temp_tables(cls, metadata): kw = temp_table_keyword_args(config, config.db) - table_name = get_temp_table_name(config, config.db, "user_tmp") + table_name = get_temp_table_name( + config, config.db, "user_tmp_%s" % config.ident + ) user_tmp = Table( table_name, metadata, @@ -477,7 +479,7 @@ class ComponentReflectionTest(fixtures.TablesTest): "after_create", DDL( "create temporary view user_tmp_v as " - "select * from user_tmp" + "select * from user_tmp_%s" % config.ident ), ) event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) @@ -506,7 +508,7 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schema_reflection def test_get_schema_names(self): - insp = inspect(testing.db) + insp = inspect(self.bind) self.assert_(testing.config.test_schema in insp.get_schema_names()) @@ -518,13 +520,28 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schema_reflection def test_get_default_schema_name(self): - insp = inspect(testing.db) - eq_(insp.default_schema_name, testing.db.dialect.default_schema_name) - - @testing.provide_metadata - def _test_get_table_names( - self, schema=None, table_type="table", order_by=None + insp = inspect(self.bind) + eq_(insp.default_schema_name, self.bind.dialect.default_schema_name) + + @testing.combinations( + (None, True, False, False), + (None, True, False, True, testing.requires.schemas), + ("foreign_key", True, False, False), + (None, False, True, False), + (None, False, True, True, testing.requires.schemas), + (None, True, True, False), + (None, True, True, True, testing.requires.schemas), + argnames="order_by,include_plain,include_views,use_schema", + ) + def test_get_table_names( + self, connection, order_by, include_plain, include_views, use_schema ): + + if use_schema: + schema = config.test_schema + else: + schema = None + _ignore_tables = [ "comment_test", "noncol_idx_test_pk", @@ -533,16 +550,16 @@ class ComponentReflectionTest(fixtures.TablesTest): "remote_table", "remote_table_2", ] - meta = self.metadata - insp = inspect(meta.bind) + insp = inspect(connection) - if table_type == "view": + if include_views: table_names = insp.get_view_names(schema) table_names.sort() answer = ["email_addresses_v", "users_v"] eq_(sorted(table_names), answer) - else: + + if include_plain: if order_by: tables = [ rec[0] @@ -564,7 +581,7 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_temp_table_names(self): insp = inspect(self.bind) temp_table_names = insp.get_temp_table_names() - eq_(sorted(temp_table_names), ["user_tmp"]) + eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident]) @testing.requires.view_reflection @testing.requires.temp_table_names @@ -574,15 +591,6 @@ class ComponentReflectionTest(fixtures.TablesTest): temp_table_names = insp.get_temp_view_names() eq_(sorted(temp_table_names), ["user_tmp_v"]) - @testing.requires.table_reflection - def test_get_table_names(self): - self._test_get_table_names() - - @testing.requires.table_reflection - @testing.requires.foreign_key_constraint_reflection - def test_get_table_names_fks(self): - self._test_get_table_names(order_by="foreign_key") - @testing.requires.comment_reflection def test_get_comments(self): self._test_get_comments() @@ -593,7 +601,7 @@ class ComponentReflectionTest(fixtures.TablesTest): self._test_get_comments(testing.config.test_schema) def _test_get_comments(self, schema=None): - insp = inspect(testing.db) + insp = inspect(self.bind) eq_( insp.get_table_comment("comment_test", schema=schema), @@ -619,35 +627,27 @@ class ComponentReflectionTest(fixtures.TablesTest): ], ) - @testing.requires.table_reflection - @testing.requires.schemas - def test_get_table_names_with_schema(self): - self._test_get_table_names(testing.config.test_schema) - - @testing.requires.view_column_reflection - def test_get_view_names(self): - self._test_get_table_names(table_type="view") - - @testing.requires.view_column_reflection - @testing.requires.schemas - def test_get_view_names_with_schema(self): - self._test_get_table_names( - testing.config.test_schema, table_type="view" - ) - - @testing.requires.table_reflection - @testing.requires.view_column_reflection - def test_get_tables_and_views(self): - self._test_get_table_names() - self._test_get_table_names(table_type="view") + @testing.combinations( + (False, False), + (False, True, testing.requires.schemas), + (True, False), + (False, True, testing.requires.schemas), + argnames="use_views,use_schema", + ) + def test_get_columns(self, connection, use_views, use_schema): + + if use_schema: + schema = config.test_schema + else: + schema = None - def _test_get_columns(self, schema=None, table_type="table"): - meta = MetaData(testing.db) users, addresses = (self.tables.users, self.tables.email_addresses) - table_names = ["users", "email_addresses"] - if table_type == "view": + if use_views: table_names = ["users_v", "email_addresses_v"] - insp = inspect(meta.bind) + else: + table_names = ["users", "email_addresses"] + + insp = inspect(connection) for table_name, table in zip(table_names, (users, addresses)): schema_name = schema cols = insp.get_columns(table_name, schema=schema_name) @@ -697,65 +697,13 @@ class ComponentReflectionTest(fixtures.TablesTest): if not col.primary_key: assert cols[i]["default"] is None - @testing.requires.table_reflection - def test_get_columns(self): - self._test_get_columns() - - @testing.provide_metadata - def _type_round_trip(self, *types): - t = Table( - "t", - self.metadata, - *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] - ) - t.create() - - return [ - c["type"] for c in inspect(self.metadata.bind).get_columns("t") - ] - - @testing.requires.table_reflection - def test_numeric_reflection(self): - for typ in self._type_round_trip(sql_types.Numeric(18, 5)): - assert isinstance(typ, sql_types.Numeric) - eq_(typ.precision, 18) - eq_(typ.scale, 5) - - @testing.requires.table_reflection - def test_varchar_reflection(self): - typ = self._type_round_trip(sql_types.String(52))[0] - assert isinstance(typ, sql_types.String) - eq_(typ.length, 52) - - @testing.requires.table_reflection - @testing.provide_metadata - def test_nullable_reflection(self): - t = Table( - "t", - self.metadata, - Column("a", Integer, nullable=True), - Column("b", Integer, nullable=False), - ) - t.create() - eq_( - dict( - (col["name"], col["nullable"]) - for col in inspect(self.metadata.bind).get_columns("t") - ), - {"a": True, "b": False}, - ) - - @testing.requires.table_reflection - @testing.requires.schemas - def test_get_columns_with_schema(self): - self._test_get_columns(schema=testing.config.test_schema) - @testing.requires.temp_table_reflection def test_get_temp_table_columns(self): - table_name = get_temp_table_name(config, config.db, "user_tmp") - meta = MetaData(self.bind) + table_name = get_temp_table_name( + config, self.bind, "user_tmp_%s" % config.ident + ) user_tmp = self.tables[table_name] - insp = inspect(meta.bind) + insp = inspect(self.bind) cols = insp.get_columns(table_name) self.assert_(len(cols) > 0, len(cols)) @@ -770,22 +718,18 @@ class ComponentReflectionTest(fixtures.TablesTest): cols = insp.get_columns("user_tmp_v") eq_([col["name"] for col in cols], ["id", "name", "foo"]) - @testing.requires.view_column_reflection - def test_get_view_columns(self): - self._test_get_columns(table_type="view") - - @testing.requires.view_column_reflection - @testing.requires.schemas - def test_get_view_columns_with_schema(self): - self._test_get_columns( - schema=testing.config.test_schema, table_type="view" - ) + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.primary_key_constraint_reflection + def test_get_pk_constraint(self, connection, use_schema): + if use_schema: + schema = testing.config.test_schema + else: + schema = None - @testing.provide_metadata - def _test_get_pk_constraint(self, schema=None): - meta = self.metadata users, addresses = self.tables.users, self.tables.email_addresses - insp = inspect(meta.bind) + insp = inspect(connection) users_cons = insp.get_pk_constraint(users.name, schema=schema) users_pkeys = users_cons["constrained_columns"] @@ -798,21 +742,18 @@ class ComponentReflectionTest(fixtures.TablesTest): with testing.requires.reflects_pk_names.fail_if(): eq_(addr_cons["name"], "email_ad_pk") - @testing.requires.primary_key_constraint_reflection - def test_get_pk_constraint(self): - self._test_get_pk_constraint() - - @testing.requires.table_reflection - @testing.requires.primary_key_constraint_reflection - @testing.requires.schemas - def test_get_pk_constraint_with_schema(self): - self._test_get_pk_constraint(schema=testing.config.test_schema) + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.foreign_key_constraint_reflection + def test_get_foreign_keys(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None - @testing.provide_metadata - def _test_get_foreign_keys(self, schema=None): - meta = self.metadata users, addresses = (self.tables.users, self.tables.email_addresses) - insp = inspect(meta.bind) + insp = inspect(connection) expected_schema = schema # users @@ -841,25 +782,16 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(fkey1["referred_columns"], ["user_id"]) eq_(fkey1["constrained_columns"], ["remote_user_id"]) - @testing.requires.foreign_key_constraint_reflection - def test_get_foreign_keys(self): - self._test_get_foreign_keys() - - @testing.requires.foreign_key_constraint_reflection - @testing.requires.schemas - def test_get_foreign_keys_with_schema(self): - self._test_get_foreign_keys(schema=testing.config.test_schema) - @testing.requires.cross_schema_fk_reflection @testing.requires.schemas def test_get_inter_schema_foreign_keys(self): local_table, remote_table, remote_table_2 = self.tables( - "%s.local_table" % testing.db.dialect.default_schema_name, + "%s.local_table" % self.bind.dialect.default_schema_name, "%s.remote_table" % testing.config.test_schema, "%s.remote_table_2" % testing.config.test_schema, ) - insp = inspect(config.db) + insp = inspect(self.bind) local_fkeys = insp.get_foreign_keys(local_table.name) eq_(len(local_fkeys), 1) @@ -879,85 +811,12 @@ class ComponentReflectionTest(fixtures.TablesTest): assert fkey2["referred_schema"] in ( None, - testing.db.dialect.default_schema_name, + self.bind.dialect.default_schema_name, ) eq_(fkey2["referred_table"], local_table.name) eq_(fkey2["referred_columns"], ["id"]) eq_(fkey2["constrained_columns"], ["local_id"]) - @testing.requires.foreign_key_constraint_option_reflection_ondelete - def test_get_foreign_key_options_ondelete(self): - self._test_get_foreign_key_options(ondelete="CASCADE") - - @testing.requires.foreign_key_constraint_option_reflection_onupdate - def test_get_foreign_key_options_onupdate(self): - self._test_get_foreign_key_options(onupdate="SET NULL") - - @testing.requires.foreign_key_constraint_option_reflection_onupdate - def test_get_foreign_key_options_onupdate_noaction(self): - self._test_get_foreign_key_options(onupdate="NO ACTION", expected={}) - - @testing.requires.fk_constraint_option_reflection_ondelete_noaction - def test_get_foreign_key_options_ondelete_noaction(self): - self._test_get_foreign_key_options(ondelete="NO ACTION", expected={}) - - @testing.requires.fk_constraint_option_reflection_onupdate_restrict - def test_get_foreign_key_options_onupdate_restrict(self): - self._test_get_foreign_key_options(onupdate="RESTRICT") - - @testing.requires.fk_constraint_option_reflection_ondelete_restrict - def test_get_foreign_key_options_ondelete_restrict(self): - self._test_get_foreign_key_options(ondelete="RESTRICT") - - @testing.provide_metadata - def _test_get_foreign_key_options(self, expected=None, **options): - meta = self.metadata - - if expected is None: - expected = options - - Table( - "x", - meta, - Column("id", Integer, primary_key=True), - test_needs_fk=True, - ) - - Table( - "table", - meta, - Column("id", Integer, primary_key=True), - Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")), - Column("test", String(10)), - test_needs_fk=True, - ) - - Table( - "user", - meta, - Column("id", Integer, primary_key=True), - Column("name", String(50), nullable=False), - Column("tid", Integer), - sa.ForeignKeyConstraint( - ["tid"], ["table.id"], name="myfk", **options - ), - test_needs_fk=True, - ) - - meta.create_all() - - insp = inspect(meta.bind) - - # test 'options' is always present for a backend - # that can reflect these, since alembic looks for this - opts = insp.get_foreign_keys("table")[0]["options"] - - eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) - - opts = insp.get_foreign_keys("user")[0]["options"] - eq_(opts, expected) - # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) - def _assert_insp_indexes(self, indexes, expected_indexes): index_names = [d["name"] for d in indexes] for e_index in expected_indexes: @@ -966,13 +825,19 @@ class ComponentReflectionTest(fixtures.TablesTest): for key in e_index: eq_(e_index[key], index[key]) - @testing.provide_metadata - def _test_get_indexes(self, schema=None): - meta = self.metadata + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_indexes(self, connection, use_schema): + + if use_schema: + schema = config.test_schema + else: + schema = None # The database may decide to create indexes for foreign keys, etc. # so there may be more indexes than expected. - insp = inspect(meta.bind) + insp = inspect(self.bind) indexes = insp.get_indexes("users", schema=schema) expected_indexes = [ { @@ -988,19 +853,15 @@ class ComponentReflectionTest(fixtures.TablesTest): ] self._assert_insp_indexes(indexes, expected_indexes) + @testing.combinations( + ("noncol_idx_test_nopk", "noncol_idx_nopk"), + ("noncol_idx_test_pk", "noncol_idx_pk"), + argnames="tname,ixname", + ) @testing.requires.index_reflection - def test_get_indexes(self): - self._test_get_indexes() - - @testing.requires.index_reflection - @testing.requires.schemas - def test_get_indexes_with_schema(self): - self._test_get_indexes(schema=testing.config.test_schema) - - @testing.provide_metadata - def _test_get_noncol_index(self, tname, ixname): - meta = self.metadata - insp = inspect(meta.bind) + @testing.requires.indexes_with_ascdesc + def test_get_noncol_index(self, connection, tname, ixname): + insp = inspect(connection) indexes = insp.get_indexes(tname) # reflecting an index that has "x DESC" in it as the column. @@ -1009,90 +870,16 @@ class ComponentReflectionTest(fixtures.TablesTest): expected_indexes = [{"unique": False, "name": ixname}] self._assert_insp_indexes(indexes, expected_indexes) - t = Table(tname, meta, autoload_with=meta.bind) + t = Table(tname, MetaData(), autoload_with=connection) eq_(len(t.indexes), 1) is_(list(t.indexes)[0].table, t) eq_(list(t.indexes)[0].name, ixname) - @testing.requires.index_reflection - @testing.requires.indexes_with_ascdesc - def test_get_noncol_index_no_pk(self): - self._test_get_noncol_index("noncol_idx_test_nopk", "noncol_idx_nopk") - - @testing.requires.index_reflection - @testing.requires.indexes_with_ascdesc - def test_get_noncol_index_pk(self): - self._test_get_noncol_index("noncol_idx_test_pk", "noncol_idx_pk") - - @testing.requires.indexes_with_expressions - @testing.provide_metadata - def test_reflect_expression_based_indexes(self): - t = Table( - "t", - self.metadata, - Column("x", String(30)), - Column("y", String(30)), - ) - - Index("t_idx", func.lower(t.c.x), func.lower(t.c.y)) - - Index("t_idx_2", t.c.x) - - self.metadata.create_all(testing.db) - - insp = inspect(testing.db) - - expected = [ - {"name": "t_idx_2", "column_names": ["x"], "unique": False} - ] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] - - with expect_warnings( - "Skipped unsupported reflection of expression-based index t_idx" - ): - eq_( - insp.get_indexes("t"), - expected, - ) - - @testing.requires.index_reflects_included_columns - @testing.provide_metadata - def test_reflect_covering_index(self): - t = Table( - "t", - self.metadata, - Column("x", String(30)), - Column("y", String(30)), - ) - idx = Index("t_idx", t.c.x) - idx.dialect_options[testing.db.name]["include"] = ["y"] - - self.metadata.create_all(testing.db) - - insp = inspect(testing.db) - - eq_( - insp.get_indexes("t"), - [ - { - "name": "t_idx", - "column_names": ["x"], - "include_columns": ["y"], - "unique": False, - } - ], - ) - - @testing.requires.unique_constraint_reflection - def test_get_unique_constraints(self): - self._test_get_unique_constraints() - @testing.requires.temp_table_reflection @testing.requires.unique_constraint_reflection def test_get_temp_table_unique_constraints(self): insp = inspect(self.bind) - reflected = insp.get_unique_constraints("user_tmp") + reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident) for refl in reflected: # Different dialects handle duplicate index and constraints # differently, so ignore this flag @@ -1110,7 +897,9 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.temp_table_reflect_indexes def test_get_temp_table_indexes(self): insp = inspect(self.bind) - table_name = get_temp_table_name(config, config.db, "user_tmp") + table_name = get_temp_table_name( + config, config.db, "user_tmp_%s" % config.ident + ) indexes = insp.get_indexes(table_name) for ind in indexes: ind.pop("dialect_options", None) @@ -1124,19 +913,22 @@ class ComponentReflectionTest(fixtures.TablesTest): expected, ) + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) @testing.requires.unique_constraint_reflection - @testing.requires.schemas - def test_get_unique_constraints_with_schema(self): - self._test_get_unique_constraints(schema=testing.config.test_schema) - - @testing.provide_metadata - def _test_get_unique_constraints(self, schema=None): + def test_get_unique_constraints(self, metadata, connection, use_schema): # SQLite dialect needs to parse the names of the constraints # separately from what it gets from PRAGMA index_list(), and # then matches them up. so same set of column_names in two # constraints will confuse it. Perhaps we should no longer # bother with index_list() here since we have the whole # CREATE TABLE? + + if use_schema: + schema = config.test_schema + else: + schema = None uniques = sorted( [ {"name": "unique_a", "column_names": ["a"]}, @@ -1148,10 +940,9 @@ class ComponentReflectionTest(fixtures.TablesTest): ], key=operator.itemgetter("name"), ) - orig_meta = self.metadata table = Table( "testtbl", - orig_meta, + metadata, Column("a", sa.String(20)), Column("b", sa.String(30)), Column("c", sa.Integer), @@ -1164,9 +955,9 @@ class ComponentReflectionTest(fixtures.TablesTest): table.append_constraint( sa.UniqueConstraint(*uc["column_names"], name=uc["name"]) ) - orig_meta.create_all() + table.create(connection) - inspector = inspect(orig_meta.bind) + inspector = inspect(connection) reflected = sorted( inspector.get_unique_constraints("testtbl", schema=schema), key=operator.itemgetter("name"), @@ -1186,7 +977,7 @@ class ComponentReflectionTest(fixtures.TablesTest): reflected = Table( "testtbl", reflected_metadata, - autoload_with=orig_meta.bind, + autoload_with=connection, schema=schema, ) @@ -1208,30 +999,90 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(names_that_duplicate_index, idx_names) eq_(uq_names, set()) - @testing.requires.check_constraint_reflection - def test_get_check_constraints(self): - self._test_get_check_constraints() + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_view_definition(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + view_name1 = "users_v" + view_name2 = "email_addresses_v" + insp = inspect(connection) + v1 = insp.get_view_definition(view_name1, schema=schema) + self.assert_(v1) + v2 = insp.get_view_definition(view_name2, schema=schema) + self.assert_(v2) + + # why is this here if it's PG specific ? + @testing.combinations( + ("users", False), + ("users", True, testing.requires.schemas), + argnames="table_name,use_schema", + ) + @testing.only_on("postgresql", "PG specific feature") + def test_get_table_oid(self, connection, table_name, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + insp = inspect(connection) + oid = insp.get_table_oid(table_name, schema) + self.assert_(isinstance(oid, int)) + + @testing.requires.table_reflection + def test_autoincrement_col(self): + """test that 'autoincrement' is reflected according to sqla's policy. + + Don't mark this test as unsupported for any backend ! + (technically it fails with MySQL InnoDB since "id" comes before "id2") + + A backend is better off not returning "autoincrement" at all, + instead of potentially returning "False" for an auto-incrementing + primary key column. + + """ + + insp = inspect(self.bind) + + for tname, cname in [ + ("users", "user_id"), + ("email_addresses", "address_id"), + ("dingalings", "dingaling_id"), + ]: + cols = insp.get_columns(tname) + id_ = {c["name"]: c for c in cols}[cname] + assert id_.get("autoincrement", True) + + +class ComponentReflectionTestExtra(fixtures.TestBase): + + __backend__ = True + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) @testing.requires.check_constraint_reflection - @testing.requires.schemas - def test_get_check_constraints_schema(self): - self._test_get_check_constraints(schema=testing.config.test_schema) + def test_get_check_constraints(self, metadata, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None - @testing.provide_metadata - def _test_get_check_constraints(self, schema=None): - orig_meta = self.metadata Table( "sa_cc", - orig_meta, + metadata, Column("a", Integer()), sa.CheckConstraint("a > 1 AND a < 5", name="cc1"), sa.CheckConstraint("a = 1 OR (a > 2 AND a < 5)", name="cc2"), schema=schema, ) - orig_meta.create_all() + metadata.create_all(connection) - inspector = inspect(orig_meta.bind) + inspector = inspect(connection) reflected = sorted( inspector.get_check_constraints("sa_cc", schema=schema), key=operator.itemgetter("name"), @@ -1257,67 +1108,200 @@ class ComponentReflectionTest(fixtures.TablesTest): ], ) - @testing.provide_metadata - def _test_get_view_definition(self, schema=None): - meta = self.metadata - view_name1 = "users_v" - view_name2 = "email_addresses_v" - insp = inspect(meta.bind) - v1 = insp.get_view_definition(view_name1, schema=schema) - self.assert_(v1) - v2 = insp.get_view_definition(view_name2, schema=schema) - self.assert_(v2) + @testing.requires.indexes_with_expressions + def test_reflect_expression_based_indexes(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) - @testing.requires.view_reflection - def test_get_view_definition(self): - self._test_get_view_definition() + Index("t_idx", func.lower(t.c.x), func.lower(t.c.y)) - @testing.requires.view_reflection - @testing.requires.schemas - def test_get_view_definition_with_schema(self): - self._test_get_view_definition(schema=testing.config.test_schema) + Index("t_idx_2", t.c.x) - @testing.only_on("postgresql", "PG specific feature") - @testing.provide_metadata - def _test_get_table_oid(self, table_name, schema=None): - meta = self.metadata - insp = inspect(meta.bind) - oid = insp.get_table_oid(table_name, schema) - self.assert_(isinstance(oid, int)) + metadata.create_all(connection) - def test_get_table_oid(self): - self._test_get_table_oid("users") + insp = inspect(connection) - @testing.requires.schemas - def test_get_table_oid_with_schema(self): - self._test_get_table_oid("users", schema=testing.config.test_schema) + expected = [ + {"name": "t_idx_2", "column_names": ["x"], "unique": False} + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + + with expect_warnings( + "Skipped unsupported reflection of expression-based index t_idx" + ): + eq_( + insp.get_indexes("t"), + expected, + ) + + @testing.requires.index_reflects_included_columns + def test_reflect_covering_index(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) + idx = Index("t_idx", t.c.x) + idx.dialect_options[connection.engine.name]["include"] = ["y"] + + metadata.create_all(connection) + + insp = inspect(connection) + + eq_( + insp.get_indexes("t"), + [ + { + "name": "t_idx", + "column_names": ["x"], + "include_columns": ["y"], + "unique": False, + } + ], + ) + + def _type_round_trip(self, connection, metadata, *types): + t = Table( + "t", + metadata, + *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] + ) + t.create(connection) + + return [c["type"] for c in inspect(connection).get_columns("t")] @testing.requires.table_reflection - @testing.provide_metadata - def test_autoincrement_col(self): - """test that 'autoincrement' is reflected according to sqla's policy. + def test_numeric_reflection(self, connection, metadata): + for typ in self._type_round_trip( + connection, metadata, sql_types.Numeric(18, 5) + ): + assert isinstance(typ, sql_types.Numeric) + eq_(typ.precision, 18) + eq_(typ.scale, 5) - Don't mark this test as unsupported for any backend ! + @testing.requires.table_reflection + def test_varchar_reflection(self, connection, metadata): + typ = self._type_round_trip( + connection, metadata, sql_types.String(52) + )[0] + assert isinstance(typ, sql_types.String) + eq_(typ.length, 52) - (technically it fails with MySQL InnoDB since "id" comes before "id2") + @testing.requires.table_reflection + def test_nullable_reflection(self, connection, metadata): + t = Table( + "t", + metadata, + Column("a", Integer, nullable=True), + Column("b", Integer, nullable=False), + ) + t.create(connection) + eq_( + dict( + (col["name"], col["nullable"]) + for col in inspect(connection).get_columns("t") + ), + {"a": True, "b": False}, + ) - A backend is better off not returning "autoincrement" at all, - instead of potentially returning "False" for an auto-incrementing - primary key column. + @testing.combinations( + ( + None, + "CASCADE", + None, + testing.requires.foreign_key_constraint_option_reflection_ondelete, + ), + ( + None, + None, + "SET NULL", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + None, + "NO ACTION", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + "NO ACTION", + None, + testing.requires.fk_constraint_option_reflection_ondelete_noaction, + ), + ( + None, + None, + "RESTRICT", + testing.requires.fk_constraint_option_reflection_onupdate_restrict, + ), + ( + None, + "RESTRICT", + None, + testing.requires.fk_constraint_option_reflection_ondelete_restrict, + ), + argnames="expected,ondelete,onupdate", + ) + def test_get_foreign_key_options( + self, connection, metadata, expected, ondelete, onupdate + ): + options = {} + if ondelete: + options["ondelete"] = ondelete + if onupdate: + options["onupdate"] = onupdate - """ + if expected is None: + expected = options - meta = self.metadata - insp = inspect(meta.bind) + Table( + "x", + metadata, + Column("id", Integer, primary_key=True), + test_needs_fk=True, + ) - for tname, cname in [ - ("users", "user_id"), - ("email_addresses", "address_id"), - ("dingalings", "dingaling_id"), - ]: - cols = insp.get_columns(tname) - id_ = {c["name"]: c for c in cols}[cname] - assert id_.get("autoincrement", True) + Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")), + Column("test", String(10)), + test_needs_fk=True, + ) + + Table( + "user", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + sa.ForeignKeyConstraint( + ["tid"], ["table.id"], name="myfk", **options + ), + test_needs_fk=True, + ) + + metadata.create_all(connection) + + insp = inspect(connection) + + # test 'options' is always present for a backend + # that can reflect these, since alembic looks for this + opts = insp.get_foreign_keys("table")[0]["options"] + + eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) + + opts = insp.get_foreign_keys("user")[0]["options"] + eq_(opts, expected) + # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) class NormalizedNameTest(fixtures.TablesTest): @@ -1342,21 +1326,21 @@ class NormalizedNameTest(fixtures.TablesTest): m2 = MetaData() t2_ref = Table( - quoted_name("t2", quote=True), m2, autoload_with=testing.db + quoted_name("t2", quote=True), m2, autoload_with=config.db ) t1_ref = m2.tables["t1"] assert t2_ref.c.t1id.references(t1_ref.c.id) m3 = MetaData() m3.reflect( - testing.db, only=lambda name, m: name.lower() in ("t1", "t2") + config.db, only=lambda name, m: name.lower() in ("t1", "t2") ) assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id) def test_get_table_names(self): tablenames = [ t - for t in inspect(testing.db).get_table_names() + for t in inspect(config.db).get_table_names() if t.lower() in ("t1", "t2") ] @@ -1631,20 +1615,16 @@ class CompositeKeyReflectionTest(fixtures.TablesTest): ) @testing.requires.primary_key_constraint_reflection - @testing.provide_metadata def test_pk_column_order(self): # test for issue #5661 - meta = self.metadata - insp = inspect(meta.bind) + insp = inspect(self.bind) primary_key = insp.get_pk_constraint(self.tables.tb1.name) eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"]) @testing.requires.foreign_key_constraint_reflection - @testing.provide_metadata def test_fk_column_order(self): # test for issue #5661 - meta = self.metadata - insp = inspect(meta.bind) + insp = inspect(self.bind) foreign_keys = insp.get_foreign_keys(self.tables.tb2.name) eq_(len(foreign_keys), 1) fkey1 = foreign_keys[0] @@ -1654,6 +1634,7 @@ class CompositeKeyReflectionTest(fixtures.TablesTest): __all__ = ( "ComponentReflectionTest", + "ComponentReflectionTestExtra", "QuotedNameArgumentTest", "HasTableTest", "HasIndexTest", diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py index f3f902abd..bb344237a 100644 --- a/lib/sqlalchemy/testing/suite/test_rowcount.py +++ b/lib/sqlalchemy/testing/suite/test_rowcount.py @@ -1,6 +1,7 @@ from sqlalchemy import bindparam from sqlalchemy import Column from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import String from sqlalchemy import Table @@ -51,12 +52,14 @@ class RowCountTest(fixtures.TablesTest): [{"name": n, "department": d} for n, d in data], ) - def test_basic(self): + def test_basic(self, connection): employees_table = self.tables.employees - s = employees_table.select() - r = s.execute().fetchall() + s = select( + employees_table.c.name, employees_table.c.department + ).order_by(employees_table.c.employee_id) + rows = connection.execute(s).fetchall() - assert len(r) == len(self.data) + eq_(rows, self.data) def test_update_rowcount1(self, connection): employees_table = self.tables.employees diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 43777239c..3a5e02c32 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -47,18 +47,19 @@ from ...util import u class _LiteralRoundTripFixture(object): supports_whereclause = True - @testing.provide_metadata - def _literal_round_trip(self, type_, input_, output, filter_=None): + @testing.fixture + def literal_round_trip(self, metadata, connection): """test literal rendering """ # for literal, we test the literal render in an INSERT # into a typed column. we can then SELECT it back as its # official type; ideally we'd be able to use CAST here # but MySQL in particular can't CAST fully - t = Table("t", self.metadata, Column("x", type_)) - t.create() - with testing.db.begin() as conn: + def run(type_, input_, output, filter_=None): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + for value in input_: ins = ( t.insert() @@ -68,7 +69,7 @@ class _LiteralRoundTripFixture(object): compile_kwargs=dict(literal_binds=True), ) ) - conn.execute(ins) + connection.execute(ins) if self.supports_whereclause: stmt = t.select().where(t.c.x == literal(value)) @@ -79,12 +80,14 @@ class _LiteralRoundTripFixture(object): dialect=testing.db.dialect, compile_kwargs=dict(literal_binds=True), ) - for row in conn.execute(stmt): + for row in connection.execute(stmt): value = row[0] if filter_ is not None: value = filter_(value) assert value in output + return run + class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase): __requires__ = ("unicode_data",) @@ -149,11 +152,11 @@ class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase): row = connection.execute(select(unicode_table.c.unicode_data)).first() eq_(row, (u(""),)) - def test_literal(self): - self._literal_round_trip(self.datatype, [self.data], [self.data]) + def test_literal(self, literal_round_trip): + literal_round_trip(self.datatype, [self.data], [self.data]) - def test_literal_non_ascii(self): - self._literal_round_trip( + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( self.datatype, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] ) @@ -227,25 +230,25 @@ class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): row = connection.execute(select(text_table.c.text_data)).first() eq_(row, (None,)) - def test_literal(self): - self._literal_round_trip(Text, ["some text"], ["some text"]) + def test_literal(self, literal_round_trip): + literal_round_trip(Text, ["some text"], ["some text"]) - def test_literal_non_ascii(self): - self._literal_round_trip( + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( Text, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] ) - def test_literal_quoting(self): + def test_literal_quoting(self, literal_round_trip): data = """some 'text' hey "hi there" that's text""" - self._literal_round_trip(Text, [data], [data]) + literal_round_trip(Text, [data], [data]) - def test_literal_backslashes(self): + def test_literal_backslashes(self, literal_round_trip): data = r"backslash one \ backslash two \\ end" - self._literal_round_trip(Text, [data], [data]) + literal_round_trip(Text, [data], [data]) - def test_literal_percentsigns(self): + def test_literal_percentsigns(self, literal_round_trip): data = r"percent % signs %% percent" - self._literal_round_trip(Text, [data], [data]) + literal_round_trip(Text, [data], [data]) class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): @@ -259,23 +262,23 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): foo.create(config.db) foo.drop(config.db) - def test_literal(self): + def test_literal(self, literal_round_trip): # note that in Python 3, this invokes the Unicode # datatype for the literal part because all strings are unicode - self._literal_round_trip(String(40), ["some text"], ["some text"]) + literal_round_trip(String(40), ["some text"], ["some text"]) - def test_literal_non_ascii(self): - self._literal_round_trip( + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( String(40), [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] ) - def test_literal_quoting(self): + def test_literal_quoting(self, literal_round_trip): data = """some 'text' hey "hi there" that's text""" - self._literal_round_trip(String(40), [data], [data]) + literal_round_trip(String(40), [data], [data]) - def test_literal_backslashes(self): + def test_literal_backslashes(self, literal_round_trip): data = r"backslash one \ backslash two \\ end" - self._literal_round_trip(String(40), [data], [data]) + literal_round_trip(String(40), [data], [data]) class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): @@ -331,9 +334,9 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): eq_(row, (None,)) @testing.requires.datetime_literals - def test_literal(self): + def test_literal(self, literal_round_trip): compare = self.compare or self.data - self._literal_round_trip(self.datatype, [self.data], [compare]) + literal_round_trip(self.datatype, [self.data], [compare]) @testing.requires.standalone_null_binds_whereclause def test_null_bound_comparison(self): @@ -430,36 +433,41 @@ class DateHistoricTest(_DateFixture, fixtures.TablesTest): class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True - def test_literal(self): - self._literal_round_trip(Integer, [5], [5]) + def test_literal(self, literal_round_trip): + literal_round_trip(Integer, [5], [5]) - def test_huge_int(self, connection): - self._round_trip(BigInteger, 1376537018368127, connection) + def test_huge_int(self, integer_round_trip): + integer_round_trip(BigInteger, 1376537018368127) - @testing.provide_metadata - def _round_trip(self, datatype, data, connection): - metadata = self.metadata - int_table = Table( - "integer_table", - metadata, - Column( - "id", Integer, primary_key=True, test_needs_autoincrement=True - ), - Column("integer_data", datatype), - ) + @testing.fixture + def integer_round_trip(self, metadata, connection): + def run(datatype, data): + int_table = Table( + "integer_table", + metadata, + Column( + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("integer_data", datatype), + ) - metadata.create_all(config.db) + metadata.create_all(config.db) - connection.execute(int_table.insert(), {"integer_data": data}) + connection.execute(int_table.insert(), {"integer_data": data}) - row = connection.execute(select(int_table.c.integer_data)).first() + row = connection.execute(select(int_table.c.integer_data)).first() - eq_(row, (data,)) + eq_(row, (data,)) - if util.py3k: - assert isinstance(row[0], int) - else: - assert isinstance(row[0], (long, int)) # noqa + if util.py3k: + assert isinstance(row[0], int) + else: + assert isinstance(row[0], (long, int)) # noqa + + return run class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): @@ -481,12 +489,10 @@ class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): return StringAsInt() - @testing.provide_metadata - def test_special_type(self, connection, string_as_int): + def test_special_type(self, metadata, connection, string_as_int): type_ = string_as_int - metadata = self.metadata t = Table("t", metadata, Column("x", type_)) t.create(connection) @@ -504,42 +510,46 @@ class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True - @testing.emits_warning(r".*does \*not\* support Decimal objects natively") - @testing.provide_metadata - def _do_test(self, type_, input_, output, filter_=None, check_scale=False): - metadata = self.metadata - t = Table("t", metadata, Column("x", type_)) - t.create() - with config.db.begin() as conn: - conn.execute(t.insert(), [{"x": x} for x in input_]) - - result = {row[0] for row in conn.execute(t.select())} - output = set(output) - if filter_: - result = set(filter_(x) for x in result) - output = set(filter_(x) for x in output) - eq_(result, output) - if check_scale: - eq_([str(x) for x in result], [str(x) for x in output]) + @testing.fixture + def do_numeric_test(self, metadata): + @testing.emits_warning( + r".*does \*not\* support Decimal objects natively" + ) + def run(type_, input_, output, filter_=None, check_scale=False): + t = Table("t", metadata, Column("x", type_)) + t.create(testing.db) + with config.db.begin() as conn: + conn.execute(t.insert(), [{"x": x} for x in input_]) + + result = {row[0] for row in conn.execute(t.select())} + output = set(output) + if filter_: + result = set(filter_(x) for x in result) + output = set(filter_(x) for x in output) + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) + + return run @testing.emits_warning(r".*does \*not\* support Decimal objects natively") - def test_render_literal_numeric(self): - self._literal_round_trip( + def test_render_literal_numeric(self, literal_round_trip): + literal_round_trip( Numeric(precision=8, scale=4), [15.7563, decimal.Decimal("15.7563")], [decimal.Decimal("15.7563")], ) @testing.emits_warning(r".*does \*not\* support Decimal objects natively") - def test_render_literal_numeric_asfloat(self): - self._literal_round_trip( + def test_render_literal_numeric_asfloat(self, literal_round_trip): + literal_round_trip( Numeric(precision=8, scale=4, asdecimal=False), [15.7563, decimal.Decimal("15.7563")], [15.7563], ) - def test_render_literal_float(self): - self._literal_round_trip( + def test_render_literal_float(self, literal_round_trip): + literal_round_trip( Float(4), [15.7563, decimal.Decimal("15.7563")], [15.7563], @@ -547,49 +557,49 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): ) @testing.requires.precision_generic_float_type - def test_float_custom_scale(self): - self._do_test( + def test_float_custom_scale(self, do_numeric_test): + do_numeric_test( Float(None, decimal_return_scale=7, asdecimal=True), [15.7563827, decimal.Decimal("15.7563827")], [decimal.Decimal("15.7563827")], check_scale=True, ) - def test_numeric_as_decimal(self): - self._do_test( + def test_numeric_as_decimal(self, do_numeric_test): + do_numeric_test( Numeric(precision=8, scale=4), [15.7563, decimal.Decimal("15.7563")], [decimal.Decimal("15.7563")], ) - def test_numeric_as_float(self): - self._do_test( + def test_numeric_as_float(self, do_numeric_test): + do_numeric_test( Numeric(precision=8, scale=4, asdecimal=False), [15.7563, decimal.Decimal("15.7563")], [15.7563], ) @testing.requires.fetch_null_from_numeric - def test_numeric_null_as_decimal(self): - self._do_test(Numeric(precision=8, scale=4), [None], [None]) + def test_numeric_null_as_decimal(self, do_numeric_test): + do_numeric_test(Numeric(precision=8, scale=4), [None], [None]) @testing.requires.fetch_null_from_numeric - def test_numeric_null_as_float(self): - self._do_test( + def test_numeric_null_as_float(self, do_numeric_test): + do_numeric_test( Numeric(precision=8, scale=4, asdecimal=False), [None], [None] ) @testing.requires.floats_to_four_decimals - def test_float_as_decimal(self): - self._do_test( + def test_float_as_decimal(self, do_numeric_test): + do_numeric_test( Float(precision=8, asdecimal=True), [15.7563, decimal.Decimal("15.7563"), None], [decimal.Decimal("15.7563"), None], filter_=lambda n: n is not None and round(n, 4) or None, ) - def test_float_as_float(self): - self._do_test( + def test_float_as_float(self, do_numeric_test): + do_numeric_test( Float(precision=8), [15.7563, decimal.Decimal("15.7563")], [15.7563], @@ -621,7 +631,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): eq_(val, expr) @testing.requires.precision_numerics_general - def test_precision_decimal(self): + def test_precision_decimal(self, do_numeric_test): numbers = set( [ decimal.Decimal("54.234246451650"), @@ -630,10 +640,10 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): ] ) - self._do_test(Numeric(precision=18, scale=12), numbers, numbers) + do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers) @testing.requires.precision_numerics_enotation_large - def test_enotation_decimal(self): + def test_enotation_decimal(self, do_numeric_test): """test exceedingly small decimals. Decimal reports values with E notation when the exponent @@ -657,10 +667,10 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): decimal.Decimal("696E-12"), ] ) - self._do_test(Numeric(precision=18, scale=14), numbers, numbers) + do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers) @testing.requires.precision_numerics_enotation_large - def test_enotation_decimal_large(self): + def test_enotation_decimal_large(self, do_numeric_test): """test exceedingly large decimals.""" numbers = set( @@ -671,10 +681,10 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): decimal.Decimal("00000000000000.1E+12"), ] ) - self._do_test(Numeric(precision=25, scale=2), numbers, numbers) + do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers) @testing.requires.precision_numerics_many_significant_digits - def test_many_significant_digits(self): + def test_many_significant_digits(self, do_numeric_test): numbers = set( [ decimal.Decimal("31943874831932418390.01"), @@ -682,12 +692,12 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): decimal.Decimal("87673.594069654243"), ] ) - self._do_test(Numeric(precision=38, scale=12), numbers, numbers) + do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers) @testing.requires.precision_numerics_retains_significant_digits - def test_numeric_no_decimal(self): + def test_numeric_no_decimal(self, do_numeric_test): numbers = set([decimal.Decimal("1.000")]) - self._do_test( + do_numeric_test( Numeric(precision=5, scale=3), numbers, numbers, check_scale=True ) @@ -705,8 +715,8 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): Column("unconstrained_value", Boolean(create_constraint=False)), ) - def test_render_literal_bool(self): - self._literal_round_trip(Boolean(), [True, False], [True, False]) + def test_render_literal_bool(self, literal_round_trip): + literal_round_trip(Boolean(), [True, False], [True, False]) def test_round_trip(self, connection): boolean_table = self.tables.boolean_table diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index c6626b9e0..eb9fcd1cd 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -1,5 +1,5 @@ # testing/util.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -11,13 +11,24 @@ import random import sys import types +from . import config from . import mock +from .. import inspect +from ..schema import Column +from ..schema import DropConstraint +from ..schema import DropTable +from ..schema import ForeignKeyConstraint +from ..schema import MetaData +from ..schema import Table +from ..sql import schema +from ..sql.sqltypes import Integer from ..util import decorator from ..util import defaultdict from ..util import has_refcount_gc from ..util import inspect_getfullargspec from ..util import py2k + if not has_refcount_gc: def non_refcount_gc_collect(*args): @@ -198,11 +209,11 @@ def fail(msg): def provide_metadata(fn, *args, **kw): """Provide bound MetaData for a single test, dropping afterwards.""" - from . import config + # import cycle that only occurs with py2k's import resolver + # in py3k this can be moved top level. from . import engines - from sqlalchemy import schema - metadata = schema.MetaData(config.db) + metadata = schema.MetaData() self = args[0] prev_meta = getattr(self, "metadata", None) self.metadata = metadata @@ -243,8 +254,6 @@ def flag_combinations(*combinations): """ - from . import config - keys = set() for d in combinations: @@ -264,8 +273,6 @@ def flag_combinations(*combinations): def lambda_combinations(lambda_arg_sets, **kw): - from . import config - args = inspect_getfullargspec(lambda_arg_sets) arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]]) @@ -302,11 +309,8 @@ def resolve_lambda(__fn, **kw): def metadata_fixture(ddl="function"): """Provide MetaData for a pytest fixture.""" - from . import config - def decorate(fn): def run_ddl(self): - from sqlalchemy import schema metadata = self.metadata = schema.MetaData() try: @@ -328,8 +332,6 @@ def force_drop_names(*names): isolating for foreign key cycles """ - from . import config - from sqlalchemy import inspect @decorator def go(fn, *args, **kw): @@ -358,14 +360,6 @@ class adict(dict): def drop_all_tables(engine, inspector, schema=None, include_names=None): - from sqlalchemy import ( - Column, - Table, - Integer, - MetaData, - ForeignKeyConstraint, - ) - from sqlalchemy.schema import DropTable, DropConstraint if include_names is not None: include_names = set(include_names) diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index 735fb82e4..1b078a263 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -1,5 +1,5 @@ # testing/warnings.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -62,12 +62,6 @@ def setup_filters(): r"The Connection.connect\(\) method is considered legacy", # r".*DefaultGenerator.execute\(\)", # - # bound metadaa - # - r"The MetaData.bind argument is deprecated", - r"The ``bind`` argument for schema methods that invoke SQL ", - r"The Function.bind argument", - r"The select.bind argument", # # result sets # diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 719b05018..9340546ca 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -1,5 +1,5 @@ # types.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 2e3f68722..5f8788a6e 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -1,5 +1,5 @@ # util/__init__.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 69457994a..b18cc13de 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -1,5 +1,5 @@ # util/_collections.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index 8ad3be543..663d3e0f4 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -64,7 +64,6 @@ def await_fallback(awaitable: Coroutine) -> Any: :param awaitable: The coroutine to call. """ - # this is called in the context greenlet while running fn current = greenlet.getcurrent() if not isinstance(current, _AsyncIoGreenlet): @@ -135,3 +134,27 @@ class AsyncAdaptedLock: def __exit__(self, *arg, **kw): self.mutex.release() + + +def _util_async_run_coroutine_function(fn, *args, **kwargs): + """for test suite/ util only""" + + loop = asyncio.get_event_loop() + if loop.is_running(): + raise Exception( + "for async run coroutine we expect that no greenlet or event " + "loop is running when we start out" + ) + return loop.run_until_complete(fn(*args, **kwargs)) + + +def _util_async_run(fn, *args, **kwargs): + """for test suite/ util only""" + + loop = asyncio.get_event_loop() + if not loop.is_running(): + return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs)) + else: + # allow for a wrapped test function to call another + assert isinstance(greenlet.getcurrent(), _AsyncIoGreenlet) + return fn(*args, **kwargs) diff --git a/lib/sqlalchemy/util/_preloaded.py b/lib/sqlalchemy/util/_preloaded.py index e267c008c..2e0c2625d 100644 --- a/lib/sqlalchemy/util/_preloaded.py +++ b/lib/sqlalchemy/util/_preloaded.py @@ -1,5 +1,5 @@ # util/_preloaded.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 77c913640..1eed2c3af 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -1,5 +1,5 @@ # util/compat.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index f78c0971c..c44efba62 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -13,6 +13,10 @@ if compat.py3k: from ._concurrency_py3k import await_fallback from ._concurrency_py3k import greenlet_spawn from ._concurrency_py3k import AsyncAdaptedLock + from ._concurrency_py3k import _util_async_run # noqa F401 + from ._concurrency_py3k import ( + _util_async_run_coroutine_function, + ) # noqa F401, E501 from ._concurrency_py3k import asyncio # noqa F401 if not have_greenlet: @@ -38,3 +42,9 @@ if not have_greenlet: def AsyncAdaptedLock(*args, **kw): # noqa F81 _not_implemented() + + def _util_async_run(fn, *arg, **kw): # noqa F81 + return fn(*arg, **kw) + + def _util_async_run_coroutine_function(fn, *arg, **kw): # noqa F81 + _not_implemented() diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index a4c9d9d0e..5d55a3ae6 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -1,5 +1,5 @@ # util/deprecations.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index b0963ce43..f6d44f708 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1,5 +1,5 @@ # util/langhelpers.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index 3687dc8dc..99ecb4fb3 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -1,5 +1,5 @@ # util/queue.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -25,6 +25,7 @@ from . import compat from .compat import threading from .concurrency import asyncio from .concurrency import await_fallback +from .concurrency import await_only __all__ = ["Empty", "Full", "Queue"] @@ -202,7 +203,7 @@ class Queue: class AsyncAdaptedQueue: - await_ = staticmethod(await_fallback) + await_ = staticmethod(await_only) def __init__(self, maxsize=0, use_lifo=False): if use_lifo: @@ -265,3 +266,7 @@ class AsyncAdaptedQueue: Empty(), replace_context=err, ) + + +class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue): + await_ = staticmethod(await_fallback) diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py index b009a8ce2..8390c5554 100644 --- a/lib/sqlalchemy/util/topological.py +++ b/lib/sqlalchemy/util/topological.py @@ -1,5 +1,5 @@ # util/topological.py -# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -37,6 +37,7 @@ packages = find: python_requires = >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.* package_dir = =lib +# TODO remove greenlet from the default requires? install_requires = importlib-metadata;python_version<"3.8" greenlet != 0.4.17;python_version>="3" @@ -64,7 +65,9 @@ postgresql_asyncpg = greenlet;python_version>="3" postgresql_psycopg2binary = psycopg2-binary postgresql_psycopg2cffi = psycopg2cffi -pymysql = pymysql +pymysql = + pymysql;python_version>="3" + pymysql<1;python_version<"3" aiomysql = aiomysql [egg_info] @@ -120,12 +123,14 @@ default = sqlite:///:memory: sqlite = sqlite:///:memory: sqlite_file = sqlite:///querytest.db postgresql = postgresql://scott:tiger@127.0.0.1:5432/test -asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true +asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test +asyncpg_fallback = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test postgresql_psycopg2cffi = postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/test mysql = mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 pymysql = mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 -aiomysql = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true +aiomysql = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 +aiomysql_fallback = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true mariadb = mariadb://scott:tiger@127.0.0.1:3306/test mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server mssql_pymssql = mssql+pymssql://scott:tiger@ms_2008 diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 5e388c0b7..75a4f51cf 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -38,6 +38,7 @@ from sqlalchemy.sql.visitors import replacement_traverse from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect @@ -360,7 +361,7 @@ class MemUsageWBackendTest(EnsureZeroed): go() def test_session(self): - metadata = MetaData(self.engine) + metadata = MetaData() table1 = Table( "mytable", @@ -387,7 +388,7 @@ class MemUsageWBackendTest(EnsureZeroed): Column("col3", Integer, ForeignKey("mytable.col1")), ) - metadata.create_all() + metadata.create_all(self.engine) m1 = mapper( A, @@ -402,7 +403,7 @@ class MemUsageWBackendTest(EnsureZeroed): @profile_memory() def go(): - with Session() as sess: + with Session(self.engine) as sess: a1 = A(col2="a1") a2 = A(col2="a2") a3 = A(col2="a3") @@ -429,7 +430,7 @@ class MemUsageWBackendTest(EnsureZeroed): go() - metadata.drop_all() + metadata.drop_all(self.engine) del m1, m2 assert_no_mappers() @@ -535,7 +536,7 @@ class MemUsageWBackendTest(EnsureZeroed): @testing.emits_warning("Compiled statement cache for.*") def test_many_updates(self): - metadata = MetaData(self.engine) + metadata = MetaData() wide_table = Table( "t", @@ -551,8 +552,8 @@ class MemUsageWBackendTest(EnsureZeroed): mapper(Wide, wide_table, _compiled_cache_size=10) - metadata.create_all() - with Session() as session: + metadata.create_all(self.engine) + with Session(self.engine) as session: w1 = Wide() session.add(w1) session.commit() @@ -561,7 +562,7 @@ class MemUsageWBackendTest(EnsureZeroed): @profile_memory() def go(): - with Session() as session: + with Session(self.engine) as session: w1 = session.query(Wide).first() x = counter[0] dec = 10 @@ -578,7 +579,7 @@ class MemUsageWBackendTest(EnsureZeroed): try: go() finally: - metadata.drop_all() + metadata.drop_all(self.engine) @testing.requires.savepoints @testing.provide_metadata @@ -625,7 +626,7 @@ class MemUsageWBackendTest(EnsureZeroed): @testing.crashes("mysql+cymysql", "blocking") def test_unicode_warnings(self): - metadata = MetaData(self.engine) + metadata = MetaData() table1 = Table( "mytable", metadata, @@ -637,7 +638,7 @@ class MemUsageWBackendTest(EnsureZeroed): ), Column("col2", Unicode(30)), ) - metadata.create_all() + metadata.create_all(self.engine) i = [1] # the times here is cranked way up so that we can see @@ -659,7 +660,7 @@ class MemUsageWBackendTest(EnsureZeroed): try: go() finally: - metadata.drop_all() + metadata.drop_all(self.engine) def test_warnings_util(self): counter = itertools.count() @@ -677,7 +678,7 @@ class MemUsageWBackendTest(EnsureZeroed): go() def test_mapper_reset(self): - metadata = MetaData(self.engine) + metadata = MetaData() table1 = Table( "mytable", @@ -713,7 +714,7 @@ class MemUsageWBackendTest(EnsureZeroed): ) mapper(B, table2) - sess = create_session() + sess = create_session(self.engine) a1 = A(col2="a1") a2 = A(col2="a2") a3 = A(col2="a3") @@ -741,15 +742,15 @@ class MemUsageWBackendTest(EnsureZeroed): sess.close() clear_mappers() - metadata.create_all() + metadata.create_all(self.engine) try: go() finally: - metadata.drop_all() + metadata.drop_all(self.engine) assert_no_mappers() def test_alias_pathing(self): - metadata = MetaData(self.engine) + metadata = MetaData() a = Table( "a", @@ -779,8 +780,8 @@ class MemUsageWBackendTest(EnsureZeroed): mapper(ASub, asub, inherits=A, polymorphic_identity="asub") mapper(B, b, properties={"as_": relationship(A)}) - metadata.create_all() - sess = Session() + metadata.create_all(self.engine) + sess = Session(self.engine) a1 = ASub(data="a1") a2 = ASub(data="a2") a3 = ASub(data="a3") @@ -794,7 +795,7 @@ class MemUsageWBackendTest(EnsureZeroed): # "dip" again @profile_memory(maxtimes=120) def go(): - sess = Session() + sess = Session(self.engine) sess.query(B).options(subqueryload(B.as_.of_type(ASub))).all() sess.close() del sess @@ -802,7 +803,7 @@ class MemUsageWBackendTest(EnsureZeroed): try: go() finally: - metadata.drop_all() + metadata.drop_all(self.engine) clear_mappers() def test_path_registry(self): @@ -832,7 +833,7 @@ class MemUsageWBackendTest(EnsureZeroed): clear_mappers() def test_with_inheritance(self): - metadata = MetaData(self.engine) + metadata = MetaData() table1 = Table( "mytable", @@ -875,7 +876,7 @@ class MemUsageWBackendTest(EnsureZeroed): ) mapper(B, table2, inherits=A, polymorphic_identity="b") - sess = create_session() + sess = create_session(self.engine) a1 = A() a2 = A() b1 = B(col3="b1") @@ -896,15 +897,15 @@ class MemUsageWBackendTest(EnsureZeroed): del B del A - metadata.create_all() + metadata.create_all(self.engine) try: go() finally: - metadata.drop_all() + metadata.drop_all(self.engine) assert_no_mappers() def test_with_manytomany(self): - metadata = MetaData(self.engine) + metadata = MetaData() table1 = Table( "mytable", @@ -956,7 +957,7 @@ class MemUsageWBackendTest(EnsureZeroed): ) mapper(B, table2) - sess = create_session() + sess = create_session(self.engine) a1 = A(col2="a1") a2 = A(col2="a2") b1 = B(col2="b1") @@ -981,11 +982,11 @@ class MemUsageWBackendTest(EnsureZeroed): del B del A - metadata.create_all() + metadata.create_all(self.engine) try: go() finally: - metadata.drop_all() + metadata.drop_all(self.engine) assert_no_mappers() @testing.uses_deprecated() @@ -1031,7 +1032,7 @@ class MemUsageWBackendTest(EnsureZeroed): t2_mapper = mapper(T2, t2) t1_mapper.add_property("bar", relationship(t2_mapper)) - s1 = Session() + s1 = fixture_session() # this causes the path_registry to be invoked s1.query(t1_mapper)._compile_context() @@ -1043,7 +1044,7 @@ class MemUsageWBackendTest(EnsureZeroed): @testing.crashes("mysql+cymysql", "blocking") def test_join_cache_deprecated_coercion(self): - metadata = MetaData(self.engine) + metadata = MetaData() table1 = Table( "table1", metadata, @@ -1071,8 +1072,8 @@ class MemUsageWBackendTest(EnsureZeroed): mapper( Foo, table1, properties={"bars": relationship(mapper(Bar, table2))} ) - metadata.create_all() - session = sessionmaker() + metadata.create_all(self.engine) + session = sessionmaker(self.engine) @profile_memory() def go(): @@ -1087,11 +1088,11 @@ class MemUsageWBackendTest(EnsureZeroed): try: go() finally: - metadata.drop_all() + metadata.drop_all(self.engine) @testing.crashes("mysql+cymysql", "blocking") def test_join_cache(self): - metadata = MetaData(self.engine) + metadata = MetaData() table1 = Table( "table1", metadata, @@ -1119,8 +1120,8 @@ class MemUsageWBackendTest(EnsureZeroed): mapper( Foo, table1, properties={"bars": relationship(mapper(Bar, table2))} ) - metadata.create_all() - session = sessionmaker() + metadata.create_all(self.engine) + session = sessionmaker(self.engine) @profile_memory() def go(): @@ -1132,7 +1133,7 @@ class MemUsageWBackendTest(EnsureZeroed): try: go() finally: - metadata.drop_all() + metadata.drop_all(self.engine) class CycleTest(_fixtures.FixtureTest): @@ -1151,7 +1152,7 @@ class CycleTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") configure_mappers() - s = Session() + s = fixture_session() @assert_cycles() def go(): @@ -1163,7 +1164,7 @@ class CycleTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") configure_mappers() - s = Session() + s = fixture_session() @assert_cycles() def go(): @@ -1223,7 +1224,7 @@ class CycleTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") configure_mappers() - s = Session() + s = fixture_session() u1 = aliased(User) @@ -1248,7 +1249,7 @@ class CycleTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") configure_mappers() - s = Session() + s = fixture_session() def generate(): objects = s.query(User).filter(User.id == 7).all() @@ -1264,7 +1265,7 @@ class CycleTest(_fixtures.FixtureTest): def test_orm_objects_from_query_w_selectinload(self): User, Address = self.classes("User", "Address") - s = Session() + s = fixture_session() def generate(): objects = s.query(User).options(selectinload(User.addresses)).all() @@ -1328,7 +1329,7 @@ class CycleTest(_fixtures.FixtureTest): def test_orm_objects_from_query_w_joinedload(self): User, Address = self.classes("User", "Address") - s = Session() + s = fixture_session() def generate(): objects = s.query(User).options(joinedload(User.addresses)).all() @@ -1344,7 +1345,7 @@ class CycleTest(_fixtures.FixtureTest): def test_query_filtered(self): User, Address = self.classes("User", "Address") - s = Session() + s = fixture_session() @assert_cycles() def go(): @@ -1355,7 +1356,7 @@ class CycleTest(_fixtures.FixtureTest): def test_query_joins(self): User, Address = self.classes("User", "Address") - s = Session() + s = fixture_session() # cycles here are due to ClauseElement._cloned_set, others # as of cache key @@ -1368,7 +1369,7 @@ class CycleTest(_fixtures.FixtureTest): def test_query_joinedload(self): User, Address = self.classes("User", "Address") - s = Session() + s = fixture_session() def generate(): s.query(User).options(joinedload(User.addresses)).all() @@ -1388,7 +1389,7 @@ class CycleTest(_fixtures.FixtureTest): @assert_cycles() def go(): - str(users.join(addresses)) + str(users.join(addresses).compile(testing.db)) go() @@ -1400,7 +1401,7 @@ class CycleTest(_fixtures.FixtureTest): @assert_cycles(7) def go(): s = select(users).select_from(users.join(addresses)) - state = s._compile_state_factory(s, s.compile()) + state = s._compile_state_factory(s, s.compile(testing.db)) state.froms go() @@ -1410,7 +1411,7 @@ class CycleTest(_fixtures.FixtureTest): @assert_cycles() def go(): - str(orm_join(User, Address, User.addresses)) + str(orm_join(User, Address, User.addresses).compile(testing.db)) go() @@ -1418,7 +1419,7 @@ class CycleTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") configure_mappers() - s = Session() + s = fixture_session() @assert_cycles() def go(): @@ -1430,7 +1431,7 @@ class CycleTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") configure_mappers() - s = Session() + s = fixture_session() @assert_cycles() def go(): @@ -1442,7 +1443,7 @@ class CycleTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") configure_mappers() - s = Session() + s = fixture_session() stmt = s.query(User).join(User.addresses).statement @@ -1460,7 +1461,7 @@ class CycleTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") configure_mappers() - s = Session() + s = fixture_session() stmt = s.query(User).join(User.addresses).statement @@ -1475,7 +1476,7 @@ class CycleTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") configure_mappers() - s = Session() + s = fixture_session() stmt = s.query(User).join(User.addresses).statement @@ -1491,7 +1492,7 @@ class CycleTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") configure_mappers() - s = Session() + s = fixture_session() stmt = s.query(User).join(User.addresses).statement @@ -1507,7 +1508,7 @@ class CycleTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") configure_mappers() - s = Session() + s = fixture_session() stmt = s.query(User).join(User.addresses).statement @@ -1568,7 +1569,7 @@ class CycleTest(_fixtures.FixtureTest): go() - @testing.fails + @testing.fails() def test_the_counter(self): @assert_cycles() def go(): diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 547672961..f163078d8 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -20,6 +20,7 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import config from sqlalchemy.testing import fixtures from sqlalchemy.testing import profiling +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -100,8 +101,8 @@ class MergeTest(NoCache, fixtures.MappedTest): def test_merge_no_load(self): Parent = self.classes.Parent - sess = sessionmaker()() - sess2 = sessionmaker()() + sess = fixture_session() + sess2 = fixture_session() p1 = sess.query(Parent).get(1) p1.children @@ -129,8 +130,8 @@ class MergeTest(NoCache, fixtures.MappedTest): def test_merge_load(self): Parent = self.classes.Parent - sess = sessionmaker()() - sess2 = sessionmaker()() + sess = fixture_session() + sess2 = fixture_session() p1 = sess.query(Parent).get(1) p1.children @@ -228,7 +229,7 @@ class LoadManyToOneFromIdentityTest(NoCache, fixtures.MappedTest): def test_many_to_one_load_no_identity(self): Parent = self.classes.Parent - sess = Session() + sess = fixture_session() parents = sess.query(Parent).all() @profiling.function_call_count(variance=0.2) @@ -241,7 +242,7 @@ class LoadManyToOneFromIdentityTest(NoCache, fixtures.MappedTest): def test_many_to_one_load_identity(self): Parent, Child = self.classes.Parent, self.classes.Child - sess = Session() + sess = fixture_session() parents = sess.query(Parent).all() children = sess.query(Child).all() children # strong reference @@ -335,7 +336,7 @@ class MergeBackrefsTest(NoCache, fixtures.MappedTest): self.classes.C, self.classes.D, ) - s = Session() + s = fixture_session() for a in [ A( id=i, @@ -398,7 +399,7 @@ class DeferOptionsTest(NoCache, fixtures.MappedTest): def test_baseline(self): # as of [ticket:2778], this is at 39025 A = self.classes.A - s = Session() + s = fixture_session() s.query(A).all() @profiling.function_call_count(variance=0.10) @@ -406,7 +407,7 @@ class DeferOptionsTest(NoCache, fixtures.MappedTest): # with [ticket:2778], this goes from 50805 to 32817, # as it should be fewer function calls than the baseline A = self.classes.A - s = Session() + s = fixture_session() s.query(A).options( *[defer(letter) for letter in ["x", "y", "z", "p", "q", "r"]] ).all() @@ -546,7 +547,7 @@ class SessionTest(NoCache, fixtures.MappedTest): Parent(children=[Child() for j in range(10)]) for i in range(10) ] - sess = Session() + sess = fixture_session() sess.add_all(obj) sess.flush() @@ -588,7 +589,7 @@ class QueryTest(NoCache, fixtures.MappedTest): def _fixture(self): Parent = self.classes.Parent - sess = Session() + sess = fixture_session() sess.add_all( [ Parent(data1="d1", data2="d2", data3="d3", data4="d4") @@ -601,7 +602,7 @@ class QueryTest(NoCache, fixtures.MappedTest): def test_query_cols(self): Parent = self.classes.Parent self._fixture() - sess = Session() + sess = fixture_session() # warm up cache for attr in [Parent.data1, Parent.data2, Parent.data3, Parent.data4]: @@ -695,7 +696,7 @@ class SelectInEagerLoadTest(NoCache, fixtures.MappedTest): def test_round_trip_results(self): A, B, C = self.classes("A", "B", "C") - sess = Session() + sess = fixture_session() q = sess.query(A).options(selectinload(A.bs).selectinload(B.cs)) @@ -835,7 +836,7 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest): def test_build_query(self): A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") - sess = Session() + sess = fixture_session() @profiling.function_call_count() def go(): @@ -1122,7 +1123,7 @@ class BranchedOptionTest(NoCache, fixtures.MappedTest): base.joinedload(B.fs), ] - q = Session().query(A) + q = fixture_session().query(A) context = q._compile_state() @@ -1149,7 +1150,7 @@ class BranchedOptionTest(NoCache, fixtures.MappedTest): base.joinedload(B.fs), ] - q = Session().query(A) + q = fixture_session().query(A) context = q._compile_state() @@ -1201,7 +1202,7 @@ class AnnotatedOverheadTest(NoCache, fixtures.MappedTest): def test_no_bundle(self): A = self.classes.A - s = Session() + s = fixture_session() q = s.query(A).select_from(A) @@ -1215,7 +1216,7 @@ class AnnotatedOverheadTest(NoCache, fixtures.MappedTest): def test_no_entity_wo_annotations(self): A = self.classes.A a = self.tables.a - s = Session() + s = fixture_session() q = s.query(a.c.data).select_from(A) @@ -1228,7 +1229,7 @@ class AnnotatedOverheadTest(NoCache, fixtures.MappedTest): def test_no_entity_w_annotations(self): A = self.classes.A - s = Session() + s = fixture_session() q = s.query(A.data).select_from(A) @profiling.function_call_count(warmup=1) @@ -1240,7 +1241,7 @@ class AnnotatedOverheadTest(NoCache, fixtures.MappedTest): def test_entity_w_annotations(self): A = self.classes.A - s = Session() + s = fixture_session() q = s.query(A, A.data).select_from(A) @profiling.function_call_count(warmup=1) @@ -1253,7 +1254,7 @@ class AnnotatedOverheadTest(NoCache, fixtures.MappedTest): def test_entity_wo_annotations(self): A = self.classes.A a = self.tables.a - s = Session() + s = fixture_session() q = s.query(A, a.c.data).select_from(A) @profiling.function_call_count(warmup=1) @@ -1266,7 +1267,7 @@ class AnnotatedOverheadTest(NoCache, fixtures.MappedTest): def test_no_bundle_wo_annotations(self): A = self.classes.A a = self.tables.a - s = Session() + s = fixture_session() q = s.query(a.c.data, A).select_from(A) @profiling.function_call_count(warmup=1) @@ -1278,7 +1279,7 @@ class AnnotatedOverheadTest(NoCache, fixtures.MappedTest): def test_no_bundle_w_annotations(self): A = self.classes.A - s = Session() + s = fixture_session() q = s.query(A.data, A).select_from(A) @profiling.function_call_count(warmup=1) @@ -1291,7 +1292,7 @@ class AnnotatedOverheadTest(NoCache, fixtures.MappedTest): def test_bundle_wo_annotation(self): A = self.classes.A a = self.tables.a - s = Session() + s = fixture_session() q = s.query(Bundle("ASdf", a.c.data), A).select_from(A) @profiling.function_call_count(warmup=1) @@ -1303,7 +1304,7 @@ class AnnotatedOverheadTest(NoCache, fixtures.MappedTest): def test_bundle_w_annotation(self): A = self.classes.A - s = Session() + s = fixture_session() q = s.query(Bundle("ASdf", A.data), A).select_from(A) @profiling.function_call_count(warmup=1) diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py index d36a0c9e1..ae0ea4992 100644 --- a/test/aaa_profiling/test_resultset.py +++ b/test/aaa_profiling/test_resultset.py @@ -3,7 +3,6 @@ import sys from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import Integer -from sqlalchemy import MetaData from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing @@ -20,17 +19,13 @@ from sqlalchemy.util import u NUM_FIELDS = 10 NUM_RECORDS = 1000 -t = t2 = metadata = None - -class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): +class ResultSetTest(fixtures.TablesTest, AssertsExecutionResults): __backend__ = True @classmethod - def setup_class(cls): - global t, t2, metadata - metadata = MetaData(testing.db) - t = Table( + def define_tables(cls, metadata): + Table( "table1", metadata, *[ @@ -38,7 +33,7 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): for fnum in range(NUM_FIELDS) ] ) - t2 = Table( + Table( "table2", metadata, *[ @@ -47,48 +42,46 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): ] ) - def setup(self): - with testing.db.begin() as conn: - metadata.create_all(conn) - conn.execute( - t.insert(), - [ - dict( - ("field%d" % fnum, u("value%d" % fnum)) - for fnum in range(NUM_FIELDS) - ) - for r_num in range(NUM_RECORDS) - ], - ) - conn.execute( - t2.insert(), - [ - dict( - ("field%d" % fnum, u("value%d" % fnum)) - for fnum in range(NUM_FIELDS) - ) - for r_num in range(NUM_RECORDS) - ], - ) + @classmethod + def insert_data(cls, connection): + conn = connection + t, t2 = cls.tables("table1", "table2") + conn.execute( + t.insert(), + [ + dict( + ("field%d" % fnum, u("value%d" % fnum)) + for fnum in range(NUM_FIELDS) + ) + for r_num in range(NUM_RECORDS) + ], + ) + conn.execute( + t2.insert(), + [ + dict( + ("field%d" % fnum, u("value%d" % fnum)) + for fnum in range(NUM_FIELDS) + ) + for r_num in range(NUM_RECORDS) + ], + ) # warm up type caches - with testing.db.connect() as conn: - conn.execute(t.select()).fetchall() - conn.execute(t2.select()).fetchall() - conn.exec_driver_sql( - "SELECT %s FROM table1" - % (", ".join("field%d" % fnum for fnum in range(NUM_FIELDS))) - ).fetchall() - conn.exec_driver_sql( - "SELECT %s FROM table2" - % (", ".join("field%d" % fnum for fnum in range(NUM_FIELDS))) - ).fetchall() - - def teardown(self): - metadata.drop_all() + conn.execute(t.select()).fetchall() + conn.execute(t2.select()).fetchall() + conn.exec_driver_sql( + "SELECT %s FROM table1" + % (", ".join("field%d" % fnum for fnum in range(NUM_FIELDS))) + ).fetchall() + conn.exec_driver_sql( + "SELECT %s FROM table2" + % (", ".join("field%d" % fnum for fnum in range(NUM_FIELDS))) + ).fetchall() @profiling.function_call_count(variance=0.15) def test_string(self): + t, t2 = self.tables("table1", "table2") with testing.db.connect().execution_options( compiled_cache=None ) as conn: @@ -96,6 +89,8 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): @profiling.function_call_count(variance=0.15) def test_unicode(self): + t, t2 = self.tables("table1", "table2") + with testing.db.connect().execution_options( compiled_cache=None ) as conn: @@ -119,6 +114,7 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): @profiling.function_call_count() def test_fetch_by_key_legacy(self): + t, t2 = self.tables("table1", "table2") with testing.db.connect().execution_options( compiled_cache=None ) as conn: @@ -127,6 +123,7 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): @profiling.function_call_count() def test_fetch_by_key_mappings(self): + t, t2 = self.tables("table1", "table2") with testing.db.connect().execution_options( compiled_cache=None ) as conn: @@ -142,6 +139,8 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): def test_one_or_none(self, one_or_first, rows_present): # TODO: this is not testing the ORM level "scalar_mapping" # mode which has a different performance profile + t, t2 = self.tables("table1", "table2") + with testing.db.connect().execution_options( compiled_cache=None ) as conn: @@ -167,8 +166,10 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): # seem to be handling this for a profile that skips result.close() - def test_contains_doesnt_compile(self): - row = t.select().execute().first() + def test_contains_doesnt_compile(self, connection): + t, t2 = self.tables("table1", "table2") + + row = connection.execute(t.select()).first() c1 = Column("some column", Integer) + Column( "some other column", Integer ) diff --git a/test/base/test_concurrency_py3k.py b/test/base/test_concurrency_py3k.py index cf1067667..e7ae8c9ad 100644 --- a/test/base/test_concurrency_py3k.py +++ b/test/base/test_concurrency_py3k.py @@ -53,7 +53,8 @@ class TestAsyncioCompat(fixtures.TestBase): to_await = run1() await_fallback(to_await) - def test_await_only_no_greenlet(self): + @async_test + async def test_await_only_no_greenlet(self): to_await = run1() with expect_raises_message( exc.InvalidRequestError, @@ -62,7 +63,7 @@ class TestAsyncioCompat(fixtures.TestBase): await_only(to_await) # ensure no warning - await_fallback(to_await) + await greenlet_spawn(await_fallback, to_await) @async_test async def test_await_fallback_error(self): diff --git a/test/dialect/mssql/test_engine.py b/test/dialect/mssql/test_engine.py index 668df6ecb..bbdbe5cca 100644 --- a/test/dialect/mssql/test_engine.py +++ b/test/dialect/mssql/test_engine.py @@ -363,15 +363,14 @@ class FastExecutemanyTest(fixtures.TestBase): __backend__ = True __requires__ = ("pyodbc_fast_executemany",) - @testing.provide_metadata - def test_flag_on(self): + def test_flag_on(self, metadata): t = Table( "t", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("data", String(50)), ) - t.create() + t.create(testing.db) eng = engines.testing_engine(options={"fast_executemany": True}) @@ -446,10 +445,9 @@ class RealIsolationLevelTest(fixtures.TestBase): __only_on__ = "mssql" __backend__ = True - @testing.provide_metadata - def test_isolation_level(self): - Table("test", self.metadata, Column("id", Integer)).create( - checkfirst=True + def test_isolation_level(self, metadata): + Table("test", metadata, Column("id", Integer)).create( + testing.db, checkfirst=True ) with testing.db.connect() as c: diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py index ea0bfa4d2..cdb37cc61 100644 --- a/test/dialect/mssql/test_query.py +++ b/test/dialect/mssql/test_query.py @@ -567,12 +567,11 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): ) ).fetchall() eq_([5], [r.id for r in results1]) - results2 = ( - matchtable.select() - .where(matchtable.c.title.match("python AND nutshell")) - .execute() - .fetchall() - ) + results2 = connection.execute( + matchtable.select().where( + matchtable.c.title.match("python AND nutshell") + ) + ).fetchall() eq_([5], [r.id for r in results2]) def test_match_across_joins(self, connection): diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index 86c97316a..86eff0fe4 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -39,9 +39,8 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): __only_on__ = "mssql" __backend__ = True - @testing.provide_metadata - def test_basic_reflection(self): - meta = self.metadata + def test_basic_reflection(self, metadata, connection): + meta = metadata users = Table( "engine_users", @@ -77,59 +76,44 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): ), Column("email_address", types.String(20)), ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() reflected_users = Table( - "engine_users", meta2, autoload_with=testing.db + "engine_users", meta2, autoload_with=connection ) reflected_addresses = Table( "engine_email_addresses", meta2, - autoload_with=testing.db, + autoload_with=connection, ) self.assert_tables_equal(users, reflected_users) self.assert_tables_equal(addresses, reflected_addresses) - @testing.provide_metadata - def _test_specific_type(self, type_obj, ddl): - metadata = self.metadata + @testing.combinations( + (mssql.XML, "XML"), + (mssql.IMAGE, "IMAGE"), + (mssql.MONEY, "MONEY"), + (mssql.NUMERIC(10, 2), "NUMERIC(10, 2)"), + (mssql.FLOAT, "FLOAT(53)"), + (mssql.REAL, "REAL"), + # FLOAT(5) comes back as REAL + (mssql.FLOAT(5), "REAL"), + argnames="type_obj,ddl", + ) + def test_assorted_types(self, metadata, connection, type_obj, ddl): table = Table("type_test", metadata, Column("col1", type_obj)) - table.create() + table.create(connection) m2 = MetaData() - table2 = Table("type_test", m2, autoload_with=testing.db) + table2 = Table("type_test", m2, autoload_with=connection) self.assert_compile( schema.CreateTable(table2), "CREATE TABLE type_test (col1 %s NULL)" % ddl, ) - def test_xml_type(self): - self._test_specific_type(mssql.XML, "XML") - - def test_image_type(self): - self._test_specific_type(mssql.IMAGE, "IMAGE") - - def test_money_type(self): - self._test_specific_type(mssql.MONEY, "MONEY") - - def test_numeric_prec_scale(self): - self._test_specific_type(mssql.NUMERIC(10, 2), "NUMERIC(10, 2)") - - def test_float(self): - self._test_specific_type(mssql.FLOAT, "FLOAT(53)") - - def test_real(self): - self._test_specific_type(mssql.REAL, "REAL") - - def test_float_as_real(self): - # FLOAT(5) comes back as REAL - self._test_specific_type(mssql.FLOAT(5), "REAL") - - @testing.provide_metadata - def test_identity(self): - metadata = self.metadata + def test_identity(self, metadata, connection): table = Table( "identity_test", metadata, @@ -144,10 +128,10 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): with testing.expect_deprecated( "The dialect options 'mssql_identity_start' and" ): - table.create() + table.create(connection) meta2 = MetaData() - table2 = Table("identity_test", meta2, autoload_with=testing.db) + table2 = Table("identity_test", meta2, autoload_with=connection) eq_(table2.c["col1"].dialect_options["mssql"]["identity_start"], None) eq_( table2.c["col1"].dialect_options["mssql"]["identity_increment"], @@ -156,7 +140,6 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): eq_(table2.c["col1"].identity.start, 2) eq_(table2.c["col1"].identity.increment, 3) - @testing.provide_metadata def test_skip_types(self, connection): connection.exec_driver_sql( "create table foo (id integer primary key, data xml)" @@ -189,10 +172,8 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): ], ) - @testing.provide_metadata - def test_cross_schema_fk_pk_name_overlaps(self): + def test_cross_schema_fk_pk_name_overlaps(self, metadata, connection): # test for issue #4228 - metadata = self.metadata Table( "subject", @@ -224,9 +205,9 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): schema=testing.config.test_schema_2, ) - metadata.create_all() + metadata.create_all(connection) - insp = inspect(testing.db) + insp = inspect(connection) eq_( insp.get_foreign_keys("referrer", testing.config.test_schema), [ @@ -240,9 +221,9 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): ], ) - @testing.provide_metadata - def test_table_name_that_is_greater_than_16_chars(self): - metadata = self.metadata + def test_table_name_that_is_greater_than_16_chars( + self, metadata, connection + ): Table( "ABCDEFGHIJKLMNOPQRSTUVWXYZ", metadata, @@ -250,14 +231,13 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): Column("foo", Integer), Index("foo_idx", "foo"), ) - metadata.create_all() + metadata.create_all(connection) t = Table( - "ABCDEFGHIJKLMNOPQRSTUVWXYZ", MetaData(), autoload_with=testing.db + "ABCDEFGHIJKLMNOPQRSTUVWXYZ", MetaData(), autoload_with=connection ) eq_(t.name, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") - @testing.provide_metadata @testing.combinations( ("local_temp", "#tmp", True), ("global_temp", "##tmp", True), @@ -265,12 +245,11 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): id_="iaa", argnames="table_name, exists", ) - def test_temporary_table(self, connection, table_name, exists): - metadata = self.metadata + def test_temporary_table(self, metadata, connection, table_name, exists): if exists: tt = Table( table_name, - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("txt", mssql.NVARCHAR(50)), Column("dt2", mssql.DATETIME2), @@ -309,7 +288,6 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): [(2, "bar", datetime.datetime(2020, 2, 2, 2, 2, 2))], ) - @testing.provide_metadata @testing.combinations( ("local_temp", "#tmp", True), ("global_temp", "##tmp", True), @@ -317,11 +295,13 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): id_="iaa", argnames="table_name, exists", ) - def test_has_table_temporary(self, connection, table_name, exists): + def test_has_table_temporary( + self, metadata, connection, table_name, exists + ): if exists: tt = Table( table_name, - self.metadata, + metadata, Column("id", Integer), ) tt.create(connection) @@ -329,9 +309,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): found_it = testing.db.dialect.has_table(connection, table_name) eq_(found_it, exists) - @testing.provide_metadata - def test_db_qualified_items(self): - metadata = self.metadata + def test_db_qualified_items(self, metadata, connection): Table("foo", metadata, Column("id", Integer, primary_key=True)) Table( "bar", @@ -339,17 +317,16 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): Column("id", Integer, primary_key=True), Column("foo_id", Integer, ForeignKey("foo.id", name="fkfoo")), ) - metadata.create_all() + metadata.create_all(connection) - with testing.db.connect() as c: - dbname = c.exec_driver_sql("select db_name()").scalar() - owner = c.exec_driver_sql("SELECT user_name()").scalar() + dbname = connection.exec_driver_sql("select db_name()").scalar() + owner = connection.exec_driver_sql("SELECT user_name()").scalar() referred_schema = "%(dbname)s.%(owner)s" % { "dbname": dbname, "owner": owner, } - inspector = inspect(testing.db) + inspector = inspect(connection) bar_via_db = inspector.get_foreign_keys("bar", schema=referred_schema) eq_( bar_via_db, @@ -364,33 +341,29 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): ], ) - assert inspect(testing.db).has_table("bar", schema=referred_schema) + assert inspect(connection).has_table("bar", schema=referred_schema) m2 = MetaData() Table( "bar", m2, schema=referred_schema, - autoload_with=testing.db, + autoload_with=connection, ) eq_(m2.tables["%s.foo" % referred_schema].schema, referred_schema) - @testing.provide_metadata - def test_indexes_cols(self): - metadata = self.metadata + def test_indexes_cols(self, metadata, connection): t1 = Table("t", metadata, Column("x", Integer), Column("y", Integer)) Index("foo", t1.c.x, t1.c.y) - metadata.create_all() + metadata.create_all(connection) m2 = MetaData() - t2 = Table("t", m2, autoload_with=testing.db) + t2 = Table("t", m2, autoload_with=connection) eq_(set(list(t2.indexes)[0].columns), set([t2.c["x"], t2.c.y])) - @testing.provide_metadata - def test_indexes_cols_with_commas(self): - metadata = self.metadata + def test_indexes_cols_with_commas(self, metadata, connection): t1 = Table( "t", @@ -399,16 +372,14 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): Column("y", Integer), ) Index("foo", t1.c.x, t1.c.y) - metadata.create_all() + metadata.create_all(connection) m2 = MetaData() - t2 = Table("t", m2, autoload_with=testing.db) + t2 = Table("t", m2, autoload_with=connection) eq_(set(list(t2.indexes)[0].columns), set([t2.c["x, col"], t2.c.y])) - @testing.provide_metadata - def test_indexes_cols_with_spaces(self): - metadata = self.metadata + def test_indexes_cols_with_spaces(self, metadata, connection): t1 = Table( "t", @@ -417,16 +388,14 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): Column("y", Integer), ) Index("foo", t1.c.x, t1.c.y) - metadata.create_all() + metadata.create_all(connection) m2 = MetaData() - t2 = Table("t", m2, autoload_with=testing.db) + t2 = Table("t", m2, autoload_with=connection) eq_(set(list(t2.indexes)[0].columns), set([t2.c["x col"], t2.c.y])) - @testing.provide_metadata - def test_indexes_with_filtered(self, connection): - metadata = self.metadata + def test_indexes_with_filtered(self, metadata, connection): t1 = Table( "t", @@ -454,8 +423,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): CreateIndex(idx), "CREATE INDEX idx_x ON t (x) WHERE ([x]='test')" ) - @testing.provide_metadata - def test_max_ident_in_varchar_not_present(self): + def test_max_ident_in_varchar_not_present(self, metadata, connection): """test [ticket:3504]. Here we are testing not just that the "max" token comes back @@ -464,7 +432,6 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): pattern however is likely in common use. """ - metadata = self.metadata Table( "t", @@ -475,10 +442,10 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): Column("t4", types.LargeBinary("max")), Column("t5", types.VARBINARY("max")), ) - metadata.create_all() - for col in inspect(testing.db).get_columns("t"): + metadata.create_all(connection) + for col in inspect(connection).get_columns("t"): is_(col["type"].length, None) - in_("max", str(col["type"].compile(dialect=testing.db.dialect))) + in_("max", str(col["type"].compile(dialect=connection.dialect))) class InfoCoerceUnicodeTest(fixtures.TestBase, AssertsCompiledSQL): @@ -510,42 +477,35 @@ class InfoCoerceUnicodeTest(fixtures.TestBase, AssertsCompiledSQL): ) -class ReflectHugeViewTest(fixtures.TestBase): +class ReflectHugeViewTest(fixtures.TablesTest): __only_on__ = "mssql" __backend__ = True # crashes on freetds 0.91, not worth it __skip_if__ = (lambda: testing.requires.mssql_freetds.enabled,) - def setup(self): - self.col_num = 150 + @classmethod + def define_tables(cls, metadata): + col_num = 150 - self.metadata = MetaData(testing.db) t = Table( "base_table", - self.metadata, + metadata, *[ Column("long_named_column_number_%d" % i, Integer) - for i in range(self.col_num) + for i in range(col_num) ] ) - self.view_str = ( + cls.view_str = ( view_str ) = "CREATE VIEW huge_named_view AS SELECT %s FROM base_table" % ( - ",".join( - "long_named_column_number_%d" % i for i in range(self.col_num) - ) + ",".join("long_named_column_number_%d" % i for i in range(col_num)) ) assert len(view_str) > 4000 event.listen(t, "after_create", DDL(view_str)) event.listen(t, "before_drop", DDL("DROP VIEW huge_named_view")) - self.metadata.create_all() - - def teardown(self): - self.metadata.drop_all() - def test_inspect_view_definition(self): inspector = inspect(testing.db) view_def = inspector.get_view_definition("huge_named_view") @@ -712,10 +672,10 @@ class IdentityReflectionTest(fixtures.TablesTest): ): Table("t%s" % i, metadata, col) - def test_reflect_identity(self): - insp = inspect(testing.db) + def test_reflect_identity(self, connection): + insp = inspect(connection) cols = [] - for t in self.metadata.tables.keys(): + for t in self.tables_test_metadata.tables.keys(): cols.extend(insp.get_columns(t)) for col in cols: is_true("dialect_options" not in col) diff --git a/test/dialect/mssql/test_types.py b/test/dialect/mssql/test_types.py index a4a3bedda..c2231d105 100644 --- a/test/dialect/mssql/test_types.py +++ b/test/dialect/mssql/test_types.py @@ -544,9 +544,7 @@ class TypeRoundTripTest( __backend__ = True - @testing.provide_metadata - def test_decimal_notation(self, connection): - metadata = self.metadata + def test_decimal_notation(self, metadata, connection): numeric_table = Table( "numeric_table", metadata, @@ -635,9 +633,7 @@ class TypeRoundTripTest( ) eq_(value, returned) - @testing.provide_metadata - def test_float(self, connection): - metadata = self.metadata + def test_float(self, metadata, connection): float_table = Table( "float_table", @@ -693,10 +689,8 @@ class TypeRoundTripTest( ) eq_(value, returned) - # todo this should suppress warnings, but it does not @emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*") - @testing.provide_metadata - def test_dates(self): + def test_dates(self, metadata, connection): "Exercise type specification for date types." columns = [ @@ -727,8 +721,6 @@ class TypeRoundTripTest( (mssql.MSDateTime2, [1], {}, "DATETIME2(1)", [">=", (10,)]), ] - metadata = self.metadata - table_args = ["test_mssql_dates", metadata] for index, spec in enumerate(columns): type_, args, kw, res, requires = spec[0:5] @@ -738,11 +730,11 @@ class TypeRoundTripTest( or not requires ): c = Column("c%s" % index, type_(*args, **kw), nullable=None) - testing.db.dialect.type_descriptor(c.type) + connection.dialect.type_descriptor(c.type) table_args.append(c) dates_table = Table(*table_args) - gen = testing.db.dialect.ddl_compiler( - testing.db.dialect, schema.CreateTable(dates_table) + gen = connection.dialect.ddl_compiler( + connection.dialect, schema.CreateTable(dates_table) ) for col in dates_table.c: index = int(col.name[1:]) @@ -751,9 +743,9 @@ class TypeRoundTripTest( "%s %s" % (col.name, columns[index][3]), ) self.assert_(repr(col)) - dates_table.create(checkfirst=True) + dates_table.create(connection) reflected_dates = Table( - "test_mssql_dates", MetaData(), autoload_with=testing.db + "test_mssql_dates", MetaData(), autoload_with=connection ) for col in reflected_dates.c: self.assert_types_base(col, dates_table.c[col.key]) @@ -915,13 +907,13 @@ class TypeRoundTripTest( ) @emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*") - @testing.provide_metadata @testing.combinations( ("legacy_large_types", False), ("sql2012_large_types", True, lambda: testing.only_on("mssql >= 11")), id_="ia", + argnames="deprecate_large_types", ) - def test_binary_reflection(self, deprecate_large_types): + def test_binary_reflection(self, metadata, deprecate_large_types): "Exercise type specification for binary types." columns = [ @@ -944,47 +936,45 @@ class TypeRoundTripTest( ), ] - metadata = self.metadata - metadata.bind = engines.testing_engine( + engine = engines.testing_engine( options={"deprecate_large_types": deprecate_large_types} ) - table_args = ["test_mssql_binary", metadata] - for index, spec in enumerate(columns): - type_, args, kw, res = spec - table_args.append( - Column("c%s" % index, type_(*args, **kw), nullable=None) + with engine.begin() as conn: + table_args = ["test_mssql_binary", metadata] + for index, spec in enumerate(columns): + type_, args, kw, res = spec + table_args.append( + Column("c%s" % index, type_(*args, **kw), nullable=None) + ) + binary_table = Table(*table_args) + metadata.create_all(conn) + reflected_binary = Table( + "test_mssql_binary", MetaData(), autoload_with=conn ) - binary_table = Table(*table_args) - metadata.create_all() - reflected_binary = Table( - "test_mssql_binary", MetaData(), autoload_with=testing.db - ) - for col, spec in zip(reflected_binary.c, columns): - eq_( - col.type.compile(dialect=mssql.dialect()), - spec[3], - "column %s %s != %s" - % ( - col.key, + for col, spec in zip(reflected_binary.c, columns): + eq_( col.type.compile(dialect=mssql.dialect()), spec[3], - ), - ) - c1 = testing.db.dialect.type_descriptor(col.type).__class__ - c2 = testing.db.dialect.type_descriptor( - binary_table.c[col.name].type - ).__class__ - assert issubclass( - c1, c2 - ), "column %s: %r is not a subclass of %r" % (col.key, c1, c2) - if binary_table.c[col.name].type.length: - testing.eq_( - col.type.length, binary_table.c[col.name].type.length + "column %s %s != %s" + % ( + col.key, + col.type.compile(dialect=conn.dialect), + spec[3], + ), ) + c1 = conn.dialect.type_descriptor(col.type).__class__ + c2 = conn.dialect.type_descriptor( + binary_table.c[col.name].type + ).__class__ + assert issubclass( + c1, c2 + ), "column %s: %r is not a subclass of %r" % (col.key, c1, c2) + if binary_table.c[col.name].type.length: + testing.eq_( + col.type.length, binary_table.c[col.name].type.length + ) - @testing.provide_metadata - def test_autoincrement(self): - metadata = self.metadata + def test_autoincrement(self, metadata, connection): Table( "ai_1", metadata, @@ -1035,7 +1025,7 @@ class TypeRoundTripTest( Column("o1", String(1), DefaultClause("x"), primary_key=True), Column("o2", String(1), DefaultClause("x"), primary_key=True), ) - metadata.create_all() + metadata.create_all(connection) table_names = [ "ai_1", @@ -1050,7 +1040,7 @@ class TypeRoundTripTest( mr = MetaData() for name in table_names: - tbl = Table(name, mr, autoload_with=testing.db) + tbl = Table(name, mr, autoload_with=connection) tbl = metadata.tables[name] # test that the flag itself reflects appropriately @@ -1081,24 +1071,23 @@ class TypeRoundTripTest( ] for counter, engine in enumerate(eng): - with engine.begin() as conn: - conn.execute(tbl.insert()) - if "int_y" in tbl.c: - eq_( - conn.execute(select(tbl.c.int_y)).scalar(), - counter + 1, - ) - assert ( - list(conn.execute(tbl.select()).first()).count( - counter + 1 - ) - == 1 - ) - else: - assert 1 not in list( - conn.execute(tbl.select()).first() + connection.execute(tbl.insert()) + if "int_y" in tbl.c: + eq_( + connection.execute(select(tbl.c.int_y)).scalar(), + counter + 1, + ) + assert ( + list(connection.execute(tbl.select()).first()).count( + counter + 1 ) - conn.execute(tbl.delete()) + == 1 + ) + else: + assert 1 not in list( + connection.execute(tbl.select()).first() + ) + connection.execute(tbl.delete()) class StringTest(fixtures.TestBase, AssertsCompiledSQL): @@ -1144,17 +1133,87 @@ class StringTest(fixtures.TestBase, AssertsCompiledSQL): ) +class MyPickleType(types.TypeDecorator): + impl = PickleType + + def process_bind_param(self, value, dialect): + if value: + value.stuff = "BIND" + value.stuff + return value + + def process_result_value(self, value, dialect): + if value: + value.stuff = value.stuff + "RESULT" + return value + + class BinaryTest(fixtures.TestBase): __only_on__ = "mssql" __requires__ = ("non_broken_binary",) __backend__ = True - def test_character_binary(self): - self._test_round_trip(mssql.MSVarBinary(800), b("some normal data")) - - @testing.provide_metadata - def _test_round_trip( - self, type_, data, deprecate_large_types=True, expected=None + @testing.combinations( + ( + mssql.MSVarBinary(800), + b("some normal data"), + None, + True, + None, + False, + ), + ( + mssql.VARBINARY("max"), + "binary_data_one.dat", + None, + False, + None, + False, + ), + ( + mssql.VARBINARY("max"), + "binary_data_one.dat", + None, + True, + None, + False, + ), + ( + sqltypes.LargeBinary, + "binary_data_one.dat", + None, + False, + None, + False, + ), + (sqltypes.LargeBinary, "binary_data_one.dat", None, True, None, False), + (mssql.MSImage, "binary_data_one.dat", None, True, None, False), + (PickleType, pickleable.Foo("im foo 1"), None, True, None, False), + ( + MyPickleType, + pickleable.Foo("im foo 1"), + pickleable.Foo("im foo 1", stuff="BINDim stuffRESULT"), + True, + None, + False, + ), + (types.BINARY(100), "binary_data_one.dat", None, True, 100, False), + (types.VARBINARY(100), "binary_data_one.dat", None, True, 100, False), + (mssql.VARBINARY(100), "binary_data_one.dat", None, True, 100, False), + (types.BINARY(100), "binary_data_two.dat", None, True, 99, True), + (types.VARBINARY(100), "binary_data_two.dat", None, True, 99, False), + (mssql.VARBINARY(100), "binary_data_two.dat", None, True, 99, False), + argnames="type_, data, expected, deprecate_large_types, " + "slice_, zeropad", + ) + def test_round_trip( + self, + metadata, + type_, + data, + expected, + deprecate_large_types, + slice_, + zeropad, ): if ( testing.db.dialect.deprecate_large_types @@ -1168,14 +1227,25 @@ class BinaryTest(fixtures.TestBase): binary_table = Table( "binary_table", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("data", type_), ) binary_table.create(engine) + if isinstance(data, str) and ( + data == "binary_data_one.dat" or data == "binary_data_two.dat" + ): + data = self._load_stream(data) + + if slice_ is not None: + data = data[0:slice_] + if expected is None: - expected = data + if zeropad: + expected = data[0:slice_] + b"\x00" + else: + expected = data with engine.begin() as conn: conn.execute(binary_table.insert(), data=data) @@ -1205,95 +1275,6 @@ class BinaryTest(fixtures.TestBase): None, ) - def test_plain_pickle(self): - self._test_round_trip(PickleType, pickleable.Foo("im foo 1")) - - def test_custom_pickle(self): - class MyPickleType(types.TypeDecorator): - impl = PickleType - - def process_bind_param(self, value, dialect): - if value: - value.stuff = "BIND" + value.stuff - return value - - def process_result_value(self, value, dialect): - if value: - value.stuff = value.stuff + "RESULT" - return value - - data = pickleable.Foo("im foo 1") - expected = pickleable.Foo("im foo 1") - expected.stuff = "BINDim stuffRESULT" - - self._test_round_trip(MyPickleType, data, expected=expected) - - def test_image(self): - stream1 = self._load_stream("binary_data_one.dat") - self._test_round_trip(mssql.MSImage, stream1) - - def test_large_binary(self): - stream1 = self._load_stream("binary_data_one.dat") - self._test_round_trip(sqltypes.LargeBinary, stream1) - - def test_large_legacy_types(self): - stream1 = self._load_stream("binary_data_one.dat") - self._test_round_trip( - sqltypes.LargeBinary, stream1, deprecate_large_types=False - ) - - def test_mssql_varbinary_max(self): - stream1 = self._load_stream("binary_data_one.dat") - self._test_round_trip(mssql.VARBINARY("max"), stream1) - - def test_mssql_legacy_varbinary_max(self): - stream1 = self._load_stream("binary_data_one.dat") - self._test_round_trip( - mssql.VARBINARY("max"), stream1, deprecate_large_types=False - ) - - def test_binary_slice(self): - self._test_var_slice(types.BINARY) - - def test_binary_slice_zeropadding(self): - self._test_var_slice_zeropadding(types.BINARY, True) - - def test_varbinary_slice(self): - self._test_var_slice(types.VARBINARY) - - def test_varbinary_slice_zeropadding(self): - self._test_var_slice_zeropadding(types.VARBINARY, False) - - def test_mssql_varbinary_slice(self): - self._test_var_slice(mssql.VARBINARY) - - def test_mssql_varbinary_slice_zeropadding(self): - self._test_var_slice_zeropadding(mssql.VARBINARY, False) - - def _test_var_slice(self, type_): - stream1 = self._load_stream("binary_data_one.dat") - - data = stream1[0:100] - - self._test_round_trip(type_(100), data) - - def _test_var_slice_zeropadding( - self, type_, pad, deprecate_large_types=True - ): - stream2 = self._load_stream("binary_data_two.dat") - - data = stream2[0:99] - - # the type we used here is 100 bytes - # so we will get 100 bytes zero-padded - - if pad: - paddedstream = stream2[0:99] + b"\x00" - else: - paddedstream = stream2[0:99] - - self._test_round_trip(type_(100), data, expected=paddedstream) - def _load_stream(self, name, len_=3000): fp = open( os.path.join(os.path.dirname(__file__), "..", "..", name), "rb" diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 2993f96b8..62292b9da 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -710,7 +710,9 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): def test_unsupported_cast_literal_bind(self): expr = cast(column("foo", Integer) + 5, Float) - with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"): + with expect_warnings( + "Datatype FLOAT does not support CAST on MySQL/MariaDb;" + ): self.assert_compile(expr, "(foo + 5)", literal_binds=True) m = mysql @@ -734,11 +736,35 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): def test_unsupported_casts(self, type_, expected): t = sql.table("t", sql.column("col")) - with expect_warnings("Datatype .* does not support CAST on MySQL;"): + with expect_warnings( + "Datatype .* does not support CAST on MySQL/MariaDb;" + ): self.assert_compile(cast(t.c.col, type_), expected) + @testing.combinations( + (m.FLOAT, "CAST(t.col AS FLOAT)"), + (Float, "CAST(t.col AS FLOAT)"), + (FLOAT, "CAST(t.col AS FLOAT)"), + (m.DOUBLE, "CAST(t.col AS DOUBLE)"), + (m.FLOAT, "CAST(t.col AS FLOAT)"), + argnames="type_,expected", + ) + @testing.combinations(True, False, argnames="maria_db") + def test_float_cast(self, type_, expected, maria_db): + + dialect = mysql.dialect() + if maria_db: + dialect.is_mariadb = maria_db + dialect.server_version_info = (10, 4, 5) + else: + dialect.server_version_info = (8, 0, 17) + t = sql.table("t", sql.column("col")) + self.assert_compile(cast(t.c.col, type_), expected, dialect=dialect) + def test_cast_grouped_expression_non_castable(self): - with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"): + with expect_warnings( + "Datatype FLOAT does not support CAST on MySQL/MariaDb;" + ): self.assert_compile( cast(sql.column("x") + sql.column("y"), Float), "(x + y)" ) diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py index 55d88957a..40617e59c 100644 --- a/test/dialect/mysql/test_reflection.py +++ b/test/dialect/mysql/test_reflection.py @@ -44,15 +44,14 @@ class TypeReflectionTest(fixtures.TestBase): __only_on__ = "mysql", "mariadb" __backend__ = True - @testing.provide_metadata - def _run_test(self, specs, attributes): + def _run_test(self, metadata, connection, specs, attributes): columns = [Column("c%i" % (i + 1), t[0]) for i, t in enumerate(specs)] # Early 5.0 releases seem to report more "general" for columns # in a view, e.g. char -> varchar, tinyblob -> mediumblob use_views = testing.db.dialect.server_version_info > (5, 0, 10) - m = self.metadata + m = metadata Table("mysql_types", m, *columns) if use_views: @@ -67,12 +66,12 @@ class TypeReflectionTest(fixtures.TestBase): event.listen( m, "before_drop", DDL("DROP VIEW IF EXISTS mysql_types_v") ) - m.create_all() + m.create_all(connection) m2 = MetaData() - tables = [Table("mysql_types", m2, autoload_with=testing.db)] + tables = [Table("mysql_types", m2, autoload_with=connection)] if use_views: - tables.append(Table("mysql_types_v", m2, autoload_with=testing.db)) + tables.append(Table("mysql_types_v", m2, autoload_with=connection)) for table in tables: for i, (reflected_col, spec) in enumerate(zip(table.c, specs)): @@ -95,7 +94,7 @@ class TypeReflectionTest(fixtures.TestBase): ), ) - def test_time_types(self): + def test_time_types(self, metadata, connection): specs = [] if testing.requires.mysql_fsp.enabled: @@ -118,20 +117,24 @@ class TypeReflectionTest(fixtures.TestBase): ) # note 'timezone' should always be None on both - self._run_test(specs, ["fsp", "timezone"]) + self._run_test(metadata, connection, specs, ["fsp", "timezone"]) - def test_year_types(self): + def test_year_types(self, metadata, connection): specs = [ (mysql.YEAR(), mysql.YEAR(display_width=4)), (mysql.YEAR(display_width=4), mysql.YEAR(display_width=4)), ] if testing.against("mysql>=8.0.19"): - self._run_test(specs, []) + self._run_test(metadata, connection, specs, []) else: - self._run_test(specs, ["display_width"]) + self._run_test(metadata, connection, specs, ["display_width"]) - def test_string_types(self): + def test_string_types( + self, + metadata, + connection, + ): specs = [ (String(1), mysql.MSString(1)), (String(3), mysql.MSString(3)), @@ -145,9 +148,9 @@ class TypeReflectionTest(fixtures.TestBase): (mysql.MSNChar(2), mysql.MSChar(2)), (mysql.MSNVarChar(22), mysql.MSString(22)), ] - self._run_test(specs, ["length"]) + self._run_test(metadata, connection, specs, ["length"]) - def test_integer_types(self): + def test_integer_types(self, metadata, connection): specs = [] for type_ in [ mysql.TINYINT, @@ -201,11 +204,22 @@ class TypeReflectionTest(fixtures.TestBase): # on display_width. need to test this more accurately though # for the cases where it does if testing.against("mysql >= 8.0.19"): - self._run_test(specs, ["unsigned", "zerofill"]) + self._run_test( + metadata, connection, specs, ["unsigned", "zerofill"] + ) else: - self._run_test(specs, ["display_width", "unsigned", "zerofill"]) + self._run_test( + metadata, + connection, + specs, + ["display_width", "unsigned", "zerofill"], + ) - def test_binary_types(self): + def test_binary_types( + self, + metadata, + connection, + ): specs = [ (LargeBinary(3), mysql.TINYBLOB()), (LargeBinary(), mysql.BLOB()), @@ -217,13 +231,17 @@ class TypeReflectionTest(fixtures.TestBase): (mysql.MSMediumBlob(), mysql.MSMediumBlob()), (mysql.MSLongBlob(), mysql.MSLongBlob()), ] - self._run_test(specs, []) + self._run_test(metadata, connection, specs, []) - def test_legacy_enum_types(self): + def test_legacy_enum_types( + self, + metadata, + connection, + ): specs = [(mysql.ENUM("", "fleem"), mysql.ENUM("", "fleem"))] - self._run_test(specs, ["enums"]) + self._run_test(metadata, connection, specs, ["enums"]) class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): @@ -324,8 +342,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): str(reflected.c.c6.server_default.arg).upper(), ) - @testing.provide_metadata - def test_reflection_with_table_options(self, connection): + def test_reflection_with_table_options(self, metadata, connection): comment = r"""Comment types type speedily ' " \ '' Fun!""" if testing.against("mariadb"): kwargs = dict( @@ -348,7 +365,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): def_table = Table( "mysql_def", - self.metadata, + metadata, Column("c1", Integer()), comment=comment, **kwargs @@ -403,11 +420,10 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): # This is explicitly ignored when reflecting schema. # assert reflected.kwargs['mysql_auto_increment'] == '5' - @testing.provide_metadata - def test_reflection_on_include_columns(self): + def test_reflection_on_include_columns(self, metadata, connection): """Test reflection of include_columns to be sure they respect case.""" - meta = self.metadata + meta = metadata case_table = Table( "mysql_case", meta, @@ -416,11 +432,11 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): Column("C3", String(10)), ) - case_table.create(testing.db) + case_table.create(connection) reflected = Table( "mysql_case", MetaData(), - autoload_with=testing.db, + autoload_with=connection, include_columns=["c1", "C2"], ) for t in case_table, reflected: @@ -429,16 +445,15 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): reflected2 = Table( "mysql_case", MetaData(), - autoload_with=testing.db, + autoload_with=connection, include_columns=["c1", "c2"], ) assert "c1" in reflected2.c.keys() for c in ["c2", "C2", "C3"]: assert c not in reflected2.c.keys() - @testing.provide_metadata - def test_autoincrement(self): - meta = self.metadata + def test_autoincrement(self, metadata, connection): + meta = metadata Table( "ai_1", meta, @@ -520,7 +535,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): Column("o2", String(1), DefaultClause("x"), primary_key=True), mysql_engine="MyISAM", ) - meta.create_all(testing.db) + meta.create_all(connection) table_names = [ "ai_1", @@ -533,30 +548,27 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): "ai_8", ] mr = MetaData() - mr.reflect(testing.db, only=table_names) - - with testing.db.begin() as conn: - for tbl in [mr.tables[name] for name in table_names]: - for c in tbl.c: - if c.name.startswith("int_y"): - assert c.autoincrement - elif c.name.startswith("int_n"): - assert not c.autoincrement - conn.execute(tbl.insert()) - if "int_y" in tbl.c: - assert conn.scalar(select(tbl.c.int_y)) == 1 - assert ( - list(conn.execute(tbl.select()).first()).count(1) == 1 - ) - else: - assert 1 not in list(conn.execute(tbl.select()).first()) + mr.reflect(connection, only=table_names) + + for tbl in [mr.tables[name] for name in table_names]: + for c in tbl.c: + if c.name.startswith("int_y"): + assert c.autoincrement + elif c.name.startswith("int_n"): + assert not c.autoincrement + connection.execute(tbl.insert()) + if "int_y" in tbl.c: + assert connection.scalar(select(tbl.c.int_y)) == 1 + assert ( + list(connection.execute(tbl.select()).first()).count(1) + == 1 + ) + else: + assert 1 not in list(connection.execute(tbl.select()).first()) - @testing.provide_metadata - def test_view_reflection(self, connection): - Table( - "x", self.metadata, Column("a", Integer), Column("b", String(50)) - ) - self.metadata.create_all(connection) + def test_view_reflection(self, metadata, connection): + Table("x", metadata, Column("a", Integer), Column("b", String(50))) + metadata.create_all(connection) conn = connection conn.exec_driver_sql("CREATE VIEW v1 AS SELECT * FROM x") @@ -570,7 +582,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): "CREATE DEFINER=CURRENT_USER VIEW v4 AS SELECT * FROM x" ) - @event.listens_for(self.metadata, "before_drop") + @event.listens_for(metadata, "before_drop") def cleanup(*arg, **kw): with testing.db.begin() as conn: for v in ["v1", "v2", "v3", "v4"]: @@ -586,9 +598,8 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): [("a", mysql.INTEGER), ("b", mysql.VARCHAR)], ) - @testing.provide_metadata - def test_skip_not_describable(self, connection): - @event.listens_for(self.metadata, "before_drop") + def test_skip_not_describable(self, metadata, connection): + @event.listens_for(metadata, "before_drop") def cleanup(*arg, **kw): with testing.db.begin() as conn: conn.exec_driver_sql("DROP TABLE IF EXISTS test_t1") @@ -625,20 +636,18 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): view_names = dialect.get_view_names(connection, "information_schema") self.assert_("TABLES" in view_names) - @testing.provide_metadata - def test_nullable_reflection(self): + def test_nullable_reflection(self, metadata, connection): """test reflection of NULL/NOT NULL, in particular with TIMESTAMP defaults where MySQL is inconsistent in how it reports CREATE TABLE. """ - meta = self.metadata + meta = metadata # this is ideally one table, but older MySQL versions choke # on the multiple TIMESTAMP columns - with testing.db.connect() as c: - row = c.exec_driver_sql( - "show variables like '%%explicit_defaults_for_timestamp%%'" - ).first() + row = connection.exec_driver_sql( + "show variables like '%%explicit_defaults_for_timestamp%%'" + ).first() explicit_defaults_for_timestamp = row[1].lower() in ("on", "1", "true") reflected = [] @@ -659,15 +668,14 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): ): Table("nn_t%d" % idx, meta) # to allow DROP - with testing.db.begin() as c: - c.exec_driver_sql( - """ - CREATE TABLE nn_t%d ( - %s - ) - """ - % (idx, ", \n".join(cols)) - ) + connection.exec_driver_sql( + """ + CREATE TABLE nn_t%d ( + %s + ) + """ + % (idx, ", \n".join(cols)) + ) reflected.extend( { @@ -675,10 +683,10 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): "nullable": d["nullable"], "default": d["default"], } - for d in inspect(testing.db).get_columns("nn_t%d" % idx) + for d in inspect(connection).get_columns("nn_t%d" % idx) ) - if testing.db.dialect._is_mariadb_102: + if connection.dialect._is_mariadb_102: current_timestamp = "current_timestamp()" else: current_timestamp = "CURRENT_TIMESTAMP" @@ -726,11 +734,10 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): ], ) - @testing.provide_metadata - def test_reflection_with_unique_constraint(self): - insp = inspect(testing.db) + def test_reflection_with_unique_constraint(self, metadata, connection): + insp = inspect(connection) - meta = self.metadata + meta = metadata uc_table = Table( "mysql_uc", meta, @@ -738,7 +745,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): UniqueConstraint("a", name="uc_a"), ) - uc_table.create() + uc_table.create(connection) # MySQL converts unique constraints into unique indexes. # separately we get both @@ -762,11 +769,10 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_(indexes["uc_a"].unique) self.assert_("uc_a" not in constraints) - @testing.provide_metadata - def test_reflect_fulltext(self): + def test_reflect_fulltext(self, metadata, connection): mt = Table( "mytable", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("textdata", String(50)), mariadb_engine="InnoDB", @@ -779,7 +785,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): mysql_prefix="FULLTEXT", mariadb_prefix="FULLTEXT", ) - self.metadata.create_all(testing.db) + metadata.create_all(connection) mt = Table("mytable", MetaData(), autoload_with=testing.db) idx = list(mt.indexes)[0] @@ -791,11 +797,14 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): ) @testing.requires.mysql_ngram_fulltext - @testing.provide_metadata - def test_reflect_fulltext_comment(self): + def test_reflect_fulltext_comment( + self, + metadata, + connection, + ): mt = Table( "mytable", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("textdata", String(50)), mysql_engine="InnoDB", @@ -807,9 +816,9 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): mysql_with_parser="ngram", ) - self.metadata.create_all(testing.db) + metadata.create_all(connection) - mt = Table("mytable", MetaData(), autoload_with=testing.db) + mt = Table("mytable", MetaData(), autoload_with=connection) idx = list(mt.indexes)[0] eq_(idx.name, "textdata_ix") eq_(idx.dialect_options["mysql"]["prefix"], "FULLTEXT") @@ -820,16 +829,15 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): "(textdata) WITH PARSER ngram", ) - @testing.provide_metadata - def test_non_column_index(self): - m1 = self.metadata + def test_non_column_index(self, metadata, connection): + m1 = metadata t1 = Table( "add_ix", m1, Column("x", String(50)), mysql_engine="InnoDB" ) Index("foo_idx", t1.c.x.desc()) - m1.create_all() + m1.create_all(connection) - insp = inspect(testing.db) + insp = inspect(connection) eq_( insp.get_indexes("add_ix"), [{"name": "foo_idx", "column_names": ["x"], "unique": False}], @@ -950,12 +958,13 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): ], ) - @testing.provide_metadata - def test_case_sensitive_column_constraint_reflection(self): + def test_case_sensitive_column_constraint_reflection( + self, metadata, connection + ): # test for issue #4344 which works around # MySQL 8.0 bug https://bugs.mysql.com/bug.php?id=88718 - m1 = self.metadata + m1 = metadata Table( "Track", @@ -987,9 +996,9 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): ), mysql_engine="InnoDB", ) - m1.create_all() + m1.create_all(connection) - if testing.db.dialect._casing in (1, 2): + if connection.dialect._casing in (1, 2): # the original test for the 88718 fix here in [ticket:4344] # actually set referred_table='track', with the wrong casing! # this test was never run. with [ticket:4751], I've gone through @@ -999,7 +1008,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): # lower case is also an 8.0 regression. eq_( - inspect(testing.db).get_foreign_keys("PlaylistTrack"), + inspect(connection).get_foreign_keys("PlaylistTrack"), [ { "name": "FK_PlaylistTTrackId", @@ -1022,7 +1031,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): else: eq_( sorted( - inspect(testing.db).get_foreign_keys("PlaylistTrack"), + inspect(connection).get_foreign_keys("PlaylistTrack"), key=lambda elem: elem["name"], ), [ @@ -1046,12 +1055,13 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): ) @testing.requires.mysql_fully_case_sensitive - @testing.provide_metadata - def test_case_sensitive_reflection_dual_case_references(self): + def test_case_sensitive_reflection_dual_case_references( + self, metadata, connection + ): # this tests that within the fix we do for MySQL bug # 88718, we don't do case-insensitive logic if the backend # is case sensitive - m = self.metadata + m = metadata Table( "t1", m, @@ -1074,12 +1084,12 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): Column("cap_t1id", ForeignKey("T1.Some_Id", name="cap_t1id_fk")), mysql_engine="InnoDB", ) - m.create_all(testing.db) + m.create_all(connection) eq_( dict( (rec["name"], rec) - for rec in inspect(testing.db).get_foreign_keys("t2") + for rec in inspect(connection).get_foreign_keys("t2") ), { "cap_t1id_fk": { diff --git a/test/dialect/mysql/test_types.py b/test/dialect/mysql/test_types.py index f4621dce3..3e8aa0fb5 100644 --- a/test/dialect/mysql/test_types.py +++ b/test/dialect/mysql/test_types.py @@ -474,11 +474,10 @@ class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults): # fixed in mysql-connector as of 2.0.1, # see http://bugs.mysql.com/bug.php?id=73266 - @testing.provide_metadata - def test_precision_float_roundtrip(self, connection): + def test_precision_float_roundtrip(self, metadata, connection): t = Table( "t", - self.metadata, + metadata, Column( "scale_value", mysql.DOUBLE(precision=15, scale=12, asdecimal=True), @@ -503,11 +502,10 @@ class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults): eq_(result, decimal.Decimal("45.768392065789")) @testing.only_if("mysql") - @testing.provide_metadata - def test_charset_collate_table(self, connection): + def test_charset_collate_table(self, metadata, connection): t = Table( "foo", - self.metadata, + metadata, Column("id", Integer), Column("data", UnicodeText), mysql_default_charset="utf8", @@ -657,19 +655,21 @@ class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults): impl = TIMESTAMP @testing.combinations( - (TIMESTAMP,), (MyTime(),), (String().with_variant(TIMESTAMP, "mysql"),) + (TIMESTAMP,), + (MyTime(),), + (String().with_variant(TIMESTAMP, "mysql"),), + argnames="type_", ) @testing.requires.mysql_zero_date - @testing.provide_metadata - def test_timestamp_nullable(self, type_): + def test_timestamp_nullable(self, metadata, connection, type_): ts_table = Table( "mysql_timestamp", - self.metadata, + metadata, Column("t1", type_), Column("t2", type_, nullable=False), mysql_engine="InnoDB", ) - self.metadata.create_all() + metadata.create_all(connection) # TIMESTAMP without NULL inserts current time when passed # NULL. when not passed, generates 0000-00-00 quite @@ -687,25 +687,23 @@ class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults): else: return dt - with testing.db.begin() as conn: - now = conn.exec_driver_sql("select now()").scalar() - conn.execute(ts_table.insert(), {"t1": now, "t2": None}) - conn.execute(ts_table.insert(), {"t1": None, "t2": None}) - conn.execute(ts_table.insert(), {"t2": None}) + now = connection.exec_driver_sql("select now()").scalar() + connection.execute(ts_table.insert(), {"t1": now, "t2": None}) + connection.execute(ts_table.insert(), {"t1": None, "t2": None}) + connection.execute(ts_table.insert(), {"t2": None}) - new_now = conn.exec_driver_sql("select now()").scalar() + new_now = connection.exec_driver_sql("select now()").scalar() - eq_( - [ - tuple([normalize(dt) for dt in row]) - for row in conn.execute(ts_table.select()) - ], - [(now, now), (None, now), (None, now)], - ) + eq_( + [ + tuple([normalize(dt) for dt in row]) + for row in connection.execute(ts_table.select()) + ], + [(now, now), (None, now), (None, now)], + ) - @testing.provide_metadata - def test_time_roundtrip(self, connection): - t = Table("mysql_time", self.metadata, Column("t1", mysql.TIME())) + def test_time_roundtrip(self, metadata, connection): + t = Table("mysql_time", metadata, Column("t1", mysql.TIME())) t.create(connection) @@ -715,13 +713,12 @@ class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults): datetime.time(8, 37, 35), ) - @testing.provide_metadata - def test_year(self, connection): + def test_year(self, metadata, connection): """Exercise YEAR.""" year_table = Table( "mysql_year", - self.metadata, + metadata, Column("y1", mysql.MSYear), Column("y2", mysql.MSYear), Column("y3", mysql.MSYear), @@ -748,26 +745,22 @@ class JSONTest(fixtures.TestBase): __only_on__ = "mysql", "mariadb" __backend__ = True - @testing.provide_metadata @testing.requires.reflects_json_type - def test_reflection(self, connection): + def test_reflection(self, metadata, connection): - Table("mysql_json", self.metadata, Column("foo", mysql.JSON)) - self.metadata.create_all(connection) + Table("mysql_json", metadata, Column("foo", mysql.JSON)) + metadata.create_all(connection) reflected = Table("mysql_json", MetaData(), autoload_with=connection) is_(reflected.c.foo.type._type_affinity, sqltypes.JSON) assert isinstance(reflected.c.foo.type, mysql.JSON) - @testing.provide_metadata - def test_rudimental_round_trip(self, connection): + def test_rudimental_round_trip(self, metadata, connection): # note that test_suite has many more JSON round trip tests # using the backend-agnostic JSON type - mysql_json = Table( - "mysql_json", self.metadata, Column("foo", mysql.JSON) - ) - self.metadata.create_all(connection) + mysql_json = Table("mysql_json", metadata, Column("foo", mysql.JSON)) + metadata.create_all(connection) value = {"json": {"foo": "bar"}, "recs": ["one", "two"]} @@ -804,8 +797,7 @@ class EnumSetTest( def get_enum_string_values(some_enum): return [str(v.value) for v in some_enum.__members__.values()] - @testing.provide_metadata - def test_enum(self, connection): + def test_enum(self, metadata, connection): """Exercise the ENUM type.""" e1 = mysql.ENUM("a", "b") @@ -815,7 +807,7 @@ class EnumSetTest( enum_table = Table( "mysql_enum", - self.metadata, + metadata, Column("e1", e1), Column("e2", e2, nullable=False), Column( @@ -857,11 +849,14 @@ class EnumSetTest( assert_raises( exc.DBAPIError, - enum_table.insert().execute, - e1=None, - e2=None, - e3=None, - e4=None, + connection.execute, + enum_table.insert(), + dict( + e1=None, + e2=None, + e3=None, + e4=None, + ), ) assert enum_table.c.e2generic.type.validate_strings @@ -948,7 +943,7 @@ class EnumSetTest( eq_(res, expected) - def _set_fixture_one(self): + def _set_fixture_one(self, metadata): e1 = mysql.SET("a", "b") e2 = mysql.SET("a", "b") e4 = mysql.SET("'a'", "b") @@ -956,7 +951,7 @@ class EnumSetTest( set_table = Table( "mysql_set", - self.metadata, + metadata, Column("e1", e1), Column("e2", e2, nullable=False), Column("e3", mysql.SET("a", "b")), @@ -965,18 +960,16 @@ class EnumSetTest( ) return set_table - def test_set_colspec(self): - self.metadata = MetaData() - set_table = self._set_fixture_one() + def test_set_colspec(self, metadata): + set_table = self._set_fixture_one(metadata) eq_(colspec(set_table.c.e1), "e1 SET('a','b')") eq_(colspec(set_table.c.e2), "e2 SET('a','b') NOT NULL") eq_(colspec(set_table.c.e3), "e3 SET('a','b')") eq_(colspec(set_table.c.e4), "e4 SET('''a''','b')") eq_(colspec(set_table.c.e5), "e5 SET('a','b')") - @testing.provide_metadata - def test_no_null(self, connection): - set_table = self._set_fixture_one() + def test_no_null(self, metadata, connection): + set_table = self._set_fixture_one(metadata) set_table.create(connection) assert_raises( exc.DBAPIError, @@ -986,11 +979,10 @@ class EnumSetTest( ) @testing.requires.mysql_non_strict - @testing.provide_metadata - def test_empty_set_no_empty_string(self, connection): + def test_empty_set_no_empty_string(self, metadata, connection): t = Table( "t", - self.metadata, + metadata, Column("id", Integer), Column("data", mysql.SET("a", "b")), ) @@ -1020,11 +1012,10 @@ class EnumSetTest( "", ) - @testing.provide_metadata - def test_empty_set_empty_string(self, connection): + def test_empty_set_empty_string(self, metadata, connection): t = Table( "t", - self.metadata, + metadata, Column("id", Integer), Column("data", mysql.SET("a", "b", "", retrieve_as_bitwise=True)), ) @@ -1048,9 +1039,8 @@ class EnumSetTest( ], ) - @testing.provide_metadata - def test_string_roundtrip(self, connection): - set_table = self._set_fixture_one() + def test_string_roundtrip(self, metadata, connection): + set_table = self._set_fixture_one(metadata) set_table.create(connection) connection.execute( set_table.insert(), @@ -1081,11 +1071,10 @@ class EnumSetTest( eq_(res, expected) - @testing.provide_metadata - def test_unicode_roundtrip(self, connection): + def test_unicode_roundtrip(self, metadata, connection): set_table = Table( "t", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("data", mysql.SET(u("réveillé"), u("drôle"), u("S’il"))), ) @@ -1099,9 +1088,8 @@ class EnumSetTest( eq_(row, (1, set([u("réveillé"), u("drôle")]))) - @testing.provide_metadata - def test_int_roundtrip(self, connection): - set_table = self._set_fixture_one() + def test_int_roundtrip(self, metadata, connection): + set_table = self._set_fixture_one(metadata) set_table.create(connection) connection.execute( set_table.insert(), dict(e1=1, e2=2, e3=3, e4=3, e5=0) @@ -1118,11 +1106,10 @@ class EnumSetTest( ), ) - @testing.provide_metadata - def test_set_roundtrip_plus_reflection(self, connection): + def test_set_roundtrip_plus_reflection(self, metadata, connection): set_table = Table( "mysql_set", - self.metadata, + metadata, Column("s1", mysql.SET("dq", "sq")), Column("s2", mysql.SET("a")), Column("s3", mysql.SET("5", "7", "9")), @@ -1166,9 +1153,7 @@ class EnumSetTest( eq_(list(rows), [({"5"},), ({"7", "5"},)]) - @testing.provide_metadata - def test_unicode_enum(self, connection): - metadata = self.metadata + def test_unicode_enum(self, metadata, connection): t1 = Table( "table", metadata, @@ -1232,12 +1217,11 @@ class EnumSetTest( "'y', 'z')))", ) - @testing.provide_metadata - def test_enum_parse(self, connection): + def test_enum_parse(self, metadata, connection): enum_table = Table( "mysql_enum", - self.metadata, + metadata, Column("e1", mysql.ENUM("a")), Column("e2", mysql.ENUM("")), Column("e3", mysql.ENUM("a")), @@ -1261,11 +1245,10 @@ class EnumSetTest( eq_(t.c.e6.type.enums, ["", "a"]) eq_(t.c.e7.type.enums, ["", "'a'", "b'b", "'"]) - @testing.provide_metadata - def test_set_parse(self, connection): + def test_set_parse(self, metadata, connection): set_table = Table( "mysql_set", - self.metadata, + metadata, Column("e1", mysql.SET("a")), Column("e2", mysql.SET("", retrieve_as_bitwise=True)), Column("e3", mysql.SET("a")), @@ -1301,11 +1284,10 @@ class EnumSetTest( eq_(t.c.e7.type.values, ("", "'a'", "b'b", "'")) @testing.requires.mysql_non_strict - @testing.provide_metadata - def test_broken_enum_returns_blanks(self, connection): + def test_broken_enum_returns_blanks(self, metadata, connection): t = Table( "enum_missing", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("e1", sqltypes.Enum("one", "two", "three")), Column("e2", mysql.ENUM("one", "two", "three")), diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index fa66a64d5..df87fe89f 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -481,11 +481,9 @@ class QuotedBindRoundTripTest(fixtures.TestBase): __only_on__ = "oracle" __backend__ = True - @testing.provide_metadata - def test_table_round_trip(self, connection): + def test_table_round_trip(self, metadata, connection): oracle.RESERVED_WORDS.discard("UNION") - metadata = self.metadata table = Table( "t1", metadata, @@ -496,7 +494,7 @@ class QuotedBindRoundTripTest(fixtures.TestBase): # is set Column("union", Integer, quote=True), ) - metadata.create_all() + metadata.create_all(connection) connection.execute( table.insert(), {"option": 1, "plain": 1, "union": 1} @@ -516,17 +514,15 @@ class QuotedBindRoundTripTest(fixtures.TestBase): 4, ) - @testing.provide_metadata - def test_numeric_bind_in_crud(self, connection): - t = Table("asfd", self.metadata, Column("100K", Integer)) + def test_numeric_bind_in_crud(self, metadata, connection): + t = Table("asfd", metadata, Column("100K", Integer)) t.create(connection) connection.execute(t.insert(), {"100K": 10}) eq_(connection.scalar(t.select()), 10) - @testing.provide_metadata - def test_expanding_quote_roundtrip(self, connection): - t = Table("asfd", self.metadata, Column("foo", Integer)) + def test_expanding_quote_roundtrip(self, metadata, connection): + t = Table("asfd", metadata, Column("foo", Integer)) t.create(connection) connection.execute( @@ -747,9 +743,7 @@ class ExecuteTest(fixtures.TestBase): finally: seq.drop(connection) - @testing.provide_metadata - def test_limit_offset_for_update(self, connection): - metadata = self.metadata + def test_limit_offset_for_update(self, metadata, connection): # oracle can't actually do the ROWNUM thing with FOR UPDATE # very well. @@ -794,15 +788,13 @@ class UnicodeSchemaTest(fixtures.TestBase): __only_on__ = "oracle" __backend__ = True - @testing.provide_metadata - def test_quoted_column_non_unicode(self, connection): - metadata = self.metadata + def test_quoted_column_non_unicode(self, metadata, connection): table = Table( "atable", metadata, Column("_underscorecolumn", Unicode(255), primary_key=True), ) - metadata.create_all() + metadata.create_all(connection) connection.execute(table.insert(), {"_underscorecolumn": u("’é")}) result = connection.execute( @@ -810,15 +802,13 @@ class UnicodeSchemaTest(fixtures.TestBase): ).scalar() eq_(result, u("’é")) - @testing.provide_metadata - def test_quoted_column_unicode(self, connection): - metadata = self.metadata + def test_quoted_column_unicode(self, metadata, connection): table = Table( "atable", metadata, Column(u("méil"), Unicode(255), primary_key=True), ) - metadata.create_all() + metadata.create_all(connection) connection.execute(table.insert(), {u("méil"): u("’é")}) result = connection.execute( diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index 2e515556f..81e4e4ab5 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -103,10 +103,9 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): if stmt.strip(): conn.exec_driver_sql(stmt) - @testing.provide_metadata - def test_create_same_names_explicit_schema(self): + def test_create_same_names_explicit_schema(self, metadata, connection): schema = testing.db.dialect.default_schema_name - meta = self.metadata + meta = metadata parent = Table( "parent", meta, @@ -120,11 +119,10 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): Column("pid", Integer, ForeignKey("%s.parent.pid" % schema)), schema=schema, ) - with testing.db.begin() as conn: - meta.create_all(conn) - conn.execute(parent.insert(), {"pid": 1}) - conn.execute(child.insert(), {"cid": 1, "pid": 1}) - eq_(conn.execute(child.select()).fetchall(), [(1, 1)]) + meta.create_all(connection) + connection.execute(parent.insert(), {"pid": 1}) + connection.execute(child.insert(), {"cid": 1, "pid": 1}) + eq_(connection.execute(child.select()).fetchall(), [(1, 1)]) def test_reflect_alt_table_owner_local_synonym(self): meta = MetaData() @@ -158,9 +156,8 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): % {"test_schema": testing.config.test_schema}, ) - @testing.provide_metadata - def test_create_same_names_implicit_schema(self, connection): - meta = self.metadata + def test_create_same_names_implicit_schema(self, metadata, connection): + meta = metadata parent = Table( "parent", meta, Column("pid", Integer, primary_key=True) ) @@ -205,18 +202,17 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): # check table comment (#5146) eq_(parent.comment, "my table comment") - @testing.provide_metadata - def test_reflect_table_comment(self): + def test_reflect_table_comment(self, metadata, connection): local_parent = Table( "parent", - self.metadata, + metadata, Column("q", Integer), comment="my local comment", ) - local_parent.create(testing.db) + local_parent.create(connection) - insp = inspect(testing.db) + insp = inspect(connection) eq_( insp.get_table_comment( "parent", schema=testing.config.test_schema @@ -231,7 +227,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): ) eq_( insp.get_table_comment( - "parent", schema=testing.db.dialect.default_schema_name + "parent", schema=connection.dialect.default_schema_name ), {"text": "my local comment"}, ) @@ -347,28 +343,28 @@ class ConstraintTest(fixtures.TablesTest): def define_tables(cls, metadata): Table("foo", metadata, Column("id", Integer, primary_key=True)) - def test_oracle_has_no_on_update_cascade(self): + def test_oracle_has_no_on_update_cascade(self, connection): bar = Table( "bar", - self.metadata, + self.tables_test_metadata, Column("id", Integer, primary_key=True), Column( "foo_id", Integer, ForeignKey("foo.id", onupdate="CASCADE") ), ) - assert_raises(exc.SAWarning, bar.create) + assert_raises(exc.SAWarning, bar.create, connection) bat = Table( "bat", - self.metadata, + self.tables_test_metadata, Column("id", Integer, primary_key=True), Column("foo_id", Integer), ForeignKeyConstraint(["foo_id"], ["foo.id"], onupdate="CASCADE"), ) - assert_raises(exc.SAWarning, bat.create) + assert_raises(exc.SAWarning, bat.create, connection) - def test_reflect_check_include_all(self): - insp = inspect(testing.db) + def test_reflect_check_include_all(self, connection): + insp = inspect(connection) eq_(insp.get_check_constraints("foo"), []) eq_( [ @@ -446,9 +442,9 @@ class DontReflectIOTTest(fixtures.TestBase): with testing.db.begin() as conn: conn.exec_driver_sql("drop table admin_docindex") - def test_reflect_all(self): - m = MetaData(testing.db) - m.reflect() + def test_reflect_all(self, connection): + m = MetaData() + m.reflect(connection) eq_(set(t.name for t in m.tables.values()), set(["admin_docindex"])) @@ -477,10 +473,8 @@ class TableReflectionTest(fixtures.TestBase): __only_on__ = "oracle" __backend__ = True - @testing.provide_metadata @testing.fails_if(all_tables_compression_missing) - def test_reflect_basic_compression(self): - metadata = self.metadata + def test_reflect_basic_compression(self, metadata, connection): tbl = Table( "test_compress", @@ -488,30 +482,27 @@ class TableReflectionTest(fixtures.TestBase): Column("data", Integer, primary_key=True), oracle_compress=True, ) - metadata.create_all() + metadata.create_all(connection) m2 = MetaData() - tbl = Table("test_compress", m2, autoload_with=testing.db) + tbl = Table("test_compress", m2, autoload_with=connection) # Don't hardcode the exact value, but it must be non-empty assert tbl.dialect_options["oracle"]["compress"] - @testing.provide_metadata @testing.fails_if(all_tables_compress_for_missing) - def test_reflect_oltp_compression(self): - metadata = self.metadata - + def test_reflect_oltp_compression(self, metadata, connection): tbl = Table( "test_compress", metadata, Column("data", Integer, primary_key=True), oracle_compress="OLTP", ) - metadata.create_all() + metadata.create_all(connection) m2 = MetaData() - tbl = Table("test_compress", m2, autoload_with=testing.db) + tbl = Table("test_compress", m2, autoload_with=connection) assert tbl.dialect_options["oracle"]["compress"] == "OLTP" @@ -519,10 +510,7 @@ class RoundTripIndexTest(fixtures.TestBase): __only_on__ = "oracle" __backend__ = True - @testing.provide_metadata - def test_no_pk(self): - metadata = self.metadata - + def test_no_pk(self, metadata, connection): Table( "sometable", metadata, @@ -531,9 +519,9 @@ class RoundTripIndexTest(fixtures.TestBase): Index("pk_idx_1", "id_a", "id_b", unique=True), Index("pk_idx_2", "id_b", "id_a", unique=True), ) - metadata.create_all() + metadata.create_all(connection) - insp = inspect(testing.db) + insp = inspect(connection) eq_( insp.get_indexes("sometable"), [ @@ -552,10 +540,10 @@ class RoundTripIndexTest(fixtures.TestBase): ], ) - @testing.combinations((True,), (False,)) - @testing.provide_metadata - def test_include_indexes_resembling_pk(self, explicit_pk): - metadata = self.metadata + @testing.combinations((True,), (False,), argnames="explicit_pk") + def test_include_indexes_resembling_pk( + self, metadata, connection, explicit_pk + ): t = Table( "sometable", @@ -575,9 +563,9 @@ class RoundTripIndexTest(fixtures.TestBase): "id_a", "id_b", "group", name="some_primary_key" ) ) - metadata.create_all() + metadata.create_all(connection) - insp = inspect(testing.db) + insp = inspect(connection) eq_( insp.get_indexes("sometable"), [ @@ -596,8 +584,7 @@ class RoundTripIndexTest(fixtures.TestBase): ], ) - @testing.provide_metadata - def test_reflect_fn_index(self, connection): + def test_reflect_fn_index(self, metadata, connection): """test reflection of a functional index. it appears this emitted a warning at some point but does not right now. @@ -606,7 +593,6 @@ class RoundTripIndexTest(fixtures.TestBase): """ - metadata = self.metadata s_table = Table( "sometable", metadata, @@ -630,9 +616,7 @@ class RoundTripIndexTest(fixtures.TestBase): ], ) - @testing.provide_metadata - def test_basic(self): - metadata = self.metadata + def test_basic(self, metadata, connection): s_table = Table( "sometable", @@ -657,16 +641,16 @@ class RoundTripIndexTest(fixtures.TestBase): oracle_compress=1, ) - metadata.create_all() + metadata.create_all(connection) - mirror = MetaData(testing.db) - mirror.reflect() + mirror = MetaData() + mirror.reflect(connection) - metadata.drop_all() - mirror.create_all() + metadata.drop_all(connection) + mirror.create_all(connection) - inspect = MetaData(testing.db) - inspect.reflect() + inspect = MetaData() + inspect.reflect(connection) def obj_definition(obj): return ( @@ -676,7 +660,7 @@ class RoundTripIndexTest(fixtures.TestBase): ) # find what the primary k constraint name should be - primaryconsname = testing.db.scalar( + primaryconsname = connection.scalar( text( """SELECT constraint_name FROM all_constraints @@ -773,14 +757,13 @@ class TypeReflectionTest(fixtures.TestBase): __only_on__ = "oracle" __backend__ = True - @testing.provide_metadata - def _run_test(self, specs, attributes): + def _run_test(self, metadata, connection, specs, attributes): columns = [Column("c%i" % (i + 1), t[0]) for i, t in enumerate(specs)] - m = self.metadata + m = metadata Table("oracle_types", m, *columns) - m.create_all() + m.create_all(connection) m2 = MetaData() - table = Table("oracle_types", m2, autoload_with=testing.db) + table = Table("oracle_types", m2, autoload_with=connection) for i, (reflected_col, spec) in enumerate(zip(table.c, specs)): expected_spec = spec[1] reflected_type = reflected_col.type @@ -800,15 +783,23 @@ class TypeReflectionTest(fixtures.TestBase): ), ) - def test_integer_types(self): + def test_integer_types(self, metadata, connection): specs = [(Integer, INTEGER()), (Numeric, INTEGER())] - self._run_test(specs, []) + self._run_test(metadata, connection, specs, []) - def test_number_types(self): + def test_number_types( + self, + metadata, + connection, + ): specs = [(Numeric(5, 2), NUMBER(5, 2)), (NUMBER, NUMBER())] - self._run_test(specs, ["precision", "scale"]) + self._run_test(metadata, connection, specs, ["precision", "scale"]) - def test_float_types(self): + def test_float_types( + self, + metadata, + connection, + ): specs = [ (DOUBLE_PRECISION(), FLOAT()), # when binary_precision is supported @@ -822,7 +813,7 @@ class TypeReflectionTest(fixtures.TestBase): # when binary_precision is supported # (FLOAT(5), oracle.FLOAT(binary_precision=126),), ] - self._run_test(specs, ["precision"]) + self._run_test(metadata, connection, specs, ["precision"]) class IdentityReflectionTest(fixtures.TablesTest): diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index db3825d13..f008ea019 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -188,10 +188,9 @@ class TypesTest(fixtures.TestBase): __dialect__ = oracle.OracleDialect() __backend__ = True - @testing.combinations((CHAR,), (NCHAR,)) - @testing.provide_metadata - def test_fixed_char(self, char_type): - m = self.metadata + @testing.combinations((CHAR,), (NCHAR,), argnames="char_type") + def test_fixed_char(self, metadata, connection, char_type): + m = metadata t = Table( "t1", m, @@ -204,32 +203,30 @@ class TypesTest(fixtures.TestBase): else: v1, v2, v3 = "value 1", "value 2", "value 3" - with testing.db.begin() as conn: - t.create(conn) - conn.execute( - t.insert(), - dict(id=1, data=v1), - dict(id=2, data=v2), - dict(id=3, data=v3), - ) + t.create(connection) + connection.execute( + t.insert(), + dict(id=1, data=v1), + dict(id=2, data=v2), + dict(id=3, data=v3), + ) - eq_( - conn.execute(t.select().where(t.c.data == v2)).fetchall(), - [(2, "value 2 ")], - ) + eq_( + connection.execute(t.select().where(t.c.data == v2)).fetchall(), + [(2, "value 2 ")], + ) - m2 = MetaData() - t2 = Table("t1", m2, autoload_with=conn) - is_(type(t2.c.data.type), char_type) - eq_( - conn.execute(t2.select().where(t2.c.data == v2)).fetchall(), - [(2, "value 2 ")], - ) + m2 = MetaData() + t2 = Table("t1", m2, autoload_with=connection) + is_(type(t2.c.data.type), char_type) + eq_( + connection.execute(t2.select().where(t2.c.data == v2)).fetchall(), + [(2, "value 2 ")], + ) @testing.requires.returning - @testing.provide_metadata - def test_int_not_float(self, connection): - m = self.metadata + def test_int_not_float(self, metadata, connection): + m = metadata t1 = Table("t1", m, Column("foo", Integer)) t1.create(connection) r = connection.execute(t1.insert().values(foo=5).returning(t1.c.foo)) @@ -242,14 +239,13 @@ class TypesTest(fixtures.TestBase): assert isinstance(x, int) @testing.requires.returning - @testing.provide_metadata - def test_int_not_float_no_coerce_decimal(self): + def test_int_not_float_no_coerce_decimal(self, metadata): engine = testing_engine(options=dict(coerce_to_decimal=False)) - m = self.metadata + m = metadata t1 = Table("t1", m, Column("foo", Integer)) with engine.begin() as conn: - t1.create() + t1.create(conn) r = conn.execute(t1.insert().values(foo=5).returning(t1.c.foo)) x = r.scalar() assert x == 5 @@ -259,30 +255,25 @@ class TypesTest(fixtures.TestBase): assert x == 5 assert isinstance(x, int) - @testing.provide_metadata - def test_rowid(self): - metadata = self.metadata + def test_rowid(self, metadata, connection): t = Table("t1", metadata, Column("x", Integer)) - with testing.db.begin() as conn: - t.create(conn) - conn.execute(t.insert(), {"x": 5}) - s1 = select(t).subquery() - s2 = select(column("rowid")).select_from(s1) - rowid = conn.scalar(s2) - - # the ROWID type is not really needed here, - # as cx_oracle just treats it as a string, - # but we want to make sure the ROWID works... - rowid_col = column("rowid", oracle.ROWID) - s3 = select(t.c.x, rowid_col).where( - rowid_col == cast(rowid, oracle.ROWID) - ) - eq_(conn.execute(s3).fetchall(), [(5, rowid)]) + t.create(connection) + connection.execute(t.insert(), {"x": 5}) + s1 = select(t).subquery() + s2 = select(column("rowid")).select_from(s1) + rowid = connection.scalar(s2) + + # the ROWID type is not really needed here, + # as cx_oracle just treats it as a string, + # but we want to make sure the ROWID works... + rowid_col = column("rowid", oracle.ROWID) + s3 = select(t.c.x, rowid_col).where( + rowid_col == cast(rowid, oracle.ROWID) + ) + eq_(connection.execute(s3).fetchall(), [(5, rowid)]) - @testing.provide_metadata - def test_interval(self, connection): - metadata = self.metadata + def test_interval(self, metadata, connection): interval_table = Table( "intervaltable", metadata, @@ -299,9 +290,8 @@ class TypesTest(fixtures.TestBase): row = connection.execute(interval_table.select()).first() eq_(row["day_interval"], datetime.timedelta(days=35, seconds=5743)) - @testing.provide_metadata - def test_numerics(self): - m = self.metadata + def test_numerics(self, metadata, connection): + m = metadata t1 = Table( "t1", m, @@ -314,51 +304,48 @@ class TypesTest(fixtures.TestBase): Column("numbercol2", oracle.NUMBER(9, 3)), Column("numbercol3", oracle.NUMBER), ) - with testing.db.begin() as conn: - t1.create(conn) - conn.execute( - t1.insert(), - dict( - intcol=1, - numericcol=5.2, - floatcol1=6.5, - floatcol2=8.5, - doubleprec=9.5, - numbercol1=12, - numbercol2=14.85, - numbercol3=15.76, - ), - ) + t1.create(connection) + connection.execute( + t1.insert(), + dict( + intcol=1, + numericcol=5.2, + floatcol1=6.5, + floatcol2=8.5, + doubleprec=9.5, + numbercol1=12, + numbercol2=14.85, + numbercol3=15.76, + ), + ) m2 = MetaData() - t2 = Table("t1", m2, autoload_with=testing.db) + t2 = Table("t1", m2, autoload_with=connection) - with testing.db.connect() as conn: - for row in ( - conn.execute(t1.select()).first(), - conn.execute(t2.select()).first(), + for row in ( + connection.execute(t1.select()).first(), + connection.execute(t2.select()).first(), + ): + for i, (val, type_) in enumerate( + ( + (1, int), + (decimal.Decimal("5.2"), decimal.Decimal), + (6.5, float), + (8.5, float), + (9.5, float), + (12, int), + (decimal.Decimal("14.85"), decimal.Decimal), + (15.76, float), + ) ): - for i, (val, type_) in enumerate( - ( - (1, int), - (decimal.Decimal("5.2"), decimal.Decimal), - (6.5, float), - (8.5, float), - (9.5, float), - (12, int), - (decimal.Decimal("14.85"), decimal.Decimal), - (15.76, float), - ) - ): - eq_(row[i], val) - assert isinstance(row[i], type_), "%r is not %r" % ( - row[i], - type_, - ) + eq_(row[i], val) + assert isinstance(row[i], type_), "%r is not %r" % ( + row[i], + type_, + ) - @testing.provide_metadata - def test_numeric_infinity_float(self, connection): - m = self.metadata + def test_numeric_infinity_float(self, metadata, connection): + m = metadata t1 = Table( "t1", m, @@ -388,9 +375,8 @@ class TypesTest(fixtures.TestBase): [(float("inf"),), (float("-inf"),)], ) - @testing.provide_metadata - def test_numeric_infinity_decimal(self, connection): - m = self.metadata + def test_numeric_infinity_decimal(self, metadata, connection): + m = metadata t1 = Table( "t1", m, @@ -420,9 +406,8 @@ class TypesTest(fixtures.TestBase): [(decimal.Decimal("Infinity"),), (decimal.Decimal("-Infinity"),)], ) - @testing.provide_metadata - def test_numeric_nan_float(self, connection): - m = self.metadata + def test_numeric_nan_float(self, metadata, connection): + m = metadata t1 = Table( "t1", m, @@ -460,9 +445,8 @@ class TypesTest(fixtures.TestBase): # needs https://github.com/oracle/python-cx_Oracle/ # issues/184#issuecomment-391399292 - @testing.provide_metadata - def _dont_test_numeric_nan_decimal(self, connection): - m = self.metadata + def _dont_test_numeric_nan_decimal(self, metadata, connection): + m = metadata t1 = Table( "t1", m, @@ -489,16 +473,13 @@ class TypesTest(fixtures.TestBase): [(decimal.Decimal("NaN"),), (decimal.Decimal("NaN"),)], ) - @testing.provide_metadata - def test_numerics_broken_inspection(self, connection): + def test_numerics_broken_inspection(self, metadata, connection): """Numeric scenarios where Oracle type info is 'broken', returning us precision, scale of the form (0, 0) or (0, -127). We convert to Decimal and let int()/float() processors take over. """ - metadata = self.metadata - # this test requires cx_oracle 5 foo = Table( @@ -743,9 +724,7 @@ class TypesTest(fixtures.TestBase): value = exec_sql(connection, "SELECT 'hello' FROM DUAL").scalar() assert isinstance(value, util.text_type) - @testing.provide_metadata - def test_reflect_dates(self): - metadata = self.metadata + def test_reflect_dates(self, metadata, connection): Table( "date_types", metadata, @@ -755,9 +734,9 @@ class TypesTest(fixtures.TestBase): Column("d4", TIMESTAMP(timezone=True)), Column("d5", oracle.INTERVAL(second_precision=5)), ) - metadata.create_all() + metadata.create_all(connection) m = MetaData() - t1 = Table("date_types", m, autoload_with=testing.db) + t1 = Table("date_types", m, autoload_with=connection) assert isinstance(t1.c.d1.type, oracle.DATE) assert isinstance(t1.c.d1.type, DateTime) assert isinstance(t1.c.d2.type, oracle.DATE) @@ -780,22 +759,18 @@ class TypesTest(fixtures.TestBase): for row in types_table.select().execute().fetchall(): [row[k] for k in row.keys()] - @testing.provide_metadata - def test_raw_roundtrip(self, connection): - metadata = self.metadata + def test_raw_roundtrip(self, metadata, connection): raw_table = Table( "raw", metadata, Column("id", Integer, primary_key=True), Column("data", oracle.RAW(35)), ) - metadata.create_all() + metadata.create_all(connection) connection.execute(raw_table.insert(), id=1, data=b("ABCDEF")) eq_(connection.execute(raw_table.select()).first(), (1, b("ABCDEF"))) - @testing.provide_metadata - def test_reflect_nvarchar(self, connection): - metadata = self.metadata + def test_reflect_nvarchar(self, metadata, connection): Table( "tnv", metadata, @@ -827,31 +802,26 @@ class TypesTest(fixtures.TestBase): assert isinstance(nv_data, util.text_type) assert isinstance(c_data, util.text_type) - @testing.provide_metadata - def test_reflect_unicode_no_nvarchar(self): - metadata = self.metadata + def test_reflect_unicode_no_nvarchar(self, metadata, connection): Table("tnv", metadata, Column("data", sqltypes.Unicode(255))) - metadata.create_all() + metadata.create_all(connection) m2 = MetaData() - t2 = Table("tnv", m2, autoload_with=testing.db) + t2 = Table("tnv", m2, autoload_with=connection) assert isinstance(t2.c.data.type, sqltypes.VARCHAR) if testing.against("oracle+cx_oracle"): assert isinstance( - t2.c.data.type.dialect_impl(testing.db.dialect), + t2.c.data.type.dialect_impl(connection.dialect), cx_oracle._OracleString, ) data = u("m’a réveillé.") - with testing.db.begin() as conn: - conn.execute(t2.insert(), {"data": data}) - res = conn.execute(t2.select()).first().data - eq_(res, data) - assert isinstance(res, util.text_type) + connection.execute(t2.insert(), {"data": data}) + res = connection.execute(t2.select()).first().data + eq_(res, data) + assert isinstance(res, util.text_type) - @testing.provide_metadata - def test_char_length(self): - metadata = self.metadata + def test_char_length(self, metadata, connection): t1 = Table( "t1", metadata, @@ -860,26 +830,22 @@ class TypesTest(fixtures.TestBase): Column("c3", CHAR(200)), Column("c4", NCHAR(180)), ) - t1.create() + t1.create(connection) m2 = MetaData() - t2 = Table("t1", m2, autoload_with=testing.db) + t2 = Table("t1", m2, autoload_with=connection) eq_(t2.c.c1.type.length, 50) eq_(t2.c.c2.type.length, 250) eq_(t2.c.c3.type.length, 200) eq_(t2.c.c4.type.length, 180) - @testing.provide_metadata - def test_long_type(self, connection): - metadata = self.metadata + def test_long_type(self, metadata, connection): t = Table("t", metadata, Column("data", oracle.LONG)) - metadata.create_all(testing.db) + metadata.create_all(connection) connection.execute(t.insert(), data="xyz") eq_(connection.scalar(select(t.c.data)), "xyz") - @testing.provide_metadata - def test_longstring(self, connection): - metadata = self.metadata + def test_longstring(self, metadata, connection): exec_sql( connection, """ @@ -1020,23 +986,21 @@ class LOBFetchTest(fixtures.TablesTest): self.data, ) - def test_large_stream(self): + def test_large_stream(self, connection): binary_table = self.tables.binary_table - result = ( - binary_table.select() - .order_by(binary_table.c.id) - .execute() - .fetchall() - ) + result = connection.execute( + binary_table.select().order_by(binary_table.c.id) + ).fetchall() eq_(result, [(i, self.stream) for i in range(1, 11)]) def test_large_stream_single_arraysize(self): binary_table = self.tables.binary_table eng = testing_engine(options={"arraysize": 1}) - result = eng.execute( - binary_table.select().order_by(binary_table.c.id) - ).fetchall() - eq_(result, [(i, self.stream) for i in range(1, 11)]) + with eng.connect() as conn: + result = conn.execute( + binary_table.select().order_by(binary_table.c.id) + ).fetchall() + eq_(result, [(i, self.stream) for i in range(1, 11)]) class EuroNumericTest(fixtures.TestBase): @@ -1140,10 +1104,10 @@ class SetInputSizesTest(fixtures.TestBase): (CHAR(30), "test", "FIXED_CHAR", False), (NCHAR(30), u("test"), "FIXED_NCHAR", False), (oracle.LONG(), "test", None, False), + argnames="datatype, value, sis_value_text, set_nchar_flag", ) - @testing.provide_metadata def test_setinputsizes( - self, datatype, value, sis_value_text, set_nchar_flag + self, metadata, datatype, value, sis_value_text, set_nchar_flag ): if isinstance(sis_value_text, str): sis_value = getattr(testing.db.dialect.dbapi, sis_value_text) @@ -1159,7 +1123,7 @@ class SetInputSizesTest(fixtures.TestBase): else: return self.impl - m = self.metadata + m = metadata # Oracle can have only one column of type LONG so we make three # tables rather than one table w/ three columns t1 = Table("t1", m, Column("foo", datatype)) @@ -1167,7 +1131,7 @@ class SetInputSizesTest(fixtures.TestBase): "t2", m, Column("foo", NullType().with_variant(datatype, "oracle")) ) t3 = Table("t3", m, Column("foo", TestTypeDec())) - m.create_all() + m.create_all(testing.db) class CursorWrapper(object): # cx_oracle cursor can't be modified so we have to @@ -1211,7 +1175,7 @@ class SetInputSizesTest(fixtures.TestBase): [mock.call.setinputsizes()], ) - def test_event_no_native_float(self): + def test_event_no_native_float(self, metadata): def _remove_type(inputsizes, cursor, statement, parameters, context): for param, dbapitype in list(inputsizes.items()): if dbapitype is testing.db.dialect.dbapi.NATIVE_FLOAT: @@ -1219,6 +1183,8 @@ class SetInputSizesTest(fixtures.TestBase): event.listen(testing.db, "do_setinputsizes", _remove_type) try: - self.test_setinputsizes(oracle.BINARY_FLOAT, 25.34534, None, False) + self.test_setinputsizes( + metadata, oracle.BINARY_FLOAT, 25.34534, None, False + ) finally: event.remove(testing.db, "do_setinputsizes", _remove_type) diff --git a/test/dialect/postgresql/test_async_pg_py3k.py b/test/dialect/postgresql/test_async_pg_py3k.py new file mode 100644 index 000000000..fadf939b8 --- /dev/null +++ b/test/dialect/postgresql/test_async_pg_py3k.py @@ -0,0 +1,182 @@ +import random + +from sqlalchemy import Column +from sqlalchemy import exc +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy.dialects.postgresql import ENUM +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.testing import async_test +from sqlalchemy.testing import engines +from sqlalchemy.testing import fixtures + + +class AsyncPgTest(fixtures.TestBase): + __requires__ = ("async_dialect",) + __only_on__ = "postgresql+asyncpg" + + @testing.fixture + def async_engine(self): + return create_async_engine(testing.db.url) + + @testing.fixture() + def metadata(self): + # TODO: remove when Iae6ab95938a7e92b6d42086aec534af27b5577d3 + # merges + + from sqlalchemy.testing import engines + from sqlalchemy.sql import schema + + metadata = schema.MetaData() + + try: + yield metadata + finally: + engines.drop_all_tables(metadata, testing.db) + + @async_test + async def test_detect_stale_ddl_cache_raise_recover( + self, metadata, async_engine + ): + async def async_setup(engine, strlen): + metadata.clear() + t1 = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(strlen)), + ) + + # conn is an instance of AsyncConnection + async with engine.begin() as conn: + await conn.run_sync(metadata.drop_all) + await conn.run_sync(metadata.create_all) + await conn.execute( + t1.insert(), + [{"name": "some name %d" % i} for i in range(500)], + ) + + meta = MetaData() + + t1 = Table( + "t1", + meta, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + await async_setup(async_engine, 30) + + second_engine = engines.testing_engine(asyncio=True) + + async with second_engine.connect() as conn: + result = await conn.execute( + t1.select() + .where(t1.c.name.like("some name%")) + .where(t1.c.id % 17 == 6) + ) + + rows = result.fetchall() + assert len(rows) >= 29 + + await async_setup(async_engine, 20) + + async with second_engine.connect() as conn: + with testing.expect_raises_message( + exc.NotSupportedError, + r"cached statement plan is invalid due to a database schema " + r"or configuration change \(SQLAlchemy asyncpg dialect " + r"will now invalidate all prepared caches in response " + r"to this exception\)", + ): + + result = await conn.execute( + t1.select() + .where(t1.c.name.like("some name%")) + .where(t1.c.id % 17 == 6) + ) + + # works again + async with second_engine.connect() as conn: + result = await conn.execute( + t1.select() + .where(t1.c.name.like("some name%")) + .where(t1.c.id % 17 == 6) + ) + + rows = result.fetchall() + assert len(rows) >= 29 + + @async_test + async def test_detect_stale_type_cache_raise_recover( + self, metadata, async_engine + ): + async def async_setup(engine, enums): + metadata = MetaData() + Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("name", ENUM(*enums, name="my_enum")), + ) + + # conn is an instance of AsyncConnection + async with engine.begin() as conn: + await conn.run_sync(metadata.drop_all) + await conn.run_sync(metadata.create_all) + + t1 = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column( + "name", + ENUM( + *("beans", "means", "keens", "faux", "beau", "flow"), + name="my_enum" + ), + ), + ) + + await async_setup(async_engine, ("beans", "means", "keens")) + + second_engine = engines.testing_engine( + asyncio=True, + options={"connect_args": {"prepared_statement_cache_size": 0}}, + ) + + async with second_engine.connect() as conn: + await conn.execute( + t1.insert(), + [ + {"name": random.choice(("beans", "means", "keens"))} + for i in range(10) + ], + ) + + await async_setup(async_engine, ("faux", "beau", "flow")) + + async with second_engine.connect() as conn: + with testing.expect_raises_message( + exc.InternalError, "cache lookup failed for type" + ): + await conn.execute( + t1.insert(), + [ + {"name": random.choice(("faux", "beau", "flow"))} + for i in range(10) + ], + ) + + # works again + async with second_engine.connect() as conn: + await conn.execute( + t1.insert(), + [ + {"name": random.choice(("faux", "beau", "flow"))} + for i in range(10) + ], + ) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 3bd8e9da0..f760a309b 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -757,7 +757,7 @@ class MiscBackendTest( Column("date1", DateTime(timezone=True)), Column("date2", DateTime(timezone=False)), ) - metadata.create_all() + metadata.create_all(testing.db) m2 = MetaData() t2 = Table("pgdate", m2, autoload_with=testing.db) assert t2.c.date1.type.timezone is True diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 824f6cd36..754eff25a 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -457,12 +457,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): __only_on__ = "postgresql" __backend__ = True - @testing.fails_if( - "postgresql < 8.4", "Better int2vector functions not available" - ) - @testing.provide_metadata - def test_reflected_primary_key_order(self): - meta1 = self.metadata + def test_reflected_primary_key_order(self, metadata, connection): + meta1 = metadata subject = Table( "subject", meta1, @@ -470,9 +466,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("p2", Integer, primary_key=True), PrimaryKeyConstraint("p2", "p1"), ) - meta1.create_all() + meta1.create_all(connection) meta2 = MetaData() - subject = Table("subject", meta2, autoload_with=testing.db) + subject = Table("subject", meta2, autoload_with=connection) eq_(subject.primary_key.columns.keys(), ["p2", "p1"]) @testing.provide_metadata @@ -583,10 +579,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): user_tmp.create(testing.db) assert inspect(testing.db).has_table("some_temp_table") - @testing.provide_metadata - def test_cross_schema_reflection_one(self): + def test_cross_schema_reflection_one(self, metadata, connection): - meta1 = self.metadata + meta1 = metadata users = Table( "users", @@ -603,12 +598,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("email_address", String(20)), schema="test_schema", ) - meta1.create_all() + meta1.create_all(connection) meta2 = MetaData() addresses = Table( "email_addresses", meta2, - autoload_with=testing.db, + autoload_with=connection, schema="test_schema", ) users = Table("users", meta2, must_exist=True, schema="test_schema") @@ -617,9 +612,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): (users.c.user_id == addresses.c.remote_user_id).compare(j.onclause) ) - @testing.provide_metadata - def test_cross_schema_reflection_two(self): - meta1 = self.metadata + def test_cross_schema_reflection_two(self, metadata, connection): + meta1 = metadata subject = Table( "subject", meta1, Column("id", Integer, primary_key=True) ) @@ -630,11 +624,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("ref", Integer, ForeignKey("subject.id")), schema="test_schema", ) - meta1.create_all() + meta1.create_all(connection) meta2 = MetaData() - subject = Table("subject", meta2, autoload_with=testing.db) + subject = Table("subject", meta2, autoload_with=connection) referer = Table( - "referer", meta2, schema="test_schema", autoload_with=testing.db + "referer", meta2, schema="test_schema", autoload_with=connection ) self.assert_( (subject.c.id == referer.c.ref).compare( @@ -642,9 +636,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ) ) - @testing.provide_metadata - def test_cross_schema_reflection_three(self): - meta1 = self.metadata + def test_cross_schema_reflection_three(self, metadata, connection): + meta1 = metadata subject = Table( "subject", meta1, @@ -658,13 +651,13 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("ref", Integer, ForeignKey("test_schema_2.subject.id")), schema="test_schema", ) - meta1.create_all() + meta1.create_all(connection) meta2 = MetaData() subject = Table( - "subject", meta2, autoload_with=testing.db, schema="test_schema_2" + "subject", meta2, autoload_with=connection, schema="test_schema_2" ) referer = Table( - "referer", meta2, autoload_with=testing.db, schema="test_schema" + "referer", meta2, autoload_with=connection, schema="test_schema" ) self.assert_( (subject.c.id == referer.c.ref).compare( @@ -672,9 +665,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ) ) - @testing.provide_metadata - def test_cross_schema_reflection_four(self): - meta1 = self.metadata + def test_cross_schema_reflection_four(self, metadata, connection): + meta1 = metadata subject = Table( "subject", meta1, @@ -688,23 +680,24 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("ref", Integer, ForeignKey("test_schema_2.subject.id")), schema="test_schema", ) - meta1.create_all() + meta1.create_all(connection) - conn = testing.db.connect() - conn.detach() - conn.exec_driver_sql("SET search_path TO test_schema, test_schema_2") - meta2 = MetaData(bind=conn) + connection.detach() + connection.exec_driver_sql( + "SET search_path TO test_schema, test_schema_2" + ) + meta2 = MetaData() subject = Table( "subject", meta2, - autoload_with=testing.db, + autoload_with=connection, schema="test_schema_2", postgresql_ignore_search_path=True, ) referer = Table( "referer", meta2, - autoload_with=testing.db, + autoload_with=connection, schema="test_schema", postgresql_ignore_search_path=True, ) @@ -713,14 +706,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): subject.join(referer).onclause ) ) - conn.close() - @testing.provide_metadata - def test_cross_schema_reflection_five(self): - meta1 = self.metadata + def test_cross_schema_reflection_five(self, metadata, connection): + meta1 = metadata # we assume 'public' - default_schema = testing.db.dialect.default_schema_name + default_schema = connection.dialect.default_schema_name subject = Table( "subject", meta1, Column("id", Integer, primary_key=True) ) @@ -730,20 +721,20 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("ref", Integer, ForeignKey("subject.id")), ) - meta1.create_all() + meta1.create_all(connection) meta2 = MetaData() subject = Table( "subject", meta2, - autoload_with=testing.db, + autoload_with=connection, schema=default_schema, postgresql_ignore_search_path=True, ) referer = Table( "referer", meta2, - autoload_with=testing.db, + autoload_with=connection, schema=default_schema, postgresql_ignore_search_path=True, ) @@ -754,11 +745,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ) ) - @testing.provide_metadata - def test_cross_schema_reflection_six(self): + def test_cross_schema_reflection_six(self, metadata, connection): # test that the search path *is* taken into account # by default - meta1 = self.metadata + meta1 = metadata Table( "some_table", @@ -773,60 +763,58 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("sid", Integer, ForeignKey("test_schema.some_table.id")), schema="test_schema_2", ) - meta1.create_all() - with testing.db.connect() as conn: - conn.detach() + meta1.create_all(connection) + connection.detach() - conn.exec_driver_sql( - "set search_path to test_schema_2, test_schema, public" - ) + connection.exec_driver_sql( + "set search_path to test_schema_2, test_schema, public" + ) - m1 = MetaData() + m1 = MetaData() - Table("some_table", m1, schema="test_schema", autoload_with=conn) - t2_schema = Table( - "some_other_table", - m1, - schema="test_schema_2", - autoload_with=conn, - ) + Table("some_table", m1, schema="test_schema", autoload_with=connection) + t2_schema = Table( + "some_other_table", + m1, + schema="test_schema_2", + autoload_with=connection, + ) - t2_no_schema = Table("some_other_table", m1, autoload_with=conn) + t2_no_schema = Table("some_other_table", m1, autoload_with=connection) - t1_no_schema = Table("some_table", m1, autoload_with=conn) + t1_no_schema = Table("some_table", m1, autoload_with=connection) - m2 = MetaData() - t1_schema_isp = Table( - "some_table", - m2, - schema="test_schema", - autoload_with=conn, - postgresql_ignore_search_path=True, - ) - t2_schema_isp = Table( - "some_other_table", - m2, - schema="test_schema_2", - autoload_with=conn, - postgresql_ignore_search_path=True, - ) + m2 = MetaData() + t1_schema_isp = Table( + "some_table", + m2, + schema="test_schema", + autoload_with=connection, + postgresql_ignore_search_path=True, + ) + t2_schema_isp = Table( + "some_other_table", + m2, + schema="test_schema_2", + autoload_with=connection, + postgresql_ignore_search_path=True, + ) - # t2_schema refers to t1_schema, but since "test_schema" - # is in the search path, we instead link to t2_no_schema - assert t2_schema.c.sid.references(t1_no_schema.c.id) + # t2_schema refers to t1_schema, but since "test_schema" + # is in the search path, we instead link to t2_no_schema + assert t2_schema.c.sid.references(t1_no_schema.c.id) - # the two no_schema tables refer to each other also. - assert t2_no_schema.c.sid.references(t1_no_schema.c.id) + # the two no_schema tables refer to each other also. + assert t2_no_schema.c.sid.references(t1_no_schema.c.id) - # but if we're ignoring search path, then we maintain - # those explicit schemas vs. what the "default" schema is - assert t2_schema_isp.c.sid.references(t1_schema_isp.c.id) + # but if we're ignoring search path, then we maintain + # those explicit schemas vs. what the "default" schema is + assert t2_schema_isp.c.sid.references(t1_schema_isp.c.id) - @testing.provide_metadata - def test_cross_schema_reflection_seven(self): + def test_cross_schema_reflection_seven(self, metadata, connection): # test that the search path *is* taken into account # by default - meta1 = self.metadata + meta1 = metadata Table( "some_table", @@ -841,42 +829,42 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("sid", Integer, ForeignKey("test_schema.some_table.id")), schema="test_schema_2", ) - meta1.create_all() - with testing.db.connect() as conn: - conn.detach() + meta1.create_all(connection) + connection.detach() - conn.exec_driver_sql( - "set search_path to test_schema_2, test_schema, public" - ) - meta2 = MetaData(conn) - meta2.reflect(schema="test_schema_2") + connection.exec_driver_sql( + "set search_path to test_schema_2, test_schema, public" + ) + meta2 = MetaData() + meta2.reflect(connection, schema="test_schema_2") - eq_( - set(meta2.tables), - set(["test_schema_2.some_other_table", "some_table"]), - ) + eq_( + set(meta2.tables), + set(["test_schema_2.some_other_table", "some_table"]), + ) - meta3 = MetaData(conn) - meta3.reflect( - schema="test_schema_2", postgresql_ignore_search_path=True - ) + meta3 = MetaData() + meta3.reflect( + connection, + schema="test_schema_2", + postgresql_ignore_search_path=True, + ) - eq_( - set(meta3.tables), - set( - [ - "test_schema_2.some_other_table", - "test_schema.some_table", - ] - ), - ) + eq_( + set(meta3.tables), + set( + [ + "test_schema_2.some_other_table", + "test_schema.some_table", + ] + ), + ) - @testing.provide_metadata - def test_cross_schema_reflection_metadata_uses_schema(self): + def test_cross_schema_reflection_metadata_uses_schema( + self, metadata, connection + ): # test [ticket:3716] - metadata = self.metadata - Table( "some_table", metadata, @@ -890,28 +878,25 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), schema=None, ) - metadata.create_all() - with testing.db.connect() as conn: - meta2 = MetaData(conn, schema="test_schema") - meta2.reflect() + metadata.create_all(connection) + meta2 = MetaData(schema="test_schema") + meta2.reflect(connection) - eq_( - set(meta2.tables), - set(["some_other_table", "test_schema.some_table"]), - ) + eq_( + set(meta2.tables), + set(["some_other_table", "test_schema.some_table"]), + ) - @testing.provide_metadata - def test_uppercase_lowercase_table(self): - metadata = self.metadata + def test_uppercase_lowercase_table(self, metadata, connection): a_table = Table("a", metadata, Column("x", Integer)) A_table = Table("A", metadata, Column("x", Integer)) - a_table.create() - assert inspect(testing.db).has_table("a") - assert not inspect(testing.db).has_table("A") - A_table.create(checkfirst=True) - assert inspect(testing.db).has_table("A") + a_table.create(connection) + assert inspect(connection).has_table("a") + assert not inspect(connection).has_table("A") + A_table.create(connection, checkfirst=True) + assert inspect(connection).has_table("A") def test_uppercase_lowercase_sequence(self): @@ -927,12 +912,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): a_seq.drop(testing.db) A_seq.drop(testing.db) - @testing.provide_metadata - def test_index_reflection(self): + def test_index_reflection(self, metadata, connection): """Reflecting expression-based indexes should warn""" - metadata = self.metadata - Table( "party", metadata, @@ -940,22 +922,21 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("name", String(20), index=True), Column("aname", String(20)), ) - metadata.create_all(testing.db) - with testing.db.begin() as conn: - conn.exec_driver_sql("create index idx1 on party ((id || name))") - conn.exec_driver_sql( - "create unique index idx2 on party (id) where name = 'test'" - ) - conn.exec_driver_sql( - """ - create index idx3 on party using btree - (lower(name::text), lower(aname::text)) - """ - ) + metadata.create_all(connection) + connection.exec_driver_sql("create index idx1 on party ((id || name))") + connection.exec_driver_sql( + "create unique index idx2 on party (id) where name = 'test'" + ) + connection.exec_driver_sql( + """ + create index idx3 on party using btree + (lower(name::text), lower(aname::text)) + """ + ) def go(): m2 = MetaData() - t2 = Table("party", m2, autoload_with=testing.db) + t2 = Table("party", m2, autoload_with=connection) assert len(t2.indexes) == 2 # Make sure indexes are in the order we expect them in @@ -1020,51 +1001,46 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): "WHERE ((name)::text = 'test'::text)", ) - @testing.fails_if("postgresql < 8.3", "index ordering not supported") - @testing.provide_metadata - def test_index_reflection_with_sorting(self): + def test_index_reflection_with_sorting(self, metadata, connection): """reflect indexes with sorting options set""" t1 = Table( "party", - self.metadata, + metadata, Column("id", String(10), nullable=False), Column("name", String(20)), Column("aname", String(20)), ) - with testing.db.begin() as conn: - - t1.create(conn) + t1.create(connection) - # check ASC, DESC options alone - conn.exec_driver_sql( - """ - create index idx1 on party - (id, name ASC, aname DESC) + # check ASC, DESC options alone + connection.exec_driver_sql( """ - ) + create index idx1 on party + (id, name ASC, aname DESC) + """ + ) - # check DESC w/ NULLS options - conn.exec_driver_sql( - """ - create index idx2 on party - (name DESC NULLS FIRST, aname DESC NULLS LAST) + # check DESC w/ NULLS options + connection.exec_driver_sql( """ - ) + create index idx2 on party + (name DESC NULLS FIRST, aname DESC NULLS LAST) + """ + ) - # check ASC w/ NULLS options - conn.exec_driver_sql( - """ - create index idx3 on party - (name ASC NULLS FIRST, aname ASC NULLS LAST) + # check ASC w/ NULLS options + connection.exec_driver_sql( """ - ) + create index idx3 on party + (name ASC NULLS FIRST, aname ASC NULLS LAST) + """ + ) # reflect data - with testing.db.connect() as conn: - m2 = MetaData(conn) - t2 = Table("party", m2, autoload_with=testing.db) + m2 = MetaData() + t2 = Table("party", m2, autoload_with=connection) eq_(len(t2.indexes), 3) @@ -1206,12 +1182,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ) @testing.skip_if("postgresql < 11.0", "indnkeyatts not supported") - @testing.provide_metadata - def test_index_reflection_with_include(self): + def test_index_reflection_with_include(self, metadata, connection): """reflect indexes with include set""" - metadata = self.metadata - Table( "t", metadata, @@ -1219,30 +1192,27 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("x", ARRAY(Integer)), Column("name", String(20)), ) - metadata.create_all() - with testing.db.begin() as conn: - conn.exec_driver_sql("CREATE INDEX idx1 ON t (x) INCLUDE (name)") + metadata.create_all(connection) + connection.exec_driver_sql("CREATE INDEX idx1 ON t (x) INCLUDE (name)") - # prior to #5205, this would return: - # [{'column_names': ['x', 'name'], - # 'name': 'idx1', 'unique': False}] + # prior to #5205, this would return: + # [{'column_names': ['x', 'name'], + # 'name': 'idx1', 'unique': False}] - ind = testing.db.dialect.get_indexes(conn, "t", None) - eq_( - ind, - [ - { - "unique": False, - "column_names": ["x"], - "include_columns": ["name"], - "name": "idx1", - } - ], - ) + ind = testing.db.dialect.get_indexes(connection, "t", None) + eq_( + ind, + [ + { + "unique": False, + "column_names": ["x"], + "include_columns": ["name"], + "name": "idx1", + } + ], + ) - @testing.provide_metadata - def test_foreign_key_option_inspection(self): - metadata = self.metadata + def test_foreign_key_option_inspection(self, metadata, connection): Table( "person", metadata, @@ -1308,8 +1278,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): "options": {"onupdate": "CASCADE", "ondelete": "CASCADE"}, }, } - metadata.create_all() - inspector = inspect(testing.db) + metadata.create_all(connection) + inspector = inspect(connection) fks = inspector.get_foreign_keys( "person" ) + inspector.get_foreign_keys("company") @@ -1543,12 +1513,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): t = Table("t", MetaData(), autoload_with=testing.db) eq_(t.c.x.type.enums, []) - @testing.provide_metadata - @testing.only_on("postgresql >= 8.5") - def test_reflection_with_unique_constraint(self): - insp = inspect(testing.db) + def test_reflection_with_unique_constraint(self, metadata, connection): + insp = inspect(connection) - meta = self.metadata + meta = metadata uc_table = Table( "pgsql_uc", meta, @@ -1556,7 +1524,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): UniqueConstraint("a", name="uc_a"), ) - uc_table.create() + uc_table.create(connection) # PostgreSQL will create an implicit index for a unique # constraint. Separately we get both @@ -1569,7 +1537,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): self.assert_("uc_a" in constraints) # reflection corrects for the dupe - reflected = Table("pgsql_uc", MetaData(), autoload_with=testing.db) + reflected = Table("pgsql_uc", MetaData(), autoload_with=connection) indexes = set(i.name for i in reflected.indexes) constraints = set(uc.name for uc in reflected.constraints) @@ -1578,9 +1546,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): self.assert_("uc_a" in constraints) @testing.requires.btree_gist - @testing.provide_metadata - def test_reflection_with_exclude_constraint(self): - m = self.metadata + def test_reflection_with_exclude_constraint(self, metadata, connection): + m = metadata Table( "t", m, @@ -1589,9 +1556,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ExcludeConstraint(("period", "&&"), name="quarters_period_excl"), ) - m.create_all() + m.create_all(connection) - insp = inspect(testing.db) + insp = inspect(connection) # PostgreSQL will create an implicit index for an exclude constraint. # we don't reflect the EXCLUDE yet. @@ -1610,15 +1577,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): eq_(insp.get_indexes("t"), expected) # reflection corrects for the dupe - reflected = Table("t", MetaData(), autoload_with=testing.db) + reflected = Table("t", MetaData(), autoload_with=connection) eq_(set(reflected.indexes), set()) - @testing.provide_metadata - def test_reflect_unique_index(self): - insp = inspect(testing.db) + def test_reflect_unique_index(self, metadata, connection): + insp = inspect(connection) - meta = self.metadata + meta = metadata # a unique index OTOH we are able to detect is an index # and not a unique constraint @@ -1629,7 +1595,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Index("ix_a", "a", unique=True), ) - uc_table.create() + uc_table.create(connection) indexes = dict((i["name"], i) for i in insp.get_indexes("pgsql_uc")) constraints = set( @@ -1640,7 +1606,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): assert indexes["ix_a"]["unique"] self.assert_("ix_a" not in constraints) - reflected = Table("pgsql_uc", MetaData(), autoload_with=testing.db) + reflected = Table("pgsql_uc", MetaData(), autoload_with=connection) indexes = dict((i.name, i) for i in reflected.indexes) constraints = set(uc.name for uc in reflected.constraints) @@ -1649,9 +1615,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): assert indexes["ix_a"].unique self.assert_("ix_a" not in constraints) - @testing.provide_metadata - def test_reflect_check_constraint(self): - meta = self.metadata + def test_reflect_check_constraint(self, metadata, connection): + meta = metadata udf_create = """\ CREATE OR REPLACE FUNCTION is_positive( @@ -1666,7 +1631,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): """ sa.event.listen(meta, "before_create", sa.DDL(udf_create)) sa.event.listen( - meta, "after_drop", sa.DDL("DROP FUNCTION is_positive(integer)") + meta, + "after_drop", + sa.DDL("DROP FUNCTION IF EXISTS is_positive(integer)"), ) Table( @@ -1680,9 +1647,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): CheckConstraint("b != 'hi\nim a name \nyup\n'", name="cc4"), ) - meta.create_all() + meta.create_all(connection) - reflected = Table("pgsql_cc", MetaData(), autoload_with=testing.db) + reflected = Table("pgsql_cc", MetaData(), autoload_with=connection) check_constraints = dict( (uc.name, uc.sqltext.text) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index ae7a65a3a..e8a1876c7 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -63,9 +63,6 @@ from sqlalchemy.testing.suite import test_types as suite from sqlalchemy.testing.util import round_decimal -tztable = notztable = metadata = table = None - - class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): __only_on__ = "postgresql" __dialect__ = postgresql.dialect() @@ -121,9 +118,7 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): ).scalar() eq_(round_decimal(ret, 9), result) - @testing.provide_metadata - def test_arrays_pg(self, connection): - metadata = self.metadata + def test_arrays_pg(self, connection, metadata): t1 = Table( "t", metadata, @@ -132,16 +127,14 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): Column("z", postgresql.ARRAY(postgresql.DOUBLE_PRECISION)), Column("q", postgresql.ARRAY(Numeric)), ) - metadata.create_all() + metadata.create_all(connection) connection.execute( t1.insert(), x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")] ) row = connection.execute(t1.select()).first() eq_(row, ([5], [5], [6], [decimal.Decimal("6.4")])) - @testing.provide_metadata - def test_arrays_base(self, connection): - metadata = self.metadata + def test_arrays_base(self, connection, metadata): t1 = Table( "t", metadata, @@ -150,7 +143,7 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): Column("z", sqltypes.ARRAY(postgresql.DOUBLE_PRECISION)), Column("q", sqltypes.ARRAY(Numeric)), ) - metadata.create_all() + metadata.create_all(connection) connection.execute( t1.insert(), x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")] ) @@ -236,17 +229,14 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ] t1.drop(conn, checkfirst=True) - def test_name_required(self): - metadata = MetaData(testing.db) + def test_name_required(self, metadata, connection): etype = Enum("four", "five", "six", metadata=metadata) - assert_raises(exc.CompileError, etype.create) + assert_raises(exc.CompileError, etype.create, connection) assert_raises( - exc.CompileError, etype.compile, dialect=postgresql.dialect() + exc.CompileError, etype.compile, dialect=connection.dialect ) - @testing.provide_metadata - def test_unicode_labels(self, connection): - metadata = self.metadata + def test_unicode_labels(self, connection, metadata): t1 = Table( "table", metadata, @@ -261,7 +251,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ), ), ) - metadata.create_all() + metadata.create_all(connection) connection.execute(t1.insert(), value=util.u("drôle")) connection.execute(t1.insert(), value=util.u("réveillé")) connection.execute(t1.insert(), value=util.u("S’il")) @@ -274,7 +264,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ], ) m2 = MetaData() - t2 = Table("table", m2, autoload_with=testing.db) + t2 = Table("table", m2, autoload_with=connection) eq_( t2.c.value.type.enums, [util.u("réveillé"), util.u("drôle"), util.u("S’il")], @@ -408,8 +398,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): RegexSQL("DROP TYPE myenum", dialect="postgresql"), ) - @testing.provide_metadata - def test_generate_multiple(self): + def test_generate_multiple(self, metadata, connection): """Test that the same enum twice only generates once for the create_all() call, without using checkfirst. @@ -417,21 +406,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): now handles this. """ - metadata = self.metadata - e1 = Enum("one", "two", "three", name="myenum") Table("e1", metadata, Column("c1", e1)) Table("e2", metadata, Column("c1", e1)) - metadata.create_all(checkfirst=False) - metadata.drop_all(checkfirst=False) + metadata.create_all(connection, checkfirst=False) + metadata.drop_all(connection, checkfirst=False) assert "myenum" not in [ - e["name"] for e in inspect(testing.db).get_enums() + e["name"] for e in inspect(connection).get_enums() ] - @testing.provide_metadata - def test_generate_alone_on_metadata(self): + def test_generate_alone_on_metadata(self, connection, metadata): """Test that the same enum twice only generates once for the create_all() call, without using checkfirst. @@ -439,20 +425,17 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): now handles this. """ - metadata = self.metadata - Enum("one", "two", "three", name="myenum", metadata=self.metadata) + Enum("one", "two", "three", name="myenum", metadata=metadata) - metadata.create_all(checkfirst=False) - assert "myenum" in [e["name"] for e in inspect(testing.db).get_enums()] - metadata.drop_all(checkfirst=False) + metadata.create_all(connection, checkfirst=False) + assert "myenum" in [e["name"] for e in inspect(connection).get_enums()] + metadata.drop_all(connection, checkfirst=False) assert "myenum" not in [ - e["name"] for e in inspect(testing.db).get_enums() + e["name"] for e in inspect(connection).get_enums() ] - @testing.provide_metadata - def test_generate_multiple_on_metadata(self): - metadata = self.metadata + def test_generate_multiple_on_metadata(self, connection, metadata): e1 = Enum("one", "two", "three", name="myenum", metadata=metadata) @@ -460,20 +443,20 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): t2 = Table("e2", metadata, Column("c1", e1)) - metadata.create_all(checkfirst=False) - assert "myenum" in [e["name"] for e in inspect(testing.db).get_enums()] - metadata.drop_all(checkfirst=False) + metadata.create_all(connection, checkfirst=False) + assert "myenum" in [e["name"] for e in inspect(connection).get_enums()] + metadata.drop_all(connection, checkfirst=False) assert "myenum" not in [ - e["name"] for e in inspect(testing.db).get_enums() + e["name"] for e in inspect(connection).get_enums() ] - e1.create() # creates ENUM - t1.create() # does not create ENUM - t2.create() # does not create ENUM + e1.create(connection) # creates ENUM + t1.create(connection) # does not create ENUM + t2.create(connection) # does not create ENUM - @testing.provide_metadata - def test_generate_multiple_schemaname_on_metadata(self): - metadata = self.metadata + def test_generate_multiple_schemaname_on_metadata( + self, metadata, connection + ): Enum("one", "two", "three", name="myenum", metadata=metadata) Enum( @@ -485,38 +468,36 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): schema="test_schema", ) - metadata.create_all(checkfirst=False) - assert "myenum" in [e["name"] for e in inspect(testing.db).get_enums()] + metadata.create_all(connection, checkfirst=False) + assert "myenum" in [e["name"] for e in inspect(connection).get_enums()] assert "myenum" in [ e["name"] - for e in inspect(testing.db).get_enums(schema="test_schema") + for e in inspect(connection).get_enums(schema="test_schema") ] - metadata.drop_all(checkfirst=False) + metadata.drop_all(connection, checkfirst=False) assert "myenum" not in [ - e["name"] for e in inspect(testing.db).get_enums() + e["name"] for e in inspect(connection).get_enums() ] assert "myenum" not in [ e["name"] - for e in inspect(testing.db).get_enums(schema="test_schema") + for e in inspect(connection).get_enums(schema="test_schema") ] - @testing.provide_metadata - def test_drops_on_table(self): - metadata = self.metadata + def test_drops_on_table(self, connection, metadata): e1 = Enum("one", "two", "three", name="myenum") table = Table("e1", metadata, Column("c1", e1)) - table.create() - table.drop() + table.create(connection) + table.drop(connection) assert "myenum" not in [ - e["name"] for e in inspect(testing.db).get_enums() + e["name"] for e in inspect(connection).get_enums() ] - table.create() - assert "myenum" in [e["name"] for e in inspect(testing.db).get_enums()] - table.drop() + table.create(connection) + assert "myenum" in [e["name"] for e in inspect(connection).get_enums()] + table.drop(connection) assert "myenum" not in [ - e["name"] for e in inspect(testing.db).get_enums() + e["name"] for e in inspect(connection).get_enums() ] def test_create_drop_schema_translate_map(self, connection): @@ -554,9 +535,8 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): assert_raises(exc.ProgrammingError, e1.drop, conn, checkfirst=False) - @testing.provide_metadata - def test_remain_on_table_metadata_wide(self): - metadata = self.metadata + def test_remain_on_table_metadata_wide(self, metadata, future_connection): + connection = future_connection e1 = Enum("one", "two", "three", name="myenum", metadata=metadata) table = Table("e1", metadata, Column("c1", e1)) @@ -566,15 +546,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): sa.exc.ProgrammingError, '.*type "myenum" does not exist', table.create, + connection, ) - table.create(checkfirst=True) - table.drop() - table.create(checkfirst=True) - table.drop() - assert "myenum" in [e["name"] for e in inspect(testing.db).get_enums()] - metadata.drop_all() + connection.rollback() + + table.create(connection, checkfirst=True) + table.drop(connection) + table.create(connection, checkfirst=True) + table.drop(connection) + assert "myenum" in [e["name"] for e in inspect(connection).get_enums()] + metadata.drop_all(connection) assert "myenum" not in [ - e["name"] for e in inspect(testing.db).get_enums() + e["name"] for e in inspect(connection).get_enums() ] def test_non_native_dialect(self): @@ -616,26 +599,25 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): finally: metadata.drop_all(engine) - def test_standalone_enum(self): - metadata = MetaData(testing.db) + def test_standalone_enum(self, connection, metadata): etype = Enum( "four", "five", "six", name="fourfivesixtype", metadata=metadata ) - etype.create() + etype.create(connection) try: - assert testing.db.dialect.has_type(testing.db, "fourfivesixtype") + assert testing.db.dialect.has_type(connection, "fourfivesixtype") finally: - etype.drop() + etype.drop(connection) assert not testing.db.dialect.has_type( - testing.db, "fourfivesixtype" + connection, "fourfivesixtype" ) - metadata.create_all() + metadata.create_all(connection) try: - assert testing.db.dialect.has_type(testing.db, "fourfivesixtype") + assert testing.db.dialect.has_type(connection, "fourfivesixtype") finally: - metadata.drop_all() + metadata.drop_all(connection) assert not testing.db.dialect.has_type( - testing.db, "fourfivesixtype" + connection, "fourfivesixtype" ) def test_no_support(self): @@ -655,9 +637,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): e.connect() assert not dialect.supports_native_enum - @testing.provide_metadata - def test_reflection(self): - metadata = self.metadata + def test_reflection(self, metadata, connection): etype = Enum( "four", "five", "six", name="fourfivesixtype", metadata=metadata ) @@ -670,17 +650,15 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ), Column("value2", etype), ) - metadata.create_all() + metadata.create_all(connection) m2 = MetaData() - t2 = Table("table", m2, autoload_with=testing.db) + t2 = Table("table", m2, autoload_with=connection) eq_(t2.c.value.type.enums, ["one", "two", "three"]) eq_(t2.c.value.type.name, "onetwothreetype") eq_(t2.c.value2.type.enums, ["four", "five", "six"]) eq_(t2.c.value2.type.name, "fourfivesixtype") - @testing.provide_metadata - def test_schema_reflection(self): - metadata = self.metadata + def test_schema_reflection(self, metadata, connection): etype = Enum( "four", "five", @@ -705,9 +683,9 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ), Column("value2", etype), ) - metadata.create_all() + metadata.create_all(connection) m2 = MetaData() - t2 = Table("table", m2, autoload_with=testing.db) + t2 = Table("table", m2, autoload_with=connection) eq_(t2.c.value.type.enums, ["one", "two", "three"]) eq_(t2.c.value.type.name, "onetwothreetype") eq_(t2.c.value2.type.enums, ["four", "five", "six"]) @@ -810,21 +788,19 @@ class OIDTest(fixtures.TestBase): __only_on__ = "postgresql" __backend__ = True - @testing.provide_metadata - def test_reflection(self): - metadata = self.metadata + def test_reflection(self, connection, metadata): Table( "table", metadata, Column("x", Integer), Column("y", postgresql.OID), ) - metadata.create_all() + metadata.create_all(connection) m2 = MetaData() t2 = Table( "table", m2, - autoload_with=testing.db, + autoload_with=connection, ) assert isinstance(t2.c.y.type, postgresql.OID) @@ -858,19 +834,18 @@ class RegClassTest(fixtures.TestBase): "pg_class", ) - def test_cast_whereclause(self): + def test_cast_whereclause(self, connection): pga = Table( "pg_attribute", - MetaData(testing.db), + MetaData(), Column("attrelid", postgresql.OID), Column("attname", String(64)), ) - with testing.db.connect() as conn: - oid = conn.scalar( - select(pga.c.attrelid).where( - pga.c.attrelid == cast("pg_class", postgresql.REGCLASS) - ) + oid = connection.scalar( + select(pga.c.attrelid).where( + pga.c.attrelid == cast("pg_class", postgresql.REGCLASS) ) + ) assert isinstance(oid, int) @@ -904,9 +879,7 @@ class NumericInterpretationTest(fixtures.TestBase): val = proc(val) assert val in (23.7, decimal.Decimal("23.7")) - @testing.provide_metadata - def test_numeric_default(self, connection): - metadata = self.metadata + def test_numeric_default(self, connection, metadata): # pg8000 appears to fail when the value is 0, # returns an int instead of decimal. t = Table( @@ -918,7 +891,7 @@ class NumericInterpretationTest(fixtures.TestBase): Column("fd", Float(asdecimal=True), default=1), Column("ff", Float(asdecimal=False), default=1), ) - metadata.create_all() + metadata.create_all(connection) connection.execute(t.insert()) row = connection.execute(t.select()).first() @@ -934,7 +907,7 @@ class PythonTypeTest(fixtures.TestBase): is_(postgresql.INTERVAL().python_type, datetime.timedelta) -class TimezoneTest(fixtures.TestBase): +class TimezoneTest(fixtures.TablesTest): __backend__ = True """Test timezone-aware datetimes. @@ -948,14 +921,11 @@ class TimezoneTest(fixtures.TestBase): __only_on__ = "postgresql" @classmethod - def setup_class(cls): - global tztable, notztable, metadata - metadata = MetaData(testing.db) - + def define_tables(cls, metadata): # current_timestamp() in postgresql is assumed to return # TIMESTAMP WITH TIMEZONE - tztable = Table( + Table( "tztable", metadata, Column("id", Integer, primary_key=True), @@ -966,7 +936,7 @@ class TimezoneTest(fixtures.TestBase): ), Column("name", String(20)), ) - notztable = Table( + Table( "notztable", metadata, Column("id", Integer, primary_key=True), @@ -979,19 +949,12 @@ class TimezoneTest(fixtures.TestBase): ), Column("name", String(20)), ) - metadata.create_all() - - @classmethod - def teardown_class(cls): - metadata.drop_all() def test_with_timezone(self, connection): - + tztable, notztable = self.tables("tztable", "notztable") # get a date with a tzinfo - somedate = testing.db.connect().scalar( - func.current_timestamp().select() - ) + somedate = connection.scalar(func.current_timestamp().select()) assert somedate.tzinfo connection.execute(tztable.insert(), id=1, name="row1", date=somedate) row = connection.execute( @@ -1012,6 +975,7 @@ class TimezoneTest(fixtures.TestBase): def test_without_timezone(self, connection): # get a date without a tzinfo + tztable, notztable = self.tables("tztable", "notztable") somedate = datetime.datetime(2005, 10, 20, 11, 52, 0) assert not somedate.tzinfo @@ -1056,14 +1020,10 @@ class TimePrecisionCompileTest(fixtures.TestBase, AssertsCompiledSQL): class TimePrecisionTest(fixtures.TestBase): - __dialect__ = postgresql.dialect() - __prefer__ = "postgresql" + __only_on__ = "postgresql" __backend__ = True - @testing.only_on("postgresql", "DB specific feature") - @testing.provide_metadata - def test_reflection(self): - metadata = self.metadata + def test_reflection(self, metadata, connection): t1 = Table( "t1", metadata, @@ -1074,9 +1034,9 @@ class TimePrecisionTest(fixtures.TestBase): Column("c5", postgresql.TIMESTAMP(precision=5)), Column("c6", postgresql.TIMESTAMP(timezone=True, precision=5)), ) - t1.create() + t1.create(connection) m2 = MetaData() - t2 = Table("t1", m2, autoload_with=testing.db) + t2 = Table("t1", m2, autoload_with=connection) eq_(t2.c.c1.type.precision, None) eq_(t2.c.c2.type.precision, 5) eq_(t2.c.c3.type.precision, 5) @@ -1391,22 +1351,18 @@ class ArrayRoundTripTest(object): assert isinstance(tbl.c.intarr.type.item_type, Integer) assert isinstance(tbl.c.strarr.type.item_type, String) - @testing.provide_metadata - def test_array_str_collation(self): - m = self.metadata - + def test_array_str_collation(self, metadata, connection): t = Table( "t", - m, + metadata, Column("data", sqltypes.ARRAY(String(50, collation="en_US"))), ) - t.create() + t.create(connection) - @testing.provide_metadata - def test_array_agg(self, connection): - values_table = Table("values", self.metadata, Column("value", Integer)) - self.metadata.create_all(testing.db) + def test_array_agg(self, metadata, connection): + values_table = Table("values", metadata, Column("value", Integer)) + metadata.create_all(connection) connection.execute( values_table.insert(), [{"value": i} for i in range(1, 10)] ) @@ -1658,9 +1614,7 @@ class ArrayRoundTripTest(object): [4, 5, 6], ) - @testing.provide_metadata - def test_tuple_flag(self, connection): - metadata = self.metadata + def test_tuple_flag(self, connection, metadata): t1 = Table( "t1", @@ -1671,7 +1625,7 @@ class ArrayRoundTripTest(object): "data2", self.ARRAY(Numeric(asdecimal=False), as_tuple=True) ), ) - metadata.create_all() + metadata.create_all(connection) connection.execute( t1.insert(), id=1, data=["1", "2", "3"], data2=[5.4, 5.6] ) @@ -2168,10 +2122,9 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): assert t.c.precision_interval.type.precision == 3 assert t.c.bitstring.type.length == 4 - @testing.provide_metadata - def test_tsvector_round_trip(self, connection): - t = Table("t1", self.metadata, Column("data", postgresql.TSVECTOR)) - t.create() + def test_tsvector_round_trip(self, connection, metadata): + t = Table("t1", metadata, Column("data", postgresql.TSVECTOR)) + t.create(connection) connection.execute(t.insert(), data="a fat cat sat") eq_(connection.scalar(select(t.c.data)), "'a' 'cat' 'fat' 'sat'") @@ -2182,9 +2135,7 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): "'a' 'cat' 'fat' 'mat' 'sat'", ) - @testing.provide_metadata - def test_bit_reflection(self): - metadata = self.metadata + def test_bit_reflection(self, metadata, connection): t1 = Table( "t1", metadata, @@ -2193,9 +2144,9 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): Column("bitvarying", postgresql.BIT(varying=True)), Column("bitvarying5", postgresql.BIT(5, varying=True)), ) - t1.create() + t1.create(connection) m2 = MetaData() - t2 = Table("t1", m2, autoload_with=testing.db) + t2 = Table("t1", m2, autoload_with=connection) eq_(t2.c.bit1.type.length, 1) eq_(t2.c.bit1.type.varying, False) eq_(t2.c.bit5.type.length, 5) @@ -2734,14 +2685,14 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): def test_where_equal(self): self._test_clause( - self.col == self._data_str, + self.col == self._data_str(), "data_table.range = %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_where_not_equal(self): self._test_clause( - self.col != self._data_str, + self.col != self._data_str(), "data_table.range <> %(range_1)s", sqltypes.BOOLEANTYPE, ) @@ -2760,94 +2711,94 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): def test_where_less_than(self): self._test_clause( - self.col < self._data_str, + self.col < self._data_str(), "data_table.range < %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_where_greater_than(self): self._test_clause( - self.col > self._data_str, + self.col > self._data_str(), "data_table.range > %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_where_less_than_or_equal(self): self._test_clause( - self.col <= self._data_str, + self.col <= self._data_str(), "data_table.range <= %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_where_greater_than_or_equal(self): self._test_clause( - self.col >= self._data_str, + self.col >= self._data_str(), "data_table.range >= %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_contains(self): self._test_clause( - self.col.contains(self._data_str), + self.col.contains(self._data_str()), "data_table.range @> %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_contained_by(self): self._test_clause( - self.col.contained_by(self._data_str), + self.col.contained_by(self._data_str()), "data_table.range <@ %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_overlaps(self): self._test_clause( - self.col.overlaps(self._data_str), + self.col.overlaps(self._data_str()), "data_table.range && %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_strictly_left_of(self): self._test_clause( - self.col << self._data_str, + self.col << self._data_str(), "data_table.range << %(range_1)s", sqltypes.BOOLEANTYPE, ) self._test_clause( - self.col.strictly_left_of(self._data_str), + self.col.strictly_left_of(self._data_str()), "data_table.range << %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_strictly_right_of(self): self._test_clause( - self.col >> self._data_str, + self.col >> self._data_str(), "data_table.range >> %(range_1)s", sqltypes.BOOLEANTYPE, ) self._test_clause( - self.col.strictly_right_of(self._data_str), + self.col.strictly_right_of(self._data_str()), "data_table.range >> %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_not_extend_right_of(self): self._test_clause( - self.col.not_extend_right_of(self._data_str), + self.col.not_extend_right_of(self._data_str()), "data_table.range &< %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_not_extend_left_of(self): self._test_clause( - self.col.not_extend_left_of(self._data_str), + self.col.not_extend_left_of(self._data_str()), "data_table.range &> %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_adjacent_to(self): self._test_clause( - self.col.adjacent_to(self._data_str), + self.col.adjacent_to(self._data_str()), "data_table.range -|- %(range_1)s", sqltypes.BOOLEANTYPE, ) @@ -2920,14 +2871,14 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): def test_insert_text(self, connection): connection.execute( - self.tables.data_table.insert(), {"range": self._data_str} + self.tables.data_table.insert(), {"range": self._data_str()} ) self._assert_data(connection) def test_union_result(self, connection): # insert connection.execute( - self.tables.data_table.insert(), {"range": self._data_str} + self.tables.data_table.insert(), {"range": self._data_str()} ) # select range_ = self.tables.data_table.c.range @@ -2937,7 +2888,7 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): def test_intersection_result(self, connection): # insert connection.execute( - self.tables.data_table.insert(), {"range": self._data_str} + self.tables.data_table.insert(), {"range": self._data_str()} ) # select range_ = self.tables.data_table.c.range @@ -2947,7 +2898,7 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): def test_difference_result(self, connection): # insert connection.execute( - self.tables.data_table.insert(), {"range": self._data_str} + self.tables.data_table.insert(), {"range": self._data_str()} ) # select range_ = self.tables.data_table.c.range @@ -2959,7 +2910,9 @@ class _Int4RangeTests(object): _col_type = INT4RANGE _col_str = "INT4RANGE" - _data_str = "[1,2)" + + def _data_str(self): + return "[1,2)" def _data_obj(self): return self.extras().NumericRange(1, 2) @@ -2969,7 +2922,9 @@ class _Int8RangeTests(object): _col_type = INT8RANGE _col_str = "INT8RANGE" - _data_str = "[9223372036854775806,9223372036854775807)" + + def _data_str(self): + return "[9223372036854775806,9223372036854775807)" def _data_obj(self): return self.extras().NumericRange( @@ -2981,7 +2936,9 @@ class _NumRangeTests(object): _col_type = NUMRANGE _col_str = "NUMRANGE" - _data_str = "[1.0,2.0)" + + def _data_str(self): + return "[1.0,2.0)" def _data_obj(self): return self.extras().NumericRange( @@ -2993,7 +2950,9 @@ class _DateRangeTests(object): _col_type = DATERANGE _col_str = "DATERANGE" - _data_str = "[2013-03-23,2013-03-24)" + + def _data_str(self): + return "[2013-03-23,2013-03-24)" def _data_obj(self): return self.extras().DateRange( @@ -3005,7 +2964,9 @@ class _DateTimeRangeTests(object): _col_type = TSRANGE _col_str = "TSRANGE" - _data_str = "[2013-03-23 14:30,2013-03-23 23:30)" + + def _data_str(self): + return "[2013-03-23 14:30,2013-03-23 23:30)" def _data_obj(self): return self.extras().DateTimeRange( @@ -3031,7 +2992,6 @@ class _DateTimeTZRangeTests(object): self._tstzs = (lower, upper) return self._tstzs - @property def _data_str(self): return "[%s,%s)" % self.tstzs() @@ -3178,7 +3138,7 @@ class JSONRoundTripTest(fixtures.TablesTest): __only_on__ = ("postgresql >= 9.3",) __backend__ = True - test_type = JSON + data_type = JSON @classmethod def define_tables(cls, metadata): @@ -3187,8 +3147,8 @@ class JSONRoundTripTest(fixtures.TablesTest): metadata, Column("id", Integer, primary_key=True), Column("name", String(30), nullable=False), - Column("data", cls.test_type), - Column("nulldata", cls.test_type(none_as_null=True)), + Column("data", cls.data_type), + Column("nulldata", cls.data_type(none_as_null=True)), ) def _fixture_data(self, engine): @@ -3255,7 +3215,7 @@ class JSONRoundTripTest(fixtures.TablesTest): def test_reflect(self): insp = inspect(testing.db) cols = insp.get_columns("data_table") - assert isinstance(cols[2]["type"], self.test_type) + assert isinstance(cols[2]["type"], self.data_type) def test_insert(self, connection): self._test_insert(connection) @@ -3286,7 +3246,7 @@ class JSONRoundTripTest(fixtures.TablesTest): options=dict(json_serializer=dumps, json_deserializer=loads) ) - s = select(cast({"key": "value", "x": "q"}, self.test_type)) + s = select(cast({"key": "value", "x": "q"}, self.data_type)) with engine.begin() as conn: eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"}) @@ -3366,7 +3326,7 @@ class JSONRoundTripTest(fixtures.TablesTest): s = select( cast( {"key": "value", "key2": {"k1": "v1", "k2": "v2"}}, - self.test_type, + self.data_type, ) ) eq_( @@ -3381,7 +3341,7 @@ class JSONRoundTripTest(fixtures.TablesTest): util.u("réveillé"): util.u("réveillé"), "data": {"k1": util.u("drôle")}, }, - self.test_type, + self.data_type, ) ) eq_( @@ -3483,7 +3443,7 @@ class JSONBTest(JSONTest): class JSONBRoundTripTest(JSONRoundTripTest): __requires__ = ("postgresql_jsonb",) - test_type = JSONB + data_type = JSONB @testing.requires.postgresql_utf8_server_encoding def test_unicode_round_trip(self, connection): diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 8eed21281..4658b40a8 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -72,44 +72,30 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "sqlite" - @testing.provide_metadata - def test_boolean(self): + def test_boolean(self, connection, metadata): """Test that the boolean only treats 1 as True""" - meta = self.metadata t = Table( "bool_table", - meta, + metadata, Column("id", Integer, primary_key=True), Column("boo", Boolean(create_constraint=False)), ) - meta.create_all(testing.db) - exec_sql( - testing.db, + metadata.create_all(connection) + for stmt in [ "INSERT INTO bool_table (id, boo) " "VALUES (1, 'false');", - ) - exec_sql( - testing.db, "INSERT INTO bool_table (id, boo) " "VALUES (2, 'true');", - ) - exec_sql( - testing.db, "INSERT INTO bool_table (id, boo) " "VALUES (3, '1');", - ) - exec_sql( - testing.db, "INSERT INTO bool_table (id, boo) " "VALUES (4, '0');", - ) - exec_sql( - testing.db, "INSERT INTO bool_table (id, boo) " "VALUES (5, 1);", - ) - exec_sql( - testing.db, "INSERT INTO bool_table (id, boo) " "VALUES (6, 0);", - ) + ]: + connection.exec_driver_sql(stmt) + eq_( - t.select(t.c.boo).order_by(t.c.id).execute().fetchall(), + connection.execute( + t.select().where(t.c.boo).order_by(t.c.id) + ).fetchall(), [(3, True), (5, True)], ) @@ -301,51 +287,41 @@ class JSONTest(fixtures.TestBase): __requires__ = ("json_type",) __only_on__ = "sqlite" - @testing.provide_metadata @testing.requires.reflects_json_type - def test_reflection(self): - Table("json_test", self.metadata, Column("foo", sqlite.JSON)) - self.metadata.create_all() + def test_reflection(self, connection, metadata): + Table("json_test", metadata, Column("foo", sqlite.JSON)) + metadata.create_all(connection) - reflected = Table("json_test", MetaData(), autoload_with=testing.db) + reflected = Table("json_test", MetaData(), autoload_with=connection) is_(reflected.c.foo.type._type_affinity, sqltypes.JSON) assert isinstance(reflected.c.foo.type, sqlite.JSON) - @testing.provide_metadata - def test_rudimentary_roundtrip(self): - sqlite_json = Table( - "json_test", self.metadata, Column("foo", sqlite.JSON) - ) + def test_rudimentary_roundtrip(self, metadata, connection): + sqlite_json = Table("json_test", metadata, Column("foo", sqlite.JSON)) - self.metadata.create_all() + metadata.create_all(connection) value = {"json": {"foo": "bar"}, "recs": ["one", "two"]} - with testing.db.begin() as conn: - conn.execute(sqlite_json.insert(), foo=value) + connection.execute(sqlite_json.insert(), foo=value) - eq_(conn.scalar(select(sqlite_json.c.foo)), value) + eq_(connection.scalar(select(sqlite_json.c.foo)), value) - @testing.provide_metadata - def test_extract_subobject(self): - sqlite_json = Table( - "json_test", self.metadata, Column("foo", sqlite.JSON) - ) + def test_extract_subobject(self, connection, metadata): + sqlite_json = Table("json_test", metadata, Column("foo", sqlite.JSON)) - self.metadata.create_all() + metadata.create_all(connection) value = {"json": {"foo": "bar"}} - with testing.db.begin() as conn: - conn.execute(sqlite_json.insert(), foo=value) - - eq_(conn.scalar(select(sqlite_json.c.foo["json"])), value["json"]) + connection.execute(sqlite_json.insert(), foo=value) - @testing.provide_metadata - def test_deprecated_serializer_args(self): - sqlite_json = Table( - "json_test", self.metadata, Column("foo", sqlite.JSON) + eq_( + connection.scalar(select(sqlite_json.c.foo["json"])), value["json"] ) + + def test_deprecated_serializer_args(self, metadata): + sqlite_json = Table("json_test", metadata, Column("foo", sqlite.JSON)) data_element = {"foo": "bar"} js = mock.Mock(side_effect=json.dumps) @@ -360,7 +336,7 @@ class JSONTest(fixtures.TestBase): engine = engines.testing_engine( options=dict(_json_serializer=js, _json_deserializer=jd) ) - self.metadata.create_all(engine) + metadata.create_all(engine) with engine.begin() as conn: conn.execute(sqlite_json.insert(), {"foo": data_element}) @@ -468,17 +444,7 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL): __only_on__ = "sqlite" - @testing.exclude( - "sqlite", - "<", - (3, 3, 8), - "sqlite3 changesets 3353 and 3440 modified " - "behavior of default displayed in pragma " - "table_info()", - ) - def test_default_reflection(self): - - # (ask_for, roundtripped_as_if_different) + def test_default_reflection(self, connection, metadata): specs = [ (String(3), '"foo"'), @@ -490,18 +456,13 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL): Column("c%i" % (i + 1), t[0], server_default=text(t[1])) for (i, t) in enumerate(specs) ] - db = testing.db - m = MetaData(db) - Table("t_defaults", m, *columns) - try: - m.create_all() - m2 = MetaData() - rt = Table("t_defaults", m2, autoload_with=db) - expected = [c[1] for c in specs] - for i, reflected in enumerate(rt.c): - eq_(str(reflected.server_default.arg), expected[i]) - finally: - m.drop_all() + Table("t_defaults", metadata, *columns) + metadata.create_all(connection) + m2 = MetaData() + rt = Table("t_defaults", m2, autoload_with=connection) + expected = [c[1] for c in specs] + for i, reflected in enumerate(rt.c): + eq_(str(reflected.server_default.arg), expected[i]) @testing.exclude( "sqlite", @@ -917,7 +878,7 @@ class AttachedDBTest(fixtures.TestBase): eq_(insp.get_schema_names(), ["main", "test_schema"]) def test_reflect_system_table(self): - meta = MetaData(self.conn) + meta = MetaData() alt_master = Table( "sqlite_master", meta, @@ -1758,8 +1719,8 @@ class KeywordInDatabaseNameTest(fixtures.TestBase): connection.exec_driver_sql('DETACH DATABASE "default"') def test_reflect(self, connection, db_fixture): - meta = MetaData(bind=connection, schema="default") - meta.reflect() + meta = MetaData(schema="default") + meta.reflect(connection) assert "default.a" in meta.tables diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py index 5cbb47854..396b48aa4 100644 --- a/test/engine/test_ddlevents.py +++ b/test/engine/test_ddlevents.py @@ -376,7 +376,7 @@ class DDLEventTest(fixtures.TestBase): class DDLExecutionTest(fixtures.TestBase): def setup(self): self.engine = engines.mock_engine() - self.metadata = MetaData(self.engine) + self.metadata = MetaData() self.users = Table( "users", self.metadata, @@ -391,14 +391,14 @@ class DDLExecutionTest(fixtures.TestBase): event.listen(users, "before_drop", DDL("xyzzy")) event.listen(users, "after_drop", DDL("fnord")) - users.create() + users.create(self.engine) strings = [str(x) for x in engine.mock] assert "mxyzptlk" in strings assert "klptzyxm" in strings assert "xyzzy" not in strings assert "fnord" not in strings del engine.mock[:] - users.drop() + users.drop(self.engine) strings = [str(x) for x in engine.mock] assert "mxyzptlk" not in strings assert "klptzyxm" not in strings @@ -413,14 +413,14 @@ class DDLExecutionTest(fixtures.TestBase): event.listen(users, "before_drop", DDL("xyzzy")) event.listen(users, "after_drop", DDL("fnord")) - metadata.create_all() + metadata.create_all(self.engine) strings = [str(x) for x in engine.mock] assert "mxyzptlk" in strings assert "klptzyxm" in strings assert "xyzzy" not in strings assert "fnord" not in strings del engine.mock[:] - metadata.drop_all() + metadata.drop_all(self.engine) strings = [str(x) for x in engine.mock] assert "mxyzptlk" not in strings assert "klptzyxm" not in strings @@ -435,14 +435,14 @@ class DDLExecutionTest(fixtures.TestBase): event.listen(metadata, "before_drop", DDL("xyzzy")) event.listen(metadata, "after_drop", DDL("fnord")) - metadata.create_all() + metadata.create_all(self.engine) strings = [str(x) for x in engine.mock] assert "mxyzptlk" in strings assert "klptzyxm" in strings assert "xyzzy" not in strings assert "fnord" not in strings del engine.mock[:] - metadata.drop_all() + metadata.drop_all(self.engine) strings = [str(x) for x in engine.mock] assert "mxyzptlk" not in strings assert "klptzyxm" not in strings diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index 4ca081be2..a18cf756b 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -69,25 +69,31 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): metadata = MetaData() Table("test_table", metadata, Column("foo", Integer)) for meth in [metadata.create_all, metadata.drop_all]: - assert_raises_message( - exc.UnboundExecutionError, - "MetaData object is not bound to an Engine or Connection.", - meth, - ) + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods that invoke SQL" + ): + assert_raises_message( + exc.UnboundExecutionError, + "MetaData object is not bound to an Engine or Connection.", + meth, + ) def test_bind_create_drop_err_table(self): metadata = MetaData() table = Table("test_table", metadata, Column("foo", Integer)) for meth in [table.create, table.drop]: - assert_raises_message( - exc.UnboundExecutionError, - ( - "Table object 'test_table' is not bound to an Engine or " - "Connection." - ), - meth, - ) + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods that invoke SQL" + ): + assert_raises_message( + exc.UnboundExecutionError, + ( + "Table object 'test_table' is not bound to an " + "Engine or Connection." + ), + meth, + ) def test_bind_create_drop_bound(self): @@ -106,16 +112,28 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): table = Table("test_table", metadata, Column("foo", Integer)) metadata.bind = bind assert metadata.bind is table.bind is bind - metadata.create_all() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods that invoke SQL" + ): + metadata.create_all() with testing.expect_deprecated( r"The Table.exists\(\) method is deprecated and will " "be removed in a future release." ): assert table.exists() - metadata.drop_all() - table.create() - table.drop() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods that invoke SQL" + ): + metadata.drop_all() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods that invoke SQL" + ): + table.create() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods that invoke SQL" + ): + table.drop() with testing.expect_deprecated( r"The Table.exists\(\) method is deprecated and will " "be removed in a future release." @@ -135,15 +153,27 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): metadata.bind = bind assert metadata.bind is table.bind is bind - metadata.create_all() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods that invoke SQL" + ): + metadata.create_all() with testing.expect_deprecated( r"The Table.exists\(\) method is deprecated and will " "be removed in a future release." ): assert table.exists() - metadata.drop_all() - table.create() - table.drop() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods that invoke SQL" + ): + metadata.drop_all() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods that invoke SQL" + ): + table.create() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods that invoke SQL" + ): + table.drop() with testing.expect_deprecated( r"The Table.exists\(\) method is deprecated and will " "be removed in a future release." @@ -158,16 +188,35 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): bind.begin() try: for args in (([bind], {}), ([], {"bind": bind})): - metadata = MetaData(*args[0], **args[1]) + with testing.expect_deprecated_20( + "The MetaData.bind argument is deprecated " + ): + metadata = MetaData(*args[0], **args[1]) table = Table( "test_table", metadata, Column("foo", Integer) ) assert metadata.bind is table.bind is bind - metadata.create_all() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods " + "that invoke SQL" + ): + metadata.create_all() is_true(inspect(bind).has_table(table.name)) - metadata.drop_all() - table.create() - table.drop() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods " + "that invoke SQL" + ): + metadata.drop_all() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods " + "that invoke SQL" + ): + table.create() + with testing.expect_deprecated_20( + "The ``bind`` argument for schema methods " + "that invoke SQL" + ): + table.drop() is_false(inspect(bind).has_table(table.name)) finally: if isinstance(bind, engine.Connection): @@ -315,11 +364,11 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): ): eq_(testing.db.execute(stmt).fetchall(), [(1,)]) - @testing.provide_metadata - def test_implicit_execute(self): - table = Table("t", self.metadata, Column("a", Integer)) + def test_implicit_execute(self, metadata): + table = Table("t", metadata, Column("a", Integer)) table.create(testing.db) + metadata.bind = testing.db stmt = table.insert().values(a=1) with testing.expect_deprecated_20( r"The Executable.execute\(\) method is considered legacy", @@ -1225,7 +1274,7 @@ class DeprecatedReflectionTest(fixtures.TablesTest): is_true(testing.db.has_table("user")) def test_engine_table_names(self): - metadata = self.metadata + metadata = self.tables_test_metadata with testing.expect_deprecated( r"The Engine.table_names\(\) method is deprecated" @@ -1235,7 +1284,8 @@ class DeprecatedReflectionTest(fixtures.TablesTest): def test_reflecttable(self): inspector = inspect(testing.db) - metadata = self.metadata + metadata = MetaData() + table = Table("user", metadata) with testing.expect_deprecated_20( r"The Inspector.reflecttable\(\) method is considered " @@ -1632,7 +1682,7 @@ class EngineEventsTest(fixtures.TestBase): class DDLExecutionTest(fixtures.TestBase): def setup(self): self.engine = engines.mock_engine() - self.metadata = MetaData(self.engine) + self.metadata = MetaData() self.users = Table( "users", self.metadata, @@ -1742,7 +1792,7 @@ class AutocommitTextTest(AutocommitKeywordFixture, fixtures.TestBase): self._test_keyword("SELECT foo FROM table", False) -class ExplicitAutoCommitTest(fixtures.TestBase): +class ExplicitAutoCommitTest(fixtures.TablesTest): """test the 'autocommit' flag on select() and text() objects. @@ -1752,36 +1802,31 @@ class ExplicitAutoCommitTest(fixtures.TestBase): __only_on__ = "postgresql" @classmethod - def setup_class(cls): - global metadata, foo - metadata = MetaData(testing.db) - foo = Table( + def define_tables(cls, metadata): + Table( "foo", metadata, Column("id", Integer, primary_key=True), Column("data", String(100)), ) - with testing.db.begin() as conn: - metadata.create_all(conn) - conn.exec_driver_sql( + + event.listen( + metadata, + "after_create", + DDL( "create function insert_foo(varchar) " "returns integer as 'insert into foo(data) " "values ($1);select 1;' language sql" - ) - - def teardown(self): - with testing.db.begin() as conn: - conn.execute(foo.delete()) - - @classmethod - def teardown_class(cls): - with testing.db.begin() as conn: - conn.exec_driver_sql("drop function insert_foo(varchar)") - metadata.drop_all(conn) + ), + ) + event.listen( + metadata, "before_drop", DDL("drop function insert_foo(varchar)") + ) def test_control(self): # test that not using autocommit does not commit + foo = self.tables.foo conn1 = testing.db.connect() conn2 = testing.db.connect() @@ -1799,6 +1844,8 @@ class ExplicitAutoCommitTest(fixtures.TestBase): conn2.close() def test_explicit_compiled(self): + foo = self.tables.foo + conn1 = testing.db.connect() conn2 = testing.db.connect() @@ -1816,6 +1863,8 @@ class ExplicitAutoCommitTest(fixtures.TestBase): conn2.close() def test_explicit_connection(self): + foo = self.tables.foo + conn1 = testing.db.connect() conn2 = testing.db.connect() with testing.expect_deprecated_20( @@ -1853,6 +1902,8 @@ class ExplicitAutoCommitTest(fixtures.TestBase): conn2.close() def test_explicit_text(self): + foo = self.tables.foo + conn1 = testing.db.connect() conn2 = testing.db.connect() with testing.expect_deprecated_20( @@ -1869,6 +1920,8 @@ class ExplicitAutoCommitTest(fixtures.TestBase): conn2.close() def test_implicit_text(self): + foo = self.tables.foo + conn1 = testing.db.connect() conn2 = testing.db.connect() with testing.expect_deprecated_20( diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 55a114409..21d4e06e0 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -304,6 +304,9 @@ class ExecuteTest(fixtures.TablesTest): class NonStandardException(OperationalError): pass + # TODO: this test is assuming too much of arbitrary dialects and would + # be better suited tested against a single mock dialect that does not + # have any special behaviors with patch.object( testing.db.dialect, "dbapi", Mock(Error=DBAPIError) ), patch.object( @@ -312,6 +315,10 @@ class ExecuteTest(fixtures.TablesTest): testing.db.dialect, "do_execute", Mock(side_effect=NonStandardException), + ), patch.object( + testing.db.dialect.execution_ctx_cls, + "handle_dbapi_exception", + Mock(), ): with testing.db.connect() as conn: assert_raises( @@ -823,12 +830,11 @@ class ConvenienceExecuteTest(fixtures.TablesTest): self._assert_no_data() -class CompiledCacheTest(fixtures.TablesTest): +class CompiledCacheTest(fixtures.TestBase): __backend__ = True - @classmethod - def define_tables(cls, metadata): - Table( + def test_cache(self, connection, metadata): + users = Table( "users", metadata, Column( @@ -837,9 +843,7 @@ class CompiledCacheTest(fixtures.TablesTest): Column("user_name", VARCHAR(20)), Column("extra_data", VARCHAR(20)), ) - - def test_cache(self, connection): - users = self.tables.users + users.create(connection) conn = connection cache = {} @@ -905,8 +909,17 @@ class CompiledCacheTest(fixtures.TablesTest): # the statement values (only the keys). eq_(ref_blob(), None) - def test_keys_independent_of_ordering(self, connection): - users = self.tables.users + def test_keys_independent_of_ordering(self, connection, metadata): + users = Table( + "users", + metadata, + Column( + "user_id", INT, primary_key=True, test_needs_autoincrement=True + ), + Column("user_name", VARCHAR(20)), + Column("extra_data", VARCHAR(20)), + ) + users.create(connection) connection.execute( users.insert(), @@ -954,13 +967,10 @@ class CompiledCacheTest(fixtures.TablesTest): eq_(len(cache), 1) @testing.requires.schemas - @testing.provide_metadata - def test_schema_translate_in_key(self): - Table("x", self.metadata, Column("q", Integer)) - Table( - "x", self.metadata, Column("q", Integer), schema=config.test_schema - ) - self.metadata.create_all() + def test_schema_translate_in_key(self, metadata, connection): + Table("x", metadata, Column("q", Integer)) + Table("x", metadata, Column("q", Integer), schema=config.test_schema) + metadata.create_all(connection) m = MetaData() t1 = Table("x", m, Column("q", Integer)) @@ -968,33 +978,30 @@ class CompiledCacheTest(fixtures.TablesTest): stmt = select(t1.c.q) cache = {} - with config.db.begin() as conn: - conn = conn.execution_options(compiled_cache=cache) - conn.execute(ins, {"q": 1}) - eq_(conn.scalar(stmt), 1) - with config.db.begin() as conn: - conn = conn.execution_options( - compiled_cache=cache, - schema_translate_map={None: config.test_schema}, - ) - conn.execute(ins, {"q": 2}) - eq_(conn.scalar(stmt), 2) + conn = connection.execution_options(compiled_cache=cache) + conn.execute(ins, {"q": 1}) + eq_(conn.scalar(stmt), 1) - with config.db.begin() as conn: - conn = conn.execution_options( - compiled_cache=cache, - schema_translate_map={None: None}, - ) - # should use default schema again even though statement - # was compiled with test_schema in the map - eq_(conn.scalar(stmt), 1) + conn = connection.execution_options( + compiled_cache=cache, + schema_translate_map={None: config.test_schema}, + ) + conn.execute(ins, {"q": 2}) + eq_(conn.scalar(stmt), 2) - with config.db.begin() as conn: - conn = conn.execution_options( - compiled_cache=cache, - ) - eq_(conn.scalar(stmt), 1) + conn = connection.execution_options( + compiled_cache=cache, + schema_translate_map={None: None}, + ) + # should use default schema again even though statement + # was compiled with test_schema in the map + eq_(conn.scalar(stmt), 1) + + conn = connection.execution_options( + compiled_cache=cache, + ) + eq_(conn.scalar(stmt), 1) class MockStrategyTest(fixtures.TestBase): @@ -1072,7 +1079,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): Table("t1", metadata, Column("x", Integer), schema=config.test_schema) Table("t2", metadata, Column("x", Integer), schema=config.test_schema) Table("t3", metadata, Column("x", Integer), schema=None) - metadata.create_all() + metadata.create_all(testing.db) def test_ddl_hastable(self): @@ -1765,7 +1772,7 @@ class EngineEventsTest(fixtures.TestBase): ]: event.listen(engine, "before_execute", execute) event.listen(engine, "before_cursor_execute", cursor_execute) - m = MetaData(engine) + m = MetaData() t1 = Table( "t1", m, diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 1a49cf4b9..550fedb8e 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -101,7 +101,7 @@ class PoolTest(PoolTestBase): def test_cursor_iterable(self): conn = testing.db.raw_connection() cursor = conn.cursor() - cursor.execute(str(select([1], bind=testing.db))) + cursor.execute(str(select(1).compile(testing.db))) expected = [(1,)] for row in cursor: eq_(row, expected.pop(0)) diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 48b6c40d7..658cdd79f 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -21,7 +21,6 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import ComparesTables from sqlalchemy.testing import config -from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import eq_regex from sqlalchemy.testing import expect_warnings @@ -43,13 +42,8 @@ from sqlalchemy.util import ue class ReflectionTest(fixtures.TestBase, ComparesTables): __backend__ = True - @testing.exclude( - "mssql", "<", (10, 0, 0), "Date is only supported on MSSQL 2008+" - ) - @testing.exclude("mysql", "<", (4, 1, 1), "early types are squirrely") - @testing.provide_metadata - def test_basic_reflection(self): - meta = self.metadata + def test_basic_reflection(self, connection, metadata): + meta = metadata users = Table( "engine_users", @@ -85,25 +79,22 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("email_address", sa.String(20)), test_needs_fk=True, ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() reflected_users = Table( - "engine_users", meta2, autoload_with=testing.db + "engine_users", meta2, autoload_with=connection ) reflected_addresses = Table( "engine_email_addresses", meta2, - autoload_with=testing.db, + autoload_with=connection, ) self.assert_tables_equal(users, reflected_users) self.assert_tables_equal(addresses, reflected_addresses) - @testing.provide_metadata - def test_autoload_with_imply_autoload( - self, - ): - meta = self.metadata + def test_autoload_with_imply_autoload(self, metadata, connection): + meta = metadata t = Table( "t", meta, @@ -111,15 +102,14 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("x", sa.String(20)), Column("y", sa.Integer), ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() - reflected_t = Table("t", meta2, autoload_with=testing.db) + reflected_t = Table("t", meta2, autoload_with=connection) self.assert_tables_equal(t, reflected_t) - @testing.provide_metadata - def test_two_foreign_keys(self): - meta = self.metadata + def test_two_foreign_keys(self, metadata, connection): + meta = metadata Table( "t1", meta, @@ -140,18 +130,17 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("id", sa.Integer, primary_key=True), test_needs_fk=True, ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() t1r, t2r, t3r = [ - Table(x, meta2, autoload_with=testing.db) + Table(x, meta2, autoload_with=connection) for x in ("t1", "t2", "t3") ] assert t1r.c.t2id.references(t2r.c.id) assert t1r.c.t3id.references(t3r.c.id) - @testing.provide_metadata - def test_resolve_fks_false_table(self): - meta = self.metadata + def test_resolve_fks_false_table(self, connection, metadata): + meta = metadata Table( "t1", meta, @@ -165,9 +154,9 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("id", sa.Integer, primary_key=True), test_needs_fk=True, ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() - t1 = Table("t1", meta2, resolve_fks=False, autoload_with=testing.db) + t1 = Table("t1", meta2, resolve_fks=False, autoload_with=connection) in_("t1", meta2.tables) not_in("t2", meta2.tables) @@ -176,14 +165,13 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): lambda: list(t1.c.t2id.foreign_keys)[0].column, ) - t2 = Table("t2", meta2, autoload_with=testing.db) + t2 = Table("t2", meta2, autoload_with=connection) # now it resolves is_true(t1.c.t2id.references(t2.c.id)) - @testing.provide_metadata - def test_resolve_fks_false_extend_existing(self): - meta = self.metadata + def test_resolve_fks_false_extend_existing(self, connection, metadata): + meta = metadata Table( "t1", meta, @@ -197,7 +185,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("id", sa.Integer, primary_key=True), test_needs_fk=True, ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() Table("t1", meta2) in_("t1", meta2.tables) @@ -206,7 +194,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): "t1", meta2, resolve_fks=False, - autoload_with=testing.db, + autoload_with=connection, extend_existing=True, ) not_in("t2", meta2.tables) @@ -216,14 +204,13 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): lambda: list(t1.c.t2id.foreign_keys)[0].column, ) - t2 = Table("t2", meta2, autoload_with=testing.db) + t2 = Table("t2", meta2, autoload_with=connection) # now it resolves is_true(t1.c.t2id.references(t2.c.id)) - @testing.provide_metadata - def test_resolve_fks_false_metadata(self): - meta = self.metadata + def test_resolve_fks_false_metadata(self, connection, metadata): + meta = metadata Table( "t1", meta, @@ -237,9 +224,9 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("id", sa.Integer, primary_key=True), test_needs_fk=True, ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() - meta2.reflect(testing.db, resolve_fks=False, only=["t1"]) + meta2.reflect(connection, resolve_fks=False, only=["t1"]) in_("t1", meta2.tables) not_in("t2", meta2.tables) @@ -250,36 +237,35 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): lambda: list(t1.c.t2id.foreign_keys)[0].column, ) - meta2.reflect(testing.db, resolve_fks=False) + meta2.reflect(connection, resolve_fks=False) t2 = meta2.tables["t2"] is_true(t1.c.t2id.references(t2.c.id)) - def test_nonexistent(self): + def test_nonexistent(self, connection): meta = MetaData() assert_raises( sa.exc.NoSuchTableError, Table, "nonexistent", meta, - autoload_with=testing.db, + autoload_with=connection, ) assert "nonexistent" not in meta.tables - @testing.provide_metadata - def test_include_columns(self): - meta = self.metadata + def test_include_columns(self, connection, metadata): + meta = metadata foo = Table( "foo", meta, *[Column(n, sa.String(30)) for n in ["a", "b", "c", "d", "e", "f"]] ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() foo = Table( "foo", meta2, - autoload_with=testing.db, + autoload_with=connection, include_columns=["b", "f", "e"], ) # test that cols come back in original order @@ -291,7 +277,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): # test against a table which is already reflected meta3 = MetaData() - foo = Table("foo", meta3, autoload_with=testing.db) + foo = Table("foo", meta3, autoload_with=connection) foo = Table( "foo", meta3, include_columns=["b", "f", "e"], extend_existing=True @@ -302,9 +288,8 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): for c in ("a", "c", "d"): assert c not in foo.c - @testing.provide_metadata - def test_extend_existing(self): - meta = self.metadata + def test_extend_existing(self, connection, metadata): + meta = metadata Table( "t", @@ -314,7 +299,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("y", Integer), Column("z", Integer, server_default="5"), ) - meta.create_all() + meta.create_all(connection) m2 = MetaData() old_z = Column("z", String, primary_key=True) @@ -327,7 +312,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): m2, old_y, extend_existing=True, - autoload_with=testing.db, + autoload_with=connection, ) eq_(set(t2.columns.keys()), set(["x", "y", "z", "q", "id"])) @@ -346,7 +331,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): "t", m3, extend_existing=False, - autoload_with=testing.db, + autoload_with=connection, ) eq_(set(t3.columns.keys()), set(["z"])) @@ -362,7 +347,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): old_y, extend_existing=True, autoload_replace=False, - autoload_with=testing.db, + autoload_with=connection, ) eq_(set(t4.columns.keys()), set(["x", "y", "z", "q", "id"])) eq_(list(t4.primary_key.columns), [t4.c.z, t4.c.id]) @@ -371,9 +356,10 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert t4.c.z.type._type_affinity is String assert t4.c.q is old_q - @testing.provide_metadata - def test_extend_existing_reflect_all_dont_dupe_index(self): - m = self.metadata + def test_extend_existing_reflect_all_dont_dupe_index( + self, connection, metadata + ): + m = metadata d = Table( "d", m, @@ -389,10 +375,10 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("id", Integer, primary_key=True), Column("aid", ForeignKey("d.id")), ) - m.create_all() + m.create_all(connection) m2 = MetaData() - m2.reflect(testing.db, extend_existing=True) + m2.reflect(connection, extend_existing=True) eq_( len( @@ -422,51 +408,51 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): ) @testing.emits_warning(r".*omitted columns") - @testing.provide_metadata - def test_include_columns_indexes(self): - m = self.metadata + def test_include_columns_indexes(self, connection, metadata): + m = metadata t1 = Table("t1", m, Column("a", sa.Integer), Column("b", sa.Integer)) sa.Index("foobar", t1.c.a, t1.c.b) sa.Index("bat", t1.c.a) - m.create_all() + m.create_all(connection) m2 = MetaData() - t2 = Table("t1", m2, autoload_with=testing.db) + t2 = Table("t1", m2, autoload_with=connection) assert len(t2.indexes) == 2 m2 = MetaData() - t2 = Table("t1", m2, autoload_with=testing.db, include_columns=["a"]) + t2 = Table("t1", m2, autoload_with=connection, include_columns=["a"]) assert len(t2.indexes) == 1 m2 = MetaData() t2 = Table( - "t1", m2, autoload_with=testing.db, include_columns=["a", "b"] + "t1", m2, autoload_with=connection, include_columns=["a", "b"] ) assert len(t2.indexes) == 2 - @testing.provide_metadata - def test_autoload_replace_foreign_key_nonpresent(self): + def test_autoload_replace_foreign_key_nonpresent( + self, connection, metadata + ): """test autoload_replace=False with col plus FK establishes the FK not present in the DB. """ - Table("a", self.metadata, Column("id", Integer, primary_key=True)) + Table("a", metadata, Column("id", Integer, primary_key=True)) Table( "b", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("a_id", Integer), ) - self.metadata.create_all() + metadata.create_all(connection) m2 = MetaData() b2 = Table("b", m2, Column("a_id", Integer, sa.ForeignKey("a.id"))) - a2 = Table("a", m2, autoload_with=testing.db) + a2 = Table("a", m2, autoload_with=connection) b2 = Table( "b", m2, extend_existing=True, - autoload_with=testing.db, + autoload_with=connection, autoload_replace=False, ) @@ -474,30 +460,31 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert b2.c.a_id.references(a2.c.id) eq_(len(b2.constraints), 2) - @testing.provide_metadata - def test_autoload_replace_foreign_key_ispresent(self): + def test_autoload_replace_foreign_key_ispresent( + self, connection, metadata + ): """test autoload_replace=False with col plus FK mirroring DB-reflected FK skips the reflected FK and installs the in-python one only. """ - Table("a", self.metadata, Column("id", Integer, primary_key=True)) + Table("a", metadata, Column("id", Integer, primary_key=True)) Table( "b", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("a_id", Integer, sa.ForeignKey("a.id")), ) - self.metadata.create_all() + metadata.create_all(connection) m2 = MetaData() b2 = Table("b", m2, Column("a_id", Integer, sa.ForeignKey("a.id"))) - a2 = Table("a", m2, autoload_with=testing.db) + a2 = Table("a", m2, autoload_with=connection) b2 = Table( "b", m2, extend_existing=True, - autoload_with=testing.db, + autoload_with=connection, autoload_replace=False, ) @@ -505,29 +492,28 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert b2.c.a_id.references(a2.c.id) eq_(len(b2.constraints), 2) - @testing.provide_metadata - def test_autoload_replace_foreign_key_removed(self): + def test_autoload_replace_foreign_key_removed(self, connection, metadata): """test autoload_replace=False with col minus FK that's in the DB means the FK is skipped and doesn't get installed at all. """ - Table("a", self.metadata, Column("id", Integer, primary_key=True)) + Table("a", metadata, Column("id", Integer, primary_key=True)) Table( "b", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("a_id", Integer, sa.ForeignKey("a.id")), ) - self.metadata.create_all() + metadata.create_all(connection) m2 = MetaData() b2 = Table("b", m2, Column("a_id", Integer)) - a2 = Table("a", m2, autoload_with=testing.db) + a2 = Table("a", m2, autoload_with=connection) b2 = Table( "b", m2, extend_existing=True, - autoload_with=testing.db, + autoload_with=connection, autoload_replace=False, ) @@ -535,10 +521,9 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert not b2.c.a_id.references(a2.c.id) eq_(len(b2.constraints), 1) - @testing.provide_metadata - def test_autoload_replace_primary_key(self): - Table("a", self.metadata, Column("id", Integer)) - self.metadata.create_all() + def test_autoload_replace_primary_key(self, connection, metadata): + Table("a", metadata, Column("id", Integer)) + metadata.create_all(connection) m2 = MetaData() a2 = Table("a", m2, Column("id", Integer, primary_key=True)) @@ -546,7 +531,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Table( "a", m2, - autoload_with=testing.db, + autoload_with=connection, autoload_replace=False, extend_existing=True, ) @@ -555,15 +540,14 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): def test_autoload_replace_arg(self): Table("t", MetaData(), autoload_replace=False) - @testing.provide_metadata - def test_autoincrement_col(self): + def test_autoincrement_col(self, connection, metadata): """test that 'autoincrement' is reflected according to sqla's policy. Don't mark this test as unsupported for any backend ! """ - meta = self.metadata + meta = metadata Table( "test", meta, @@ -581,41 +565,35 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("data", sa.String(50)), mysql_engine="InnoDB", ) - meta.create_all() + meta.create_all(connection) m2 = MetaData() - t1a = Table("test", m2, autoload_with=testing.db) + t1a = Table("test", m2, autoload_with=connection) assert t1a._autoincrement_column is t1a.c.id - t2a = Table("test2", m2, autoload_with=testing.db) + t2a = Table("test2", m2, autoload_with=connection) assert t2a._autoincrement_column is None @skip("sqlite") - @testing.provide_metadata - def test_unknown_types(self): + def test_unknown_types(self, connection, metadata): """Test the handling of unknown types for the given dialect. sqlite is skipped because it has special rules for unknown types using 'affinity types' - this feature is tested in that dialect's test spec. """ - meta = self.metadata + meta = metadata t = Table("test", meta, Column("foo", sa.DateTime)) - ischema_names = testing.db.dialect.ischema_names - t.create() - testing.db.dialect.ischema_names = {} - try: - m2 = MetaData(testing.db) + t.create(connection) + + with mock.patch.object(connection.dialect, "ischema_names", {}): + m2 = MetaData() with testing.expect_warnings("Did not recognize type"): - t3 = Table("test", m2, autoload_with=testing.db) + t3 = Table("test", m2, autoload_with=connection) is_(t3.c.foo.type.__class__, sa.types.NullType) - finally: - testing.db.dialect.ischema_names = ischema_names - - @testing.provide_metadata - def test_basic_override(self): - meta = self.metadata + def test_basic_override(self, connection, metadata): + meta = metadata table = Table( "override_test", meta, @@ -623,7 +601,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("col2", sa.String(20)), Column("col3", sa.Numeric), ) - table.create() + table.create(connection) meta2 = MetaData() table = Table( @@ -631,16 +609,15 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): meta2, Column("col2", sa.Unicode()), Column("col4", sa.String(30)), - autoload_with=testing.db, + autoload_with=connection, ) self.assert_(isinstance(table.c.col1.type, sa.Integer)) self.assert_(isinstance(table.c.col2.type, sa.Unicode)) self.assert_(isinstance(table.c.col4.type, sa.String)) - @testing.provide_metadata - def test_override_upgrade_pk_flag(self): - meta = self.metadata + def test_override_upgrade_pk_flag(self, connection, metadata): + meta = metadata table = Table( "override_test", meta, @@ -648,26 +625,25 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("col2", sa.String(20)), Column("col3", sa.Numeric), ) - table.create() + table.create(connection) meta2 = MetaData() table = Table( "override_test", meta2, Column("col1", sa.Integer, primary_key=True), - autoload_with=testing.db, + autoload_with=connection, ) eq_(list(table.primary_key), [table.c.col1]) eq_(table.c.col1.primary_key, True) - @testing.provide_metadata - def test_override_pkfk(self): + def test_override_pkfk(self, connection, metadata): """test that you can override columns which contain foreign keys to other reflected tables, where the foreign key column is also a primary key column""" - meta = self.metadata + meta = metadata Table( "users", meta, @@ -681,7 +657,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("street", sa.String(30)), ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() a2 = Table( "addresses", @@ -689,36 +665,35 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column( "id", sa.Integer, sa.ForeignKey("users.id"), primary_key=True ), - autoload_with=testing.db, + autoload_with=connection, ) - u2 = Table("users", meta2, autoload_with=testing.db) + u2 = Table("users", meta2, autoload_with=connection) assert list(a2.primary_key) == [a2.c.id] assert list(u2.primary_key) == [u2.c.id] assert u2.join(a2).onclause.compare(u2.c.id == a2.c.id) meta3 = MetaData() - u3 = Table("users", meta3, autoload_with=testing.db) + u3 = Table("users", meta3, autoload_with=connection) a3 = Table( "addresses", meta3, Column( "id", sa.Integer, sa.ForeignKey("users.id"), primary_key=True ), - autoload_with=testing.db, + autoload_with=connection, ) assert list(a3.primary_key) == [a3.c.id] assert list(u3.primary_key) == [u3.c.id] assert u3.join(a3).onclause.compare(u3.c.id == a3.c.id) - @testing.provide_metadata - def test_override_nonexistent_fk(self): + def test_override_nonexistent_fk(self, connection, metadata): """test that you can override columns and create new foreign keys to other reflected tables which have no foreign keys. this is common with MySQL MyISAM tables.""" - meta = self.metadata + meta = metadata Table( "users", meta, @@ -733,15 +708,15 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("user_id", sa.Integer), ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() a2 = Table( "addresses", meta2, Column("user_id", sa.Integer, sa.ForeignKey("users.id")), - autoload_with=testing.db, + autoload_with=connection, ) - u2 = Table("users", meta2, autoload_with=testing.db) + u2 = Table("users", meta2, autoload_with=connection) assert len(a2.c.user_id.foreign_keys) == 1 assert len(a2.foreign_keys) == 1 assert [c.parent for c in a2.foreign_keys] == [a2.c.user_id] @@ -750,13 +725,13 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert u2.join(a2).onclause.compare(u2.c.id == a2.c.user_id) meta3 = MetaData() - u3 = Table("users", meta3, autoload_with=testing.db) + u3 = Table("users", meta3, autoload_with=connection) a3 = Table( "addresses", meta3, Column("user_id", sa.Integer, sa.ForeignKey("users.id")), - autoload_with=testing.db, + autoload_with=connection, ) assert u3.join(a3).onclause.compare(u3.c.id == a3.c.user_id) @@ -766,7 +741,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): "users", meta4, Column("id", sa.Integer, key="u_id", primary_key=True), - autoload_with=testing.db, + autoload_with=connection, ) a4 = Table( @@ -777,7 +752,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column( "user_id", sa.Integer, sa.ForeignKey("users.u_id"), key="id" ), - autoload_with=testing.db, + autoload_with=connection, ) # for the thing happening here with the column collection, @@ -789,12 +764,9 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert len(a4.columns) == 3 assert len(a4.constraints) == 2 - @testing.provide_metadata - def test_override_composite_fk(self): + def test_override_composite_fk(self, connection, metadata): """Test double-remove of composite foreign key, when replaced.""" - metadata = self.metadata - Table( "a", metadata, @@ -810,26 +782,25 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): sa.ForeignKeyConstraint(["x", "y"], ["a.x", "a.y"]), ) - metadata.create_all() + metadata.create_all(connection) meta2 = MetaData() c1 = Column("x", sa.Integer, primary_key=True) c2 = Column("y", sa.Integer, primary_key=True) f1 = sa.ForeignKeyConstraint(["x", "y"], ["a.x", "a.y"]) - b1 = Table("b", meta2, c1, c2, f1, autoload_with=testing.db) + b1 = Table("b", meta2, c1, c2, f1, autoload_with=connection) assert b1.c.x is c1 assert b1.c.y is c2 assert f1 in b1.constraints assert len(b1.constraints) == 2 - @testing.provide_metadata - def test_override_keys(self): + def test_override_keys(self, connection, metadata): """test that columns can be overridden with a 'key', and that ForeignKey targeting during reflection still works.""" - meta = self.metadata + meta = metadata Table( "a", meta, @@ -843,27 +814,26 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("y", sa.Integer, sa.ForeignKey("a.x")), test_needs_fk=True, ) - meta.create_all(testing.db) + meta.create_all(connection) m2 = MetaData() a2 = Table( "a", m2, Column("x", sa.Integer, primary_key=True, key="x1"), - autoload_with=testing.db, + autoload_with=connection, ) - b2 = Table("b", m2, autoload_with=testing.db) + b2 = Table("b", m2, autoload_with=connection) assert a2.join(b2).onclause.compare(a2.c.x1 == b2.c.y) assert b2.c.y.references(a2.c.x1) - @testing.provide_metadata - def test_nonreflected_fk_raises(self): + def test_nonreflected_fk_raises(self, connection, metadata): """test that a NoReferencedColumnError is raised when reflecting a table with an FK to another table which has not included the target column in its reflection. """ - meta = self.metadata + meta = metadata Table( "a", meta, @@ -877,21 +847,19 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("y", sa.Integer, sa.ForeignKey("a.x")), test_needs_fk=True, ) - meta.create_all() + meta.create_all(connection) m2 = MetaData() - a2 = Table("a", m2, include_columns=["z"], autoload_with=testing.db) - b2 = Table("b", m2, autoload_with=testing.db) + a2 = Table("a", m2, include_columns=["z"], autoload_with=connection) + b2 = Table("b", m2, autoload_with=connection) assert_raises(sa.exc.NoReferencedColumnError, a2.join, b2) - @testing.exclude("mysql", "<", (4, 1, 1), "innodb funkiness") - @testing.provide_metadata - def test_override_existing_fk(self): + def test_override_existing_fk(self, connection, metadata): """test that you can override columns and specify new foreign keys to other reflected tables, on columns which *do* already have that foreign key, and that the FK is not duped.""" - meta = self.metadata + meta = metadata Table( "users", meta, @@ -907,15 +875,15 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): test_needs_fk=True, ) - meta.create_all(testing.db) + meta.create_all(connection) meta2 = MetaData() a2 = Table( "addresses", meta2, Column("user_id", sa.Integer, sa.ForeignKey("users.id")), - autoload_with=testing.db, + autoload_with=connection, ) - u2 = Table("users", meta2, autoload_with=testing.db) + u2 = Table("users", meta2, autoload_with=connection) s = sa.select(a2).subquery() assert s.c.user_id is not None @@ -932,14 +900,14 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): "users", meta2, Column("id", sa.Integer, primary_key=True), - autoload_with=testing.db, + autoload_with=connection, ) a2 = Table( "addresses", meta2, Column("id", sa.Integer, primary_key=True), Column("user_id", sa.Integer, sa.ForeignKey("users.id")), - autoload_with=testing.db, + autoload_with=connection, ) s = sa.select(a2).subquery() @@ -953,8 +921,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert u2.join(a2).onclause.compare(u2.c.id == a2.c.user_id) @testing.only_on(["postgresql", "mysql"]) - @testing.provide_metadata - def test_fk_options(self): + def test_fk_options(self, connection, metadata): """test that foreign key reflection includes options (on backends with {dialect}.get_foreign_keys() support)""" @@ -989,7 +956,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): ) test_attrs = ("onupdate", "ondelete") - meta = self.metadata + meta = metadata Table( "users", meta, @@ -1004,40 +971,38 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Column("user_id", sa.Integer, addresses_user_id_fkey), test_needs_fk=True, ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() - meta2.reflect(testing.db) + meta2.reflect(connection) for fk in meta2.tables["addresses"].foreign_keys: ref = addresses_user_id_fkey for attr in test_attrs: eq_(getattr(fk, attr), getattr(ref, attr)) - @testing.provide_metadata - def test_pks_not_uniques(self): + def test_pks_not_uniques(self, connection, metadata): """test that primary key reflection not tripped up by unique indexes""" - with testing.db.begin() as conn: - conn.exec_driver_sql( - """ - CREATE TABLE book ( - id INTEGER NOT NULL, - title VARCHAR(100) NOT NULL, - series INTEGER, - series_id INTEGER, - UNIQUE(series, series_id), - PRIMARY KEY(id) - )""" - ) + conn = connection + conn.exec_driver_sql( + """ + CREATE TABLE book ( + id INTEGER NOT NULL, + title VARCHAR(100) NOT NULL, + series INTEGER, + series_id INTEGER, + UNIQUE(series, series_id), + PRIMARY KEY(id) + )""" + ) - book = Table("book", self.metadata, autoload_with=testing.db) + book = Table("book", metadata, autoload_with=connection) assert book.primary_key.contains_column(book.c.id) assert not book.primary_key.contains_column(book.c.series) eq_(len(book.primary_key), 1) - def test_fk_error(self): - metadata = MetaData(testing.db) + def test_fk_error(self, connection, metadata): Table( "slots", metadata, @@ -1052,37 +1017,35 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): "could not find table 'pkgs' with which to generate " "a foreign key to target column 'pkg_id'", metadata.create_all, + connection, ) - @testing.provide_metadata - def test_composite_pks(self): + def test_composite_pks(self, connection, metadata): """test reflection of a composite primary key""" - with testing.db.begin() as conn: - conn.exec_driver_sql( - """ - CREATE TABLE book ( - id INTEGER NOT NULL, - isbn VARCHAR(50) NOT NULL, - title VARCHAR(100) NOT NULL, - series INTEGER NOT NULL, - series_id INTEGER NOT NULL, - UNIQUE(series, series_id), - PRIMARY KEY(id, isbn) - )""" - ) - book = Table("book", self.metadata, autoload_with=testing.db) + conn = connection + conn.exec_driver_sql( + """ + CREATE TABLE book ( + id INTEGER NOT NULL, + isbn VARCHAR(50) NOT NULL, + title VARCHAR(100) NOT NULL, + series INTEGER NOT NULL, + series_id INTEGER NOT NULL, + UNIQUE(series, series_id), + PRIMARY KEY(id, isbn) + )""" + ) + book = Table("book", metadata, autoload_with=connection) assert book.primary_key.contains_column(book.c.id) assert book.primary_key.contains_column(book.c.isbn) assert not book.primary_key.contains_column(book.c.series) eq_(len(book.primary_key), 2) - @testing.exclude("mysql", "<", (4, 1, 1), "innodb funkiness") - @testing.provide_metadata - def test_composite_fk(self): + def test_composite_fk(self, connection, metadata): """test reflection of composite foreign keys""" - meta = self.metadata + meta = metadata multi = Table( "multi", meta, @@ -1107,11 +1070,11 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): ), test_needs_fk=True, ) - meta.create_all() + meta.create_all(connection) meta2 = MetaData() - table = Table("multi", meta2, autoload_with=testing.db) - table2 = Table("multi2", meta2, autoload_with=testing.db) + table = Table("multi", meta2, autoload_with=connection) + table2 = Table("multi2", meta2, autoload_with=connection) self.assert_tables_equal(multi, table) self.assert_tables_equal(multi2, table2) j = sa.join(table, table2) @@ -1126,13 +1089,12 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): @testing.crashes("oracle", "FIXME: unknown, confirm not fails_on") @testing.requires.check_constraints - @testing.provide_metadata - def test_reserved(self): + def test_reserved(self, connection, metadata): # check a table that uses a SQL reserved name doesn't cause an # error - meta = self.metadata + meta = metadata table_a = Table( "select", meta, @@ -1142,11 +1104,11 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): ) sa.Index("where", table_a.c["from"]) - if meta.bind.dialect.requires_name_normalize: + if connection.dialect.requires_name_normalize: check_col = "TRUE" else: check_col = "true" - quoter = meta.bind.dialect.identifier_preparer.quote_identifier + quoter = connection.dialect.identifier_preparer.quote_identifier Table( "false", @@ -1164,120 +1126,81 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): sa.PrimaryKeyConstraint("or", "join", name="to"), ) index_c = sa.Index("else", table_c.c.join) - meta.create_all() - index_c.drop() + meta.create_all(connection) + index_c.drop(connection) meta2 = MetaData() - Table("select", meta2, autoload_with=testing.db) - Table("false", meta2, autoload_with=testing.db) - Table("is", meta2, autoload_with=testing.db) + Table("select", meta2, autoload_with=connection) + Table("false", meta2, autoload_with=connection) + Table("is", meta2, autoload_with=connection) - @testing.provide_metadata - def _test_reflect_uses_bind(self, fn): - from sqlalchemy.pool import AssertionPool - - e = engines.testing_engine(options={"poolclass": AssertionPool}) - fn(e) - - def test_reflect_uses_bind_constructor_conn_reflect(self): - self._test_reflect_uses_bind(lambda e: MetaData(e.connect()).reflect()) - - def test_reflect_uses_bind_constructor_engine_reflect(self): - self._test_reflect_uses_bind(lambda e: MetaData(e).reflect()) - - def test_reflect_uses_bind_conn_reflect(self): - self._test_reflect_uses_bind(lambda e: MetaData().reflect(e.connect())) - - def test_reflect_uses_bind_engine_reflect(self): - self._test_reflect_uses_bind(lambda e: MetaData().reflect(e)) - - def test_reflect_uses_bind_option_engine_reflect(self): - self._test_reflect_uses_bind( - lambda e: MetaData().reflect(e.execution_options(foo="bar")) - ) - - @testing.provide_metadata - def test_reflect_all(self): - existing = inspect(testing.db).get_table_names() + def test_reflect_all(self, connection, metadata): names = ["rt_%s" % name for name in ("a", "b", "c", "d", "e")] nameset = set(names) - for name in names: - # be sure our starting environment is sane - self.assert_(name not in existing) - self.assert_("rt_f" not in existing) - baseline = self.metadata + baseline = metadata for name in names: Table(name, baseline, Column("id", sa.Integer, primary_key=True)) - baseline.create_all() + baseline.create_all(connection) - m1 = MetaData(testing.db) - self.assert_(not m1.tables) - m1.reflect() - self.assert_(nameset.issubset(set(m1.tables.keys()))) + m1 = MetaData() + is_false(m1.tables) + m1.reflect(connection) + is_true(nameset.issubset(set(m1.tables.keys()))) m2 = MetaData() - m2.reflect(testing.db, only=["rt_a", "rt_b"]) - self.assert_(set(m2.tables.keys()) == set(["rt_a", "rt_b"])) + m2.reflect(connection, only=["rt_a", "rt_b"]) + eq_(set(m2.tables.keys()), set(["rt_a", "rt_b"])) m3 = MetaData() - c = testing.db.connect() - m3.reflect(bind=c, only=lambda name, meta: name == "rt_c") - self.assert_(set(m3.tables.keys()) == set(["rt_c"])) + m3.reflect(connection, only=lambda name, meta: name == "rt_c") + eq_(set(m3.tables.keys()), set(["rt_c"])) - m4 = MetaData(testing.db) + m4 = MetaData() assert_raises_message( sa.exc.InvalidRequestError, r"Could not reflect: requested table\(s\) not available in " r"Engine\(.*?\): \(rt_f\)", m4.reflect, + connection, only=["rt_a", "rt_f"], ) - m5 = MetaData(testing.db) - m5.reflect(only=[]) - self.assert_(not m5.tables) + m5 = MetaData() + m5.reflect(connection, only=[]) + is_false(m5.tables) - m6 = MetaData(testing.db) - m6.reflect(only=lambda n, m: False) - self.assert_(not m6.tables) + m6 = MetaData() + m6.reflect(connection, only=lambda n, m: False) + is_false(m6.tables) - m7 = MetaData(testing.db) - m7.reflect() - self.assert_(nameset.issubset(set(m7.tables.keys()))) + m7 = MetaData() + m7.reflect(connection) + is_true(nameset.issubset(set(m7.tables.keys()))) - m8 = MetaData() - assert_raises(sa.exc.UnboundExecutionError, m8.reflect) - - m8_e1 = MetaData(testing.db) + m8_e1 = MetaData() rt_c = Table("rt_c", m8_e1) - m8_e1.reflect(extend_existing=True) + m8_e1.reflect(connection, extend_existing=True) eq_(set(m8_e1.tables.keys()), set(names)) eq_(rt_c.c.keys(), ["id"]) - m8_e2 = MetaData(testing.db) + m8_e2 = MetaData() rt_c = Table("rt_c", m8_e2) - m8_e2.reflect(extend_existing=True, only=["rt_a", "rt_c"]) + m8_e2.reflect(connection, extend_existing=True, only=["rt_a", "rt_c"]) eq_(set(m8_e2.tables.keys()), set(["rt_a", "rt_c"])) eq_(rt_c.c.keys(), ["id"]) - if existing: - print("Other tables present in database, skipping some checks.") - else: - baseline.drop_all() - m9 = MetaData(testing.db) - m9.reflect() - self.assert_(not m9.tables) + baseline.drop_all(connection) + m9 = MetaData() + m9.reflect(connection) + is_false(m9.tables) - @testing.provide_metadata - def test_reflect_all_unreflectable_table(self): + def test_reflect_all_unreflectable_table(self, connection, metadata): names = ["rt_%s" % name for name in ("a", "b", "c", "d", "e")] for name in names: - Table( - name, self.metadata, Column("id", sa.Integer, primary_key=True) - ) - self.metadata.create_all() + Table(name, metadata, Column("id", sa.Integer, primary_key=True)) + metadata.create_all(connection) m = MetaData() @@ -1292,7 +1215,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): with mock.patch.object(inspector, "reflect_table", patched): with expect_warnings("Skipping table rt_c: Can't reflect rt_c"): - m.reflect(bind=testing.db) + m.reflect(connection) assert_raises_message( sa.exc.UnreflectableTableError, @@ -1300,23 +1223,11 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): Table, "rt_c", m, - autoload_with=testing.db, + autoload_with=connection, ) - def test_reflect_all_conn_closing(self): - m1 = MetaData() - c = testing.db.connect() - m1.reflect(bind=c) - assert not c.closed - - def test_inspector_conn_closing(self): - c = testing.db.connect() - inspect(c) - assert not c.closed - - @testing.provide_metadata - def test_index_reflection(self): - m1 = self.metadata + def test_index_reflection(self, connection, metadata): + m1 = metadata t1 = Table( "party", m1, @@ -1325,9 +1236,9 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): ) sa.Index("idx1", t1.c.id, unique=True) sa.Index("idx2", t1.c.name, t1.c.id, unique=False) - m1.create_all() + m1.create_all(connection) m2 = MetaData() - t2 = Table("party", m2, autoload_with=testing.db) + t2 = Table("party", m2, autoload_with=connection) assert len(t2.indexes) == 3 # Make sure indexes are in the order we expect them in @@ -1345,18 +1256,17 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert set([t2.c.name]) == set(r3.columns) @testing.requires.comment_reflection - @testing.provide_metadata - def test_comment_reflection(self): - m1 = self.metadata + def test_comment_reflection(self, connection, metadata): + m1 = metadata Table( "sometable", m1, Column("id", sa.Integer, comment="c1 comment"), comment="t1 comment", ) - m1.create_all() + m1.create_all(connection) m2 = MetaData() - t2 = Table("sometable", m2, autoload_with=testing.db) + t2 = Table("sometable", m2, autoload_with=connection) eq_(t2.comment, "t1 comment") eq_(t2.c.id.comment, "c1 comment") @@ -1366,18 +1276,17 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): eq_(t3.c.id.comment, "c1 comment") @testing.requires.check_constraint_reflection - @testing.provide_metadata - def test_check_constraint_reflection(self): - m1 = self.metadata + def test_check_constraint_reflection(self, connection, metadata): + m1 = metadata Table( "x", m1, Column("q", Integer), sa.CheckConstraint("q > 10", name="ck1"), ) - m1.create_all() + m1.create_all(connection) m2 = MetaData() - t2 = Table("x", m2, autoload_with=testing.db) + t2 = Table("x", m2, autoload_with=connection) ck = [ const @@ -1388,40 +1297,35 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): eq_regex(ck.sqltext.text, r"[\(`]*q[\)`]* > 10") eq_(ck.name, "ck1") - @testing.provide_metadata - def test_index_reflection_cols_busted(self): - t = Table( - "x", self.metadata, Column("a", Integer), Column("b", Integer) - ) + def test_index_reflection_cols_busted(self, connection, metadata): + t = Table("x", metadata, Column("a", Integer), Column("b", Integer)) sa.Index("x_ix", t.c.a, t.c.b) - self.metadata.create_all() + metadata.create_all(connection) def mock_get_columns(self, connection, table_name, **kw): return [{"name": "b", "type": Integer, "primary_key": False}] with testing.mock.patch.object( - testing.db.dialect, "get_columns", mock_get_columns + connection.dialect, "get_columns", mock_get_columns ): m = MetaData() with testing.expect_warnings( "index key 'a' was not located in columns" ): - t = Table("x", m, autoload_with=testing.db) + t = Table("x", m, autoload_with=connection) eq_(list(t.indexes)[0].columns, [t.c.b]) @testing.requires.views - @testing.provide_metadata - def test_views(self): - metadata = self.metadata + def test_views(self, connection, metadata): users, addresses, dingalings = createTables(metadata) try: - metadata.create_all() - _create_views(metadata.bind, None) + metadata.create_all(connection) + _create_views(connection, None) m2 = MetaData() - users_v = Table("users_v", m2, autoload_with=testing.db) + users_v = Table("users_v", m2, autoload_with=connection) addresses_v = Table( - "email_addresses_v", m2, autoload_with=testing.db + "email_addresses_v", m2, autoload_with=connection ) for c1, c2 in zip(users_v.c, users.c): @@ -1432,25 +1336,23 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): eq_(c1.name, c2.name) self.assert_types_base(c1, c2) finally: - _drop_views(metadata.bind) + _drop_views(connection) @testing.requires.views - @testing.provide_metadata - def test_reflect_all_with_views(self): - metadata = self.metadata + def test_reflect_all_with_views(self, connection, metadata): users, addresses, dingalings = createTables(metadata, None) try: - metadata.create_all() - _create_views(metadata.bind, None) - m2 = MetaData(testing.db) + metadata.create_all(connection) + _create_views(connection, None) + m2 = MetaData() - m2.reflect(views=False) + m2.reflect(connection, views=False) eq_( set(m2.tables), set(["users", "email_addresses", "dingalings"]) ) - m2 = MetaData(testing.db) - m2.reflect(views=True) + m2 = MetaData() + m2.reflect(connection, views=True) eq_( set(m2.tables), set( @@ -1464,7 +1366,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): ), ) finally: - _drop_views(metadata.bind) + _drop_views(connection) class CreateDropTest(fixtures.TablesTest): @@ -1473,6 +1375,13 @@ class CreateDropTest(fixtures.TablesTest): run_create_tables = None @classmethod + def teardown_class(cls): + # TablesTest is used here without + # run_create_tables, so add an explicit drop of whatever is in + # metadata + cls._tables_metadata.drop_all(testing.db) + + @classmethod def define_tables(cls, metadata): Table( "users", @@ -1525,11 +1434,8 @@ class CreateDropTest(fixtures.TablesTest): Column("item_name", sa.VARCHAR(50)), ) - def teardown(self): - self.metadata.drop_all(testing.db) - def test_sorter(self): - tables = self.metadata.sorted_tables + tables = self.tables_test_metadata.sorted_tables table_names = [t.name for t in tables] ua = [n for n in table_names if n in ("users", "email_addresses")] oi = [n for n in table_names if n in ("orders", "items")] @@ -1537,39 +1443,41 @@ class CreateDropTest(fixtures.TablesTest): eq_(ua, ["users", "email_addresses"]) eq_(oi, ["orders", "items"]) - def test_checkfirst(self): - insp = inspect(testing.db) + def test_checkfirst(self, connection): + insp = inspect(connection) users = self.tables.users is_false(insp.has_table("users")) - users.create(bind=testing.db) + users.create(connection) is_true(insp.has_table("users")) - users.create(bind=testing.db, checkfirst=True) - users.drop(bind=testing.db) - users.drop(bind=testing.db, checkfirst=True) + users.create(connection, checkfirst=True) + users.drop(connection) + users.drop(connection, checkfirst=True) is_false(insp.has_table("users")) - users.create(bind=testing.db, checkfirst=True) - users.drop(bind=testing.db) + users.create(connection, checkfirst=True) + users.drop(connection) - def test_createdrop(self): - insp = inspect(testing.db) - metadata = self.metadata - metadata.create_all(bind=testing.db) + def test_createdrop(self, connection): + insp = inspect(connection) + + metadata = self.tables_test_metadata + + metadata.create_all(connection) is_true(insp.has_table("items")) is_true(insp.has_table("email_addresses")) - metadata.create_all(bind=testing.db) + metadata.create_all(connection) is_true(insp.has_table("items")) - metadata.drop_all(bind=testing.db) + metadata.drop_all(connection) is_false(insp.has_table("items")) is_false(insp.has_table("email_addresses")) - metadata.drop_all(bind=testing.db) + metadata.drop_all(connection) is_false(insp.has_table("items")) - def test_tablenames(self): - metadata = self.metadata - metadata.create_all(bind=testing.db) - insp = inspect(testing.db) + def test_tablenames(self, connection): + metadata = self.tables_test_metadata + metadata.create_all(bind=connection) + insp = inspect(connection) # ensure all tables we created are in the list. is_true(set(insp.get_table_names()).issuperset(metadata.tables)) @@ -1597,12 +1505,11 @@ class SchemaManipulationTest(fixtures.TestBase): assert addresses.constraints == set([addresses.primary_key, fk]) -class UnicodeReflectionTest(fixtures.TestBase): +class UnicodeReflectionTest(fixtures.TablesTest): __backend__ = True @classmethod - def setup_class(cls): - cls.metadata = metadata = MetaData() + def define_tables(cls, metadata): no_multibyte_period = set([("plain", "col_plain", "ix_plain")]) no_has_table = [ @@ -1671,32 +1578,24 @@ class UnicodeReflectionTest(fixtures.TestBase): ) schema.Index(ixname, t.c[cname]) - metadata.create_all(testing.db) cls.names = names - @classmethod - def teardown_class(cls): - cls.metadata.drop_all(testing.db, checkfirst=False) - @testing.requires.unicode_connections - def test_has_table(self): - insp = inspect(testing.db) + def test_has_table(self, connection): + insp = inspect(connection) for tname, cname, ixname in self.names: assert insp.has_table(tname), "Can't detect name %s" % tname @testing.requires.unicode_connections - def test_basic(self): + def test_basic(self, connection): # the 'convert_unicode' should not get in the way of the # reflection process. reflect_table for oracle, postgresql # (others?) expect non-unicode strings in result sets/bind # params - bind = testing.db names = set([rec[0] for rec in self.names]) - reflected = set(inspect(bind).get_table_names()) - - # Jython 2.5 on Java 5 lacks unicodedata.normalize + reflected = set(inspect(connection).get_table_names()) if not names.issubset(reflected) and hasattr(unicodedata, "normalize"): @@ -1711,14 +1610,14 @@ class UnicodeReflectionTest(fixtures.TestBase): # Yep. But still ensure that bulk reflection and # create/drop work with either normalization. - r = MetaData(bind) - r.reflect() - r.drop_all(checkfirst=False) - r.create_all(checkfirst=False) + r = MetaData() + r.reflect(connection) + r.drop_all(connection, checkfirst=False) + r.create_all(connection, checkfirst=False) @testing.requires.unicode_connections - def test_get_names(self): - inspector = inspect(testing.db) + def test_get_names(self, connection): + inspector = inspect(connection) names = dict( (tname, (cname, ixname)) for tname, cname, ixname in self.names ) @@ -1760,8 +1659,7 @@ class SchemaTest(fixtures.TestBase): @testing.requires.cross_schema_fk_reflection @testing.requires.implicit_default_schema @testing.provide_metadata - def test_blank_schema_arg(self): - metadata = self.metadata + def test_blank_schema_arg(self, connection, metadata): Table( "some_table", @@ -1778,37 +1676,27 @@ class SchemaTest(fixtures.TestBase): schema=None, test_needs_fk=True, ) - metadata.create_all() - with testing.db.connect() as conn: - meta2 = MetaData(conn, schema=testing.config.test_schema) - meta2.reflect() + metadata.create_all(connection) + meta2 = MetaData(schema=testing.config.test_schema) + meta2.reflect(connection) - eq_( - set(meta2.tables), - set( - [ - "some_other_table", - "%s.some_table" % testing.config.test_schema, - ] - ), - ) + eq_( + set(meta2.tables), + set( + [ + "some_other_table", + "%s.some_table" % testing.config.test_schema, + ] + ), + ) @testing.requires.schemas - def test_explicit_default_schema(self): - engine = testing.db - engine.connect().close() - - if testing.against("sqlite"): - # Works for CREATE TABLE main.foo, SELECT FROM main.foo, etc., - # but fails on: - # FOREIGN KEY(col2) REFERENCES main.table1 (col1) - schema = "main" - else: - schema = engine.dialect.default_schema_name + def test_explicit_default_schema(self, connection, metadata): + + schema = connection.dialect.default_schema_name assert bool(schema) - metadata = MetaData() Table( "table1", metadata, @@ -1826,54 +1714,41 @@ class SchemaTest(fixtures.TestBase): test_needs_fk=True, schema=schema, ) - try: - metadata.create_all(engine) - metadata.create_all(engine, checkfirst=True) - assert len(metadata.tables) == 2 - metadata.clear() - - Table("table1", metadata, autoload_with=engine, schema=schema) - Table("table2", metadata, autoload_with=engine, schema=schema) - assert len(metadata.tables) == 2 - finally: - metadata.drop_all(engine) + metadata.create_all(connection) + metadata.create_all(connection, checkfirst=True) + eq_(len(metadata.tables), 2) + + m1 = MetaData() + Table("table1", m1, autoload_with=connection, schema=schema) + Table("table2", m1, autoload_with=connection, schema=schema) + eq_(len(m1.tables), 2) @testing.requires.schemas - @testing.provide_metadata - def test_schema_translation(self): + def test_schema_translation(self, connection, metadata): Table( "foob", - self.metadata, + metadata, Column("q", Integer), schema=config.test_schema, ) - self.metadata.create_all() + metadata.create_all(connection) m = MetaData() map_ = {"foob": config.test_schema} - with config.db.connect().execution_options( - schema_translate_map=map_ - ) as conn: - t = Table("foob", m, schema="foob", autoload_with=conn) - eq_(t.schema, "foob") - eq_(t.c.keys(), ["q"]) + + c2 = connection.execution_options(schema_translate_map=map_) + t = Table("foob", m, schema="foob", autoload_with=c2) + eq_(t.schema, "foob") + eq_(t.c.keys(), ["q"]) @testing.requires.schemas @testing.fails_on("sybase", "FIXME: unknown") - def test_explicit_default_schema_metadata(self): - engine = testing.db - - if testing.against("sqlite"): - # Works for CREATE TABLE main.foo, SELECT FROM main.foo, etc., - # but fails on: - # FOREIGN KEY(col2) REFERENCES main.table1 (col1) - schema = "main" - else: - schema = engine.dialect.default_schema_name + def test_explicit_default_schema_metadata(self, connection, metadata): + schema = connection.dialect.default_schema_name - assert bool(schema) + is_true(schema) - metadata = MetaData(schema=schema) + metadata.schema = schema Table( "table1", metadata, @@ -1887,26 +1762,21 @@ class SchemaTest(fixtures.TestBase): Column("col2", sa.Integer, sa.ForeignKey("table1.col1")), test_needs_fk=True, ) - try: - metadata.create_all(engine) - metadata.create_all(engine, checkfirst=True) - assert len(metadata.tables) == 2 - metadata.clear() - - Table("table1", metadata, autoload_with=engine) - Table("table2", metadata, autoload_with=engine) - assert len(metadata.tables) == 2 - finally: - metadata.drop_all(engine) + metadata.create_all(connection) + metadata.create_all(connection, checkfirst=True) + + m1 = MetaData(schema=schema) + + Table("table1", m1, autoload_with=connection) + Table("table2", m1, autoload_with=connection) + eq_(len(m1.tables), 2) @testing.requires.schemas - @testing.provide_metadata - def test_metadata_reflect_schema(self): - metadata = self.metadata + def test_metadata_reflect_schema(self, connection, metadata): createTables(metadata, testing.config.test_schema) - metadata.create_all() - m2 = MetaData(schema=testing.config.test_schema, bind=testing.db) - m2.reflect() + metadata.create_all(connection) + m2 = MetaData(schema=testing.config.test_schema) + m2.reflect(connection) eq_( set(m2.tables), set( @@ -1921,24 +1791,23 @@ class SchemaTest(fixtures.TestBase): @testing.requires.schemas @testing.requires.cross_schema_fk_reflection @testing.requires.implicit_default_schema - @testing.provide_metadata - def test_reflect_all_schemas_default_overlap(self): - Table("t", self.metadata, Column("id", Integer, primary_key=True)) + def test_reflect_all_schemas_default_overlap(self, connection, metadata): + Table("t", metadata, Column("id", Integer, primary_key=True)) Table( "t", - self.metadata, + metadata, Column("id1", sa.ForeignKey("t.id")), schema=testing.config.test_schema, ) - self.metadata.create_all() + metadata.create_all(connection) m2 = MetaData() - m2.reflect(testing.db, schema=testing.config.test_schema) + m2.reflect(connection, schema=testing.config.test_schema) m3 = MetaData() - m3.reflect(testing.db) - m3.reflect(testing.db, schema=testing.config.test_schema) + m3.reflect(connection) + m3.reflect(connection, schema=testing.config.test_schema) eq_( set((t.name, t.schema) for t in m2.tables.values()), @@ -2015,30 +1884,28 @@ def createIndexes(con, schema=None): @testing.requires.views -def _create_views(con, schema=None): - with testing.db.begin() as conn: - for table_name in ("users", "email_addresses"): - fullname = table_name - if schema: - fullname = "%s.%s" % (schema, table_name) - view_name = fullname + "_v" - query = "CREATE VIEW %s AS SELECT * FROM %s" % ( - view_name, - fullname, - ) - conn.execute(sa.sql.text(query)) +def _create_views(conn, schema=None): + for table_name in ("users", "email_addresses"): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + "_v" + query = "CREATE VIEW %s AS SELECT * FROM %s" % ( + view_name, + fullname, + ) + conn.execute(sa.sql.text(query)) @testing.requires.views -def _drop_views(con, schema=None): - with testing.db.begin() as conn: - for table_name in ("email_addresses", "users"): - fullname = table_name - if schema: - fullname = "%s.%s" % (schema, table_name) - view_name = fullname + "_v" - query = "DROP VIEW %s" % view_name - conn.execute(sa.sql.text(query)) +def _drop_views(conn, schema=None): + for table_name in ("email_addresses", "users"): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + "_v" + query = "DROP VIEW %s" % view_name + conn.execute(sa.sql.text(query)) class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL): @@ -2064,9 +1931,9 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL): conn.exec_driver_sql("drop table weird_casing") @testing.requires.denormalized_names - def test_direct_quoting(self): + def test_direct_quoting(self, connection): m = MetaData() - t = Table("weird_casing", m, autoload_with=testing.db) + t = Table("weird_casing", m, autoload_with=connection) self.assert_compile( t.select(), "SELECT weird_casing.col1, " @@ -2097,13 +1964,13 @@ class CaseSensitiveTest(fixtures.TablesTest): ) @testing.fails_if(testing.requires._has_mysql_on_windows) - def test_table_names(self): - x = inspect(testing.db).get_table_names() + def test_table_names(self, connection): + x = inspect(connection).get_table_names() assert set(["SomeTable", "SomeOtherTable"]).issubset(x) - def test_reflect_exact_name(self): + def test_reflect_exact_name(self, connection): m = MetaData() - t1 = Table("SomeTable", m, autoload_with=testing.db) + t1 = Table("SomeTable", m, autoload_with=connection) eq_(t1.name, "SomeTable") assert t1.c.x is not None @@ -2111,47 +1978,43 @@ class CaseSensitiveTest(fixtures.TablesTest): lambda: testing.against(("mysql", "<", (5, 5))) and not testing.requires._has_mysql_fully_case_sensitive() ) - def test_reflect_via_fk(self): + def test_reflect_via_fk(self, connection): m = MetaData() - t2 = Table("SomeOtherTable", m, autoload_with=testing.db) + t2 = Table("SomeOtherTable", m, autoload_with=connection) eq_(t2.name, "SomeOtherTable") assert "SomeTable" in m.tables @testing.fails_if(testing.requires._has_mysql_fully_case_sensitive) @testing.fails_on_everything_except("sqlite", "mysql", "mssql") - def test_reflect_case_insensitive(self): + def test_reflect_case_insensitive(self, connection): m = MetaData() - t2 = Table("sOmEtAbLe", m, autoload_with=testing.db) + t2 = Table("sOmEtAbLe", m, autoload_with=connection) eq_(t2.name, "sOmEtAbLe") -class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): +class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TablesTest): __backend__ = True @classmethod - def setup_class(cls): - cls.metadata = MetaData() - cls.to_reflect = Table( + def define_tables(cls, metadata): + to_reflect = Table( "to_reflect", - cls.metadata, + metadata, Column("x", sa.Integer, primary_key=True, autoincrement=False), Column("y", sa.Integer), test_needs_fk=True, ) - cls.related = Table( + Table( "related", - cls.metadata, + metadata, Column("q", sa.Integer, sa.ForeignKey("to_reflect.x")), test_needs_fk=True, ) - sa.Index("some_index", cls.to_reflect.c.y) - cls.metadata.create_all(testing.db) + sa.Index("some_index", to_reflect.c.y) - @classmethod - def teardown_class(cls): - cls.metadata.drop_all(testing.db) - - def _do_test(self, col, update, assert_, tablename="to_reflect"): + def _do_test( + self, connection, col, update, assert_, tablename="to_reflect" + ): # load the actual Table class, not the test # wrapper from sqlalchemy.schema import Table @@ -2165,31 +2028,31 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): t = Table( tablename, m, - autoload_with=testing.db, + autoload_with=connection, listeners=[("column_reflect", column_reflect)], ) assert_(t) m = MetaData() self.event_listen(Table, "column_reflect", column_reflect) - t2 = Table(tablename, m, autoload_with=testing.db) + t2 = Table(tablename, m, autoload_with=connection) assert_(t2) - def test_override_key(self): + def test_override_key(self, connection): def assertions(table): eq_(table.c.YXZ.name, "x") eq_(set(table.primary_key), set([table.c.YXZ])) - self._do_test("x", {"key": "YXZ"}, assertions) + self._do_test(connection, "x", {"key": "YXZ"}, assertions) - def test_override_index(self): + def test_override_index(self, connection): def assertions(table): idx = list(table.indexes)[0] eq_(idx.columns, [table.c.YXZ]) - self._do_test("y", {"key": "YXZ"}, assertions) + self._do_test(connection, "y", {"key": "YXZ"}, assertions) - def test_override_key_fk(self): + def test_override_key_fk(self, connection): m = MetaData() def column_reflect(insp, table, column_info): @@ -2202,48 +2065,51 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): to_reflect = Table( "to_reflect", m, - autoload_with=testing.db, + autoload_with=connection, listeners=[("column_reflect", column_reflect)], ) related = Table( "related", m, - autoload_with=testing.db, + autoload_with=connection, listeners=[("column_reflect", column_reflect)], ) assert related.c.qyz.references(to_reflect.c.xyz) - def test_override_type(self): + def test_override_type(self, connection): def assert_(table): assert isinstance(table.c.x.type, sa.String) - self._do_test("x", {"type": sa.String}, assert_) + self._do_test(connection, "x", {"type": sa.String}, assert_) - def test_override_info(self): + def test_override_info(self, connection): self._do_test( + connection, "x", {"info": {"a": "b"}}, lambda table: eq_(table.c.x.info, {"a": "b"}), ) - def test_override_server_default_fetchedvalue(self): + def test_override_server_default_fetchedvalue(self, connection): my_default = FetchedValue() self._do_test( + connection, "x", {"default": my_default}, lambda table: eq_(table.c.x.server_default, my_default), ) - def test_override_server_default_default_clause(self): + def test_override_server_default_default_clause(self, connection): my_default = DefaultClause("1") self._do_test( + connection, "x", {"default": my_default}, lambda table: eq_(table.c.x.server_default, my_default), ) - def test_override_server_default_plain_text(self): + def test_override_server_default_plain_text(self, connection): my_default = "1" def assert_text_of_one(table): @@ -2254,9 +2120,11 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): ) eq_(str(table.c.x.server_default.arg), "1") - self._do_test("x", {"default": my_default}, assert_text_of_one) + self._do_test( + connection, "x", {"default": my_default}, assert_text_of_one + ) - def test_override_server_default_textclause(self): + def test_override_server_default_textclause(self, connection): my_default = sa.text("1") def assert_text_of_one(table): @@ -2267,9 +2135,11 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): ) eq_(str(table.c.x.server_default.arg), "1") - self._do_test("x", {"default": my_default}, assert_text_of_one) + self._do_test( + connection, "x", {"default": my_default}, assert_text_of_one + ) - def test_listen_metadata_obj(self): + def test_listen_metadata_obj(self, connection): m1 = MetaData() m2 = MetaData() @@ -2280,13 +2150,13 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): def go(insp, table, info): canary.append(info["name"]) - Table("related", m1, autoload_with=testing.db) + Table("related", m1, autoload_with=connection) - Table("related", m2, autoload_with=testing.db) + Table("related", m2, autoload_with=connection) eq_(canary, ["q", "x", "y"]) - def test_listen_metadata_cls(self): + def test_listen_metadata_cls(self, connection): m1 = MetaData() m2 = MetaData() @@ -2298,9 +2168,9 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): self.event_listen(MetaData, "column_reflect", go) - Table("related", m1, autoload_with=testing.db) + Table("related", m1, autoload_with=connection) - Table("related", m2, autoload_with=testing.db) + Table("related", m2, autoload_with=connection) eq_(canary, ["q", "x", "y", "q", "x", "y"]) diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index cd1e16ed9..49bf20baf 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -17,13 +17,18 @@ from sqlalchemy.ext.asyncio import engine as _async_engine from sqlalchemy.ext.asyncio import exc as asyncio_exc from sqlalchemy.testing import async_test from sqlalchemy.testing import combinations +from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_none from sqlalchemy.testing import is_not +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock +from sqlalchemy.testing import ne_ from sqlalchemy.util.concurrency import greenlet_spawn @@ -32,7 +37,7 @@ class EngineFixture(fixtures.TablesTest): @testing.fixture def async_engine(self): - return create_async_engine(testing.db.url) + return engines.testing_engine(asyncio=True) @classmethod def define_tables(cls, metadata): @@ -55,6 +60,12 @@ class EngineFixture(fixtures.TablesTest): class AsyncEngineTest(EngineFixture): __backend__ = True + @testing.fails("the failure is the test") + @async_test + async def test_we_are_definitely_running_async_tests(self, async_engine): + async with async_engine.connect() as conn: + eq_(await conn.scalar(text("select 1")), 2) + def test_proxied_attrs_engine(self, async_engine): sync_engine = async_engine.sync_engine @@ -65,6 +76,53 @@ class AsyncEngineTest(EngineFixture): eq_(async_engine.driver, sync_engine.driver) eq_(async_engine.echo, sync_engine.echo) + @async_test + async def test_engine_eq_ne(self, async_engine): + e2 = _async_engine.AsyncEngine(async_engine.sync_engine) + e3 = testing.engines.testing_engine(asyncio=True) + + eq_(async_engine, e2) + ne_(async_engine, e3) + + is_false(async_engine == None) + + @async_test + async def test_connection_info(self, async_engine): + + async with async_engine.connect() as conn: + conn.info["foo"] = "bar" + + eq_(conn.sync_connection.info, {"foo": "bar"}) + + @async_test + async def test_connection_eq_ne(self, async_engine): + + async with async_engine.connect() as conn: + c2 = _async_engine.AsyncConnection( + async_engine, conn.sync_connection + ) + + eq_(conn, c2) + + async with async_engine.connect() as c3: + ne_(conn, c3) + + is_false(conn == None) + + @async_test + async def test_transaction_eq_ne(self, async_engine): + + async with async_engine.connect() as conn: + t1 = await conn.begin() + + t2 = _async_engine.AsyncTransaction._from_existing_transaction( + conn, t1._proxied + ) + + eq_(t1, t2) + + is_false(t1 == None) + def test_clear_compiled_cache(self, async_engine): async_engine.sync_engine._compiled_cache["foo"] = "bar" eq_(async_engine.sync_engine._compiled_cache["foo"], "bar") @@ -97,6 +155,48 @@ class AsyncEngineTest(EngineFixture): eq_(conn.default_isolation_level, sync_conn.default_isolation_level) @async_test + async def test_transaction_accessor(self, async_engine): + async with async_engine.connect() as conn: + is_none(conn.get_transaction()) + is_false(conn.in_transaction()) + is_false(conn.in_nested_transaction()) + + trans = await conn.begin() + + is_true(conn.in_transaction()) + is_false(conn.in_nested_transaction()) + + is_( + trans.sync_transaction, conn.get_transaction().sync_transaction + ) + + nested = await conn.begin_nested() + + is_true(conn.in_transaction()) + is_true(conn.in_nested_transaction()) + + is_( + conn.get_nested_transaction().sync_transaction, + nested.sync_transaction, + ) + eq_(conn.get_nested_transaction(), nested) + + is_( + trans.sync_transaction, conn.get_transaction().sync_transaction + ) + + await nested.commit() + + is_true(conn.in_transaction()) + is_false(conn.in_nested_transaction()) + + await trans.rollback() + + is_none(conn.get_transaction()) + is_false(conn.in_transaction()) + is_false(conn.in_nested_transaction()) + + @async_test async def test_invalidate(self, async_engine): conn = await async_engine.connect() diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index 37e1b807b..e56adec4d 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -5,10 +5,10 @@ from sqlalchemy import select from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import selectinload from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import async_test +from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import is_ from sqlalchemy.testing import mock @@ -24,7 +24,7 @@ class AsyncFixture(_fixtures.FixtureTest): @testing.fixture def async_engine(self): - return create_async_engine(testing.db.url) + return engines.testing_engine(asyncio=True) @testing.fixture def async_session(self, async_engine): @@ -40,6 +40,11 @@ class AsyncSessionTest(AsyncFixture): bind=async_engine.sync_engine, ) + def test_info(self, async_session): + async_session.info["foo"] = "bar" + + eq_(async_session.sync_session.info, {"foo": "bar"}) + class AsyncSessionQueryTest(AsyncFixture): @async_test @@ -297,6 +302,107 @@ class AsyncSessionTransactionTest(AsyncFixture): is_(new_u_merged, u1) eq_(u1.name, "new u1") + @async_test + async def test_join_to_external_transaction(self, async_engine): + User = self.classes.User + + async with async_engine.connect() as conn: + t1 = await conn.begin() + + async_session = AsyncSession(conn) + + aconn = await async_session.connection() + + eq_(aconn.get_transaction(), t1) + + eq_(aconn, conn) + is_(aconn.sync_connection, conn.sync_connection) + + u1 = User(id=1, name="u1") + + async_session.add(u1) + + await async_session.commit() + + assert conn.in_transaction() + await conn.rollback() + + async with AsyncSession(async_engine) as async_session: + result = await async_session.execute(select(User)) + eq_(result.all(), []) + + @testing.requires.savepoints + @async_test + async def test_join_to_external_transaction_with_savepoints( + self, async_engine + ): + """This is the full 'join to an external transaction' recipe + implemented for async using savepoints. + + It's not particularly simple to understand as we have to switch between + async / sync APIs but it works and it's a start. + + """ + + User = self.classes.User + + async with async_engine.connect() as conn: + + await conn.begin() + + await conn.begin_nested() + + async_session = AsyncSession(conn) + + @event.listens_for( + async_session.sync_session, "after_transaction_end" + ) + def end_savepoint(session, transaction): + """here's an event. inside the event we write blocking + style code. wow will this be fun to try to explain :) + + """ + + if conn.closed: + return + + if not conn.in_nested_transaction(): + conn.sync_connection.begin_nested() + + aconn = await async_session.connection() + is_(aconn.sync_connection, conn.sync_connection) + + u1 = User(id=1, name="u1") + + async_session.add(u1) + + await async_session.commit() + + result = (await async_session.execute(select(User))).all() + eq_(len(result), 1) + + u2 = User(id=2, name="u2") + async_session.add(u2) + + await async_session.flush() + + result = (await async_session.execute(select(User))).all() + eq_(len(result), 2) + + # a rollback inside the session ultimately ends the savepoint + await async_session.rollback() + + # but the previous thing we "committed" is still in the DB + result = (await async_session.execute(select(User))).all() + eq_(len(result), 1) + + assert conn.in_transaction() + await conn.rollback() + + async with AsyncSession(async_engine) as async_session: + result = await async_session.execute(select(User)) + eq_(result.all(), []) + class AsyncEventTest(AsyncFixture): """The engine events all run in their normal synchronous context. diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index 77d4a80fe..2b80b753e 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -10,7 +10,6 @@ from sqlalchemy.ext.declarative import has_inherited_table from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import close_all_sessions from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship @@ -19,6 +18,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm.test_events import _RemoveListeners @@ -34,7 +34,7 @@ class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): def teardown(self): close_all_sessions() clear_mappers() - Base.metadata.drop_all() + Base.metadata.drop_all(testing.db) class ConcreteInhTest( @@ -49,8 +49,8 @@ class ConcreteInhTest( polymorphic=True, explicit_type=False, ): - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() e1 = Engineer(name="dilbert", primary_language="java") e2 = Engineer(name="wally", primary_language="c++") m1 = Manager(name="dogbert", golf_swing="fore!") @@ -342,7 +342,7 @@ class ConcreteInhTest( "concrete": True, } - Base.metadata.create_all() + Base.metadata.create_all(testing.db) sess = Session() sess.add(Engineer(name="d")) sess.commit() @@ -552,7 +552,7 @@ class ConcreteExtensionConfigTest( c_data = Column(String(50)) __mapper_args__ = {"polymorphic_identity": "c", "concrete": True} - Base.metadata.create_all() + Base.metadata.create_all(testing.db) sess = Session() sess.add_all( [ diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py index 504025d6f..d7fcbf9e8 100644 --- a/test/ext/declarative/test_reflection.py +++ b/test/ext/declarative/test_reflection.py @@ -1,6 +1,5 @@ from sqlalchemy import ForeignKey from sqlalchemy import Integer -from sqlalchemy import MetaData from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.ext.declarative import DeferredReflection @@ -26,7 +25,7 @@ class DeclarativeReflectionBase(fixtures.TablesTest): def setup(self): global Base, registry - registry = decl.registry(metadata=MetaData(bind=testing.db)) + registry = decl.registry() Base = registry.generate_base() def teardown(self): @@ -102,7 +101,7 @@ class DeferredReflectionTest(DeferredReflectBase): u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session() + sess = create_session(testing.db) sess.add(u1) sess.flush() sess.expunge_all() @@ -192,7 +191,7 @@ class DeferredReflectionTest(DeferredReflectBase): return {"primary_key": cls.__table__.c.id} DeferredReflection.prepare(testing.db) - sess = Session() + sess = Session(testing.db) sess.add_all( [User(name="G"), User(name="Q"), User(name="A"), User(name="C")] ) @@ -256,7 +255,7 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase): u1 = User(name="u1", items=[Item(name="i1"), Item(name="i2")]) - sess = Session() + sess = Session(testing.db) sess.add(u1) sess.commit() diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index df27c8d27..b1f5cc956 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -21,7 +21,6 @@ from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import collections from sqlalchemy.orm import composite from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import Session @@ -34,6 +33,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing.assertions import expect_warnings +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.mock import call from sqlalchemy.testing.mock import Mock from sqlalchemy.testing.schema import Column @@ -201,7 +201,7 @@ class _CollectionOperations(fixtures.TestBase): def setup(self): collection_class = self.collection_class - metadata = MetaData(testing.db) + metadata = MetaData() parents_table = Table( "Parent", @@ -254,14 +254,14 @@ class _CollectionOperations(fixtures.TestBase): ) mapper(Child, children_table) - metadata.create_all() + metadata.create_all(testing.db) self.metadata = metadata - self.session = create_session() + self.session = fixture_session() self.Parent, self.Child = Parent, Child def teardown(self): - self.metadata.drop_all() + self.metadata.drop_all(testing.db) def roundtrip(self, obj): if obj not in self.session: @@ -886,7 +886,7 @@ class CustomObjectTest(_CollectionOperations): class ProxyFactoryTest(ListTest): def setup(self): - metadata = MetaData(testing.db) + metadata = MetaData() parents_table = Table( "Parent", @@ -940,10 +940,10 @@ class ProxyFactoryTest(ListTest): ) mapper(Child, children_table) - metadata.create_all() + metadata.create_all(testing.db) self.metadata = metadata - self.session = create_session() + self.session = fixture_session() self.Parent, self.Child = Parent, Child def test_sequence_ops(self): @@ -1003,8 +1003,8 @@ class ScalarTest(fixtures.TestBase): ) mapper(Child, children_table) - metadata.create_all() - session = create_session() + metadata.create_all(testing.db) + session = fixture_session() def roundtrip(obj): if obj not in session: @@ -1158,7 +1158,7 @@ class ScalarTest(fixtures.TestBase): class LazyLoadTest(fixtures.TestBase): def setup(self): - metadata = MetaData(testing.db) + metadata = MetaData() parents_table = Table( "Parent", @@ -1190,15 +1190,15 @@ class LazyLoadTest(fixtures.TestBase): self.name = name mapper(Child, children_table) - metadata.create_all() + metadata.create_all(testing.db) self.metadata = metadata - self.session = create_session() + self.session = fixture_session() self.Parent, self.Child = Parent, Child self.table = parents_table def teardown(self): - self.metadata.drop_all() + self.metadata.drop_all(testing.db) def roundtrip(self, obj): self.session.add(obj) @@ -1369,7 +1369,7 @@ class ReconstitutionTest(fixtures.MappedTest): properties=dict(children=relationship(Child)), ) mapper(Child, self.tables.children) - session = create_session() + session = fixture_session() def add_child(parent_name, child_name): parent = session.query(Parent).filter_by(name=parent_name).one() @@ -3367,7 +3367,7 @@ class ProxyHybridTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_comparator_ambiguous(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() self.assert_compile( s.query(A).filter(A.b_data.any()), "SELECT a.id AS a_id FROM a WHERE EXISTS " @@ -3377,7 +3377,7 @@ class ProxyHybridTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_explicit_expr(self): (C,) = self.classes("C") - s = Session() + s = fixture_session() self.assert_compile( s.query(C).filter_by(attr=5), "SELECT c.id AS c_id, c.b_id AS c_b_id FROM c WHERE EXISTS " diff --git a/test/ext/test_automap.py b/test/ext/test_automap.py index da0e7c133..bddb42b03 100644 --- a/test/ext/test_automap.py +++ b/test/ext/test_automap.py @@ -30,7 +30,7 @@ class AutomapTest(fixtures.MappedTest): FixtureTest.define_tables(metadata) def test_relationship_o2m_default(self): - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) Base.prepare() User = Base.classes.users @@ -41,7 +41,7 @@ class AutomapTest(fixtures.MappedTest): assert a1.users is u1 def test_relationship_explicit_override_o2m(self): - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) prop = relationship("addresses", collection_class=set) class User(Base): @@ -58,7 +58,7 @@ class AutomapTest(fixtures.MappedTest): assert a1.user is u1 def test_exception_prepare_not_called(self): - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) class User(Base): __tablename__ = "users" @@ -75,7 +75,7 @@ class AutomapTest(fixtures.MappedTest): ) def test_relationship_explicit_override_m2o(self): - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) prop = relationship("users") @@ -93,7 +93,7 @@ class AutomapTest(fixtures.MappedTest): assert a1.users is u1 def test_relationship_self_referential(self): - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) Base.prepare() Node = Base.classes.nodes @@ -110,7 +110,7 @@ class AutomapTest(fixtures.MappedTest): This test verifies that prepare can accept an optional schema argument and pass it to reflect. """ - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) engine_mock = Mock() with patch.object(Base.metadata, "reflect") as reflect_mock: Base.prepare(autoload_with=engine_mock, schema="some_schema") @@ -128,7 +128,7 @@ class AutomapTest(fixtures.MappedTest): This test verifies that prepare passes a default None if no schema is provided. """ - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) engine_mock = Mock() with patch.object(Base.metadata, "reflect") as reflect_mock: Base.prepare(autoload_with=engine_mock) @@ -140,7 +140,7 @@ class AutomapTest(fixtures.MappedTest): ) def test_naming_schemes(self): - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) def classname_for_table(base, tablename, table): return str("cls_" + tablename) @@ -170,7 +170,7 @@ class AutomapTest(fixtures.MappedTest): assert a1.scalar_cls_users is u1 def test_relationship_m2m(self): - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) Base.prepare() @@ -182,7 +182,7 @@ class AutomapTest(fixtures.MappedTest): assert o1 in i1.orders_collection def test_relationship_explicit_override_forwards_m2m(self): - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) class Order(Base): __tablename__ = "orders" @@ -205,7 +205,7 @@ class AutomapTest(fixtures.MappedTest): assert o1 in i1.order_collection def test_relationship_pass_params(self): - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) mock = Mock() @@ -269,7 +269,7 @@ class CascadeTest(fixtures.MappedTest): ) def test_o2m_relationship_cascade(self): - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) Base.prepare() configure_mappers() diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py index eff3ccdae..71fabc629 100644 --- a/test/ext/test_baked.py +++ b/test/ext/test_baked.py @@ -18,6 +18,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not from sqlalchemy.testing import mock +from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures @@ -42,7 +43,7 @@ class StateChangeTest(BakedTest): def test_initial_key(self): User = self.classes.User - session = Session() + session = fixture_session() def l1(): return session.query(User) @@ -53,7 +54,7 @@ class StateChangeTest(BakedTest): def test_inplace_add(self): User = self.classes.User - session = Session() + session = fixture_session() def l1(): return session.query(User) @@ -73,7 +74,7 @@ class StateChangeTest(BakedTest): def test_inplace_add_operator(self): User = self.classes.User - session = Session() + session = fixture_session() def l1(): return session.query(User) @@ -90,7 +91,7 @@ class StateChangeTest(BakedTest): def test_chained_add(self): User = self.classes.User - session = Session() + session = fixture_session() def l1(): return session.query(User) @@ -108,7 +109,7 @@ class StateChangeTest(BakedTest): def test_chained_add_operator(self): User = self.classes.User - session = Session() + session = fixture_session() def l1(): return session.query(User) @@ -138,7 +139,7 @@ class LikeQueryTest(BakedTest): bq = self.bakery(lambda s: s.query(User)) bq += lambda q: q.filter(User.name == "asdf") - eq_(bq(Session()).first(), None) + eq_(bq(fixture_session()).first(), None) def test_first_multiple_result(self): User = self.classes.User @@ -146,7 +147,7 @@ class LikeQueryTest(BakedTest): bq = self.bakery(lambda s: s.query(User.id)) bq += lambda q: q.filter(User.name.like("%ed%")).order_by(User.id) - eq_(bq(Session()).first(), (8,)) + eq_(bq(fixture_session()).first(), (8,)) def test_one_or_none_no_result(self): User = self.classes.User @@ -154,7 +155,7 @@ class LikeQueryTest(BakedTest): bq = self.bakery(lambda s: s.query(User)) bq += lambda q: q.filter(User.name == "asdf") - eq_(bq(Session()).one_or_none(), None) + eq_(bq(fixture_session()).one_or_none(), None) def test_one_or_none_result(self): User = self.classes.User @@ -162,7 +163,7 @@ class LikeQueryTest(BakedTest): bq = self.bakery(lambda s: s.query(User)) bq += lambda q: q.filter(User.name == "ed") - u1 = bq(Session()).one_or_none() + u1 = bq(fixture_session()).one_or_none() eq_(u1.name, "ed") def test_one_or_none_multiple_result(self): @@ -174,7 +175,7 @@ class LikeQueryTest(BakedTest): assert_raises_message( orm_exc.MultipleResultsFound, "Multiple rows were found when one or none was required", - bq(Session()).one_or_none, + bq(fixture_session()).one_or_none, ) def test_one_no_result(self): @@ -186,7 +187,7 @@ class LikeQueryTest(BakedTest): assert_raises_message( orm_exc.NoResultFound, "No row was found when one was required", - bq(Session()).one, + bq(fixture_session()).one, ) def test_one_result(self): @@ -195,7 +196,7 @@ class LikeQueryTest(BakedTest): bq = self.bakery(lambda s: s.query(User)) bq += lambda q: q.filter(User.name == "ed") - u1 = bq(Session()).one() + u1 = bq(fixture_session()).one() eq_(u1.name, "ed") def test_one_multiple_result(self): @@ -207,7 +208,7 @@ class LikeQueryTest(BakedTest): assert_raises_message( orm_exc.MultipleResultsFound, "Multiple rows were found when exactly one was required", - bq(Session()).one, + bq(fixture_session()).one, ) def test_get(self): @@ -215,7 +216,7 @@ class LikeQueryTest(BakedTest): bq = self.bakery(lambda s: s.query(User)) - sess = Session() + sess = fixture_session() def go(): u1 = bq(sess).get(7) @@ -242,7 +243,7 @@ class LikeQueryTest(BakedTest): bq = self.bakery(lambda s: s.query(User.id)) - sess = Session() + sess = fixture_session() bq += lambda q: q.filter(User.id == 7) @@ -253,7 +254,7 @@ class LikeQueryTest(BakedTest): bq = self.bakery(lambda s: s.query(User)) - sess = Session() + sess = fixture_session() eq_(bq(sess).count(), 4) @@ -272,7 +273,7 @@ class LikeQueryTest(BakedTest): bq = self.bakery(lambda s: s.query(User)) - sess = Session() + sess = fixture_session() eq_(bq(sess).count(), 4) @@ -306,7 +307,7 @@ class LikeQueryTest(BakedTest): bq = self.bakery(lambda s: s.query(AddressUser)) - sess = Session() + sess = fixture_session() def go(): u1 = bq(sess).get((10, None)) @@ -329,7 +330,7 @@ class LikeQueryTest(BakedTest): bq = self.bakery(lambda s: s.query(User)) for i in range(5): - sess = Session() + sess = fixture_session() u1 = bq(sess).get(7) eq_(u1.name, "jack") sess.close() @@ -343,7 +344,7 @@ class LikeQueryTest(BakedTest): del inspect(User).__dict__["_get_clause"] for i in range(5): - sess = Session() + sess = fixture_session() u1 = bq(sess).get(7) eq_(u1.name, "jack") sess.close() @@ -463,7 +464,7 @@ class ResultTest(BakedTest): bq2 = self.bakery(fn, 8) for i in range(3): - session = Session(autocommit=True) + session = fixture_session() eq_(bq1(session).all(), [(7,)]) eq_(bq2(session).all(), [(8,)]) @@ -476,7 +477,7 @@ class ResultTest(BakedTest): ) for i in range(3): - session = Session(autocommit=True) + session = fixture_session() eq_( bq(session).all(), [(7, "jack"), (8, "ed"), (9, "fred"), (10, "chuck")], @@ -490,7 +491,7 @@ class ResultTest(BakedTest): ) bq += lambda q: q.limit(bindparam("limit")).offset(bindparam("offset")) - session = Session(autocommit=True) + session = fixture_session() for i in range(4): for limit, offset, exp in [ @@ -522,7 +523,7 @@ class ResultTest(BakedTest): bq += fn2 - sess = Session(autocommit=True, enable_baked_queries=False) + sess = fixture_session(autocommit=True, enable_baked_queries=False) eq_(bq.add_criteria(fn3)(sess).params(id=7).all(), [(7, "jack")]) eq_( @@ -562,7 +563,7 @@ class ResultTest(BakedTest): bq += fn2 - sess = Session(autocommit=True) + sess = fixture_session() eq_( bq.spoil(full=True).add_criteria(fn3)(sess).params(id=7).all(), [(7, "jack")], @@ -609,7 +610,7 @@ class ResultTest(BakedTest): bq += fn2 - sess = Session(autocommit=True) + sess = fixture_session() eq_( bq.spoil().add_criteria(fn3)(sess).params(id=7).all(), [(7, "jack")], @@ -639,7 +640,7 @@ class ResultTest(BakedTest): bq += lambda q: q._from_self().with_entities(func.count(User.id)) for i in range(3): - session = Session(autocommit=True) + session = fixture_session() eq_(bq(session).all(), [(4,)]) def test_conditional_step(self): @@ -674,7 +675,7 @@ class ResultTest(BakedTest): bq += lambda q: q._from_self().with_entities( func.count(User.id) ) - sess = Session(autocommit=True) + sess = fixture_session() result = bq(sess).all() if cond4: if cond1: @@ -729,7 +730,7 @@ class ResultTest(BakedTest): if cond1 else (lambda q: q.filter(User.name == "jack")) ) # noqa - sess = Session(autocommit=True) + sess = fixture_session() result = bq(sess).all() if cond1: @@ -754,7 +755,7 @@ class ResultTest(BakedTest): main_bq += lambda q: q.filter(sub_bq.to_query(q).exists()) main_bq += lambda q: q.order_by(Address.id) - sess = Session() + sess = fixture_session() result = main_bq(sess).all() eq_(result, [(2,), (3,), (4,)]) @@ -775,7 +776,7 @@ class ResultTest(BakedTest): ) main_bq += lambda q: q.order_by(Address.id) - sess = Session() + sess = fixture_session() result = main_bq(sess).all() eq_(result, [(2, "ed"), (3, "ed"), (4, "ed")]) @@ -840,7 +841,7 @@ class ResultTest(BakedTest): print("HI----") bq = base_bq._clone() - sess = Session() + sess = fixture_session() if cond1: bq += lambda q: q.filter(User.name == "jack") @@ -908,7 +909,7 @@ class ResultTest(BakedTest): bq += lambda q: q.options(subqueryload(User.addresses)) bq += lambda q: q.order_by(User.id) bq += lambda q: q.filter(User.name == bindparam("name")) - sess = Session() + sess = fixture_session() def set_params(q): return q.params(name="jack") @@ -950,7 +951,7 @@ class ResultTest(BakedTest): bq += lambda q: q.options(subqueryload(User.addresses)) bq += lambda q: q.order_by(User.id) bq += lambda q: q.filter(User.name == bindparam("name")) - sess = Session() + sess = fixture_session() def set_params(q): return q.params(name="jack") @@ -1007,7 +1008,7 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest): cache[cache_key] = retval = createfunc().freeze() return retval() - s1 = Session(query_cls=CachingQuery) + s1 = fixture_session(query_cls=CachingQuery) @event.listens_for(s1, "do_orm_execute", retval=True) def do_orm_execute(orm_context): diff --git a/test/ext/test_deprecations.py b/test/ext/test_deprecations.py index b209de36d..b6976299b 100644 --- a/test/ext/test_deprecations.py +++ b/test/ext/test_deprecations.py @@ -16,7 +16,7 @@ class AutomapTest(fixtures.MappedTest): FixtureTest.define_tables(metadata) def test_reflect_true(self): - Base = automap_base(metadata=self.metadata) + Base = automap_base(metadata=self.tables_test_metadata) engine_mock = mock.Mock() with mock.patch.object(Base.metadata, "reflect") as reflect_mock: with testing.expect_deprecated( diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index e46c65ff0..038bdd83e 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -47,7 +47,7 @@ class ShardTest(object): db1, db2, db3, db4 = self._dbs = self._init_dbs() - meta = self.metadata = MetaData() + meta = self.tables_test_metadata = MetaData() ids = Table("ids", meta, Column("nextid", Integer, nullable=False)) def id_generator(ctx): @@ -786,7 +786,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase): os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) with self.postgresql_engine.begin() as conn: - self.metadata.drop_all(conn) + self.tables_test_metadata.drop_all(conn) for i in [2, 4]: conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,)) @@ -898,7 +898,7 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): ) for db in (db1, db2): - self.metadata.create_all(db) + self.tables_test_metadata.create_all(db) self.dbs = [db1, db2] diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index fbac35f7e..048a8b52d 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -17,6 +17,7 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -70,14 +71,14 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): def test_query(self): A = self._fixture() - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(A.value), "SELECT a.value AS a_value FROM a" ) def test_aliased_query(self): A = self._fixture() - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(aliased(A).value), "SELECT a_1.value AS a_1_value FROM a AS a_1", @@ -85,7 +86,7 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): def test_aliased_filter(self): A = self._fixture() - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(aliased(A)).filter_by(value="foo"), "SELECT a_1.value AS a_1_value, a_1.id AS a_1_id " @@ -183,7 +184,7 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): def test_any(self): A, B = self._relationship_fixture() - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(B).filter(B.as_.any(value=5)), "SELECT b.id AS b_id FROM b WHERE EXISTS " @@ -200,7 +201,7 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): def test_query(self): A = self._fixture() - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(A).filter_by(value="foo"), "SELECT a.value AS a_value, a.id AS a_id " @@ -209,7 +210,7 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): def test_aliased_query(self): A = self._fixture() - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(aliased(A)).filter_by(value="foo"), "SELECT a_1.value AS a_1_value, a_1.id AS a_1_id " @@ -489,7 +490,7 @@ class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): def test_query(self): A = self._fixture() - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(A).filter(A.value(5) == "foo"), "SELECT a.value AS a_value, a.id AS a_id " @@ -498,7 +499,7 @@ class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): def test_aliased_query(self): A = self._fixture() - sess = Session() + sess = fixture_session() a1 = aliased(A) self.assert_compile( sess.query(a1).filter(a1.value(5) == "foo"), @@ -508,7 +509,7 @@ class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): def test_query_col(self): A = self._fixture() - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(A.value(5)), "SELECT foo(a.value, :foo_1) + :foo_2 AS anon_1 FROM a", @@ -516,7 +517,7 @@ class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): def test_aliased_query_col(self): A = self._fixture() - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(aliased(A).value(5)), "SELECT foo(a_1.value, :foo_1) + :foo_2 AS anon_1 FROM a AS a_1", @@ -610,7 +611,7 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_evaluate_hybrid_attr_indirect(self): Person = self.classes.Person - s = Session() + s = fixture_session() jill = s.query(Person).get(3) s.query(Person).update( @@ -621,7 +622,7 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_evaluate_hybrid_attr_plain(self): Person = self.classes.Person - s = Session() + s = fixture_session() jill = s.query(Person).get(3) s.query(Person).update( @@ -632,7 +633,7 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_fetch_hybrid_attr_indirect(self): Person = self.classes.Person - s = Session() + s = fixture_session() jill = s.query(Person).get(3) s.query(Person).update( @@ -643,7 +644,7 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_fetch_hybrid_attr_plain(self): Person = self.classes.Person - s = Session() + s = fixture_session() jill = s.query(Person).get(3) s.query(Person).update( @@ -654,7 +655,7 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_evaluate_hybrid_attr_w_update_expr(self): Person = self.classes.Person - s = Session() + s = fixture_session() jill = s.query(Person).get(3) s.query(Person).update( @@ -665,7 +666,7 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_fetch_hybrid_attr_w_update_expr(self): Person = self.classes.Person - s = Session() + s = fixture_session() jill = s.query(Person).get(3) s.query(Person).update( @@ -676,7 +677,7 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_evaluate_hybrid_attr_indirect_w_update_expr(self): Person = self.classes.Person - s = Session() + s = fixture_session() jill = s.query(Person).get(3) s.query(Person).update( @@ -813,7 +814,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): def test_query_one(self): BankAccount, Amount = self.BankAccount, self.Amount - session = Session() + session = fixture_session() query = session.query(BankAccount).filter( BankAccount.balance == Amount(10000, "cad") @@ -829,7 +830,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): def test_query_two(self): BankAccount, Amount = self.BankAccount, self.Amount - session = Session() + session = fixture_session() # alternatively we can do the calc on the DB side. query = ( @@ -858,7 +859,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): def test_query_three(self): BankAccount = self.BankAccount - session = Session() + session = fixture_session() query = session.query(BankAccount).filter( BankAccount.balance.as_currency("cad") @@ -879,7 +880,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): def test_query_four(self): BankAccount = self.BankAccount - session = Session() + session = fixture_session() # 4c. query all amounts, converting to "CAD" on the DB side query = session.query(BankAccount.balance.as_currency("cad").amount) @@ -892,7 +893,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): def test_query_five(self): BankAccount = self.BankAccount - session = Session() + session = fixture_session() # 4d. average balance in EUR query = session.query(func.avg(BankAccount.balance.as_currency("eur"))) diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index acb0ad490..eba2ac0cb 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -15,7 +15,6 @@ from sqlalchemy.orm import attributes from sqlalchemy.orm import column_property from sqlalchemy.orm import composite from sqlalchemy.orm import mapper -from sqlalchemy.orm import Session from sqlalchemy.orm.instrumentation import ClassManager from sqlalchemy.orm.mapper import Mapper from sqlalchemy.testing import assert_raises @@ -23,6 +22,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import picklers @@ -106,7 +106,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): mapper(Foo, foo) def test_coerce_none(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=None) sess.add(f1) sess.commit() @@ -121,7 +121,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): ) def test_in_place_mutation(self): - sess = Session() + sess = fixture_session() f1 = Foo(data={"a": "b"}) sess.add(f1) @@ -149,7 +149,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): ) def test_clear(self): - sess = Session() + sess = fixture_session() f1 = Foo(data={"a": "b"}) sess.add(f1) @@ -161,7 +161,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): eq_(f1.data, {}) def test_update(self): - sess = Session() + sess = fixture_session() f1 = Foo(data={"a": "b"}) sess.add(f1) @@ -173,7 +173,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): eq_(f1.data, {"a": "z"}) def test_pop(self): - sess = Session() + sess = fixture_session() f1 = Foo(data={"a": "b", "c": "d"}) sess.add(f1) @@ -187,7 +187,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): eq_(f1.data, {"c": "d"}) def test_pop_default(self): - sess = Session() + sess = fixture_session() f1 = Foo(data={"a": "b", "c": "d"}) sess.add(f1) @@ -200,7 +200,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): eq_(f1.data, {"c": "d"}) def test_popitem(self): - sess = Session() + sess = fixture_session() orig = {"a": "b", "c": "d"} @@ -220,7 +220,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): eq_(f1.data, orig) def test_setdefault(self): - sess = Session() + sess = fixture_session() f1 = Foo(data={"a": "b"}) sess.add(f1) @@ -237,7 +237,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): eq_(f1.data, {"a": "b", "c": "d"}) def test_replace(self): - sess = Session() + sess = fixture_session() f1 = Foo(data={"a": "b"}) sess.add(f1) sess.flush() @@ -247,7 +247,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): eq_(f1.data, {"b": "c"}) def test_replace_itself_still_ok(self): - sess = Session() + sess = fixture_session() f1 = Foo(data={"a": "b"}) sess.add(f1) sess.flush() @@ -258,7 +258,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): eq_(f1.data, {"a": "b", "b": "c"}) def test_pickle_parent(self): - sess = Session() + sess = fixture_session() f1 = Foo(data={"a": "b"}) sess.add(f1) @@ -267,14 +267,14 @@ class _MutableDictTestBase(_MutableDictTestFixture): sess.close() for loads, dumps in picklers(): - sess = Session() + sess = fixture_session() f2 = loads(dumps(f1)) sess.add(f2) f2.data["a"] = "c" assert f2 in sess.dirty def test_unrelated_flush(self): - sess = Session() + sess = fixture_session() f1 = Foo(data={"a": "b"}, unrelated_data="unrelated") sess.add(f1) sess.flush() @@ -285,7 +285,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): eq_(f1.data["a"], "c") def _test_non_mutable(self): - sess = Session() + sess = fixture_session() f1 = Foo(non_mutable_data={"a": "b"}) sess.add(f1) @@ -328,7 +328,7 @@ class _MutableListTestBase(_MutableListTestFixture): mapper(Foo, foo) def test_coerce_none(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=None) sess.add(f1) sess.commit() @@ -343,7 +343,7 @@ class _MutableListTestBase(_MutableListTestFixture): ) def test_in_place_mutation(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2]) sess.add(f1) @@ -355,7 +355,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [3, 2]) def test_in_place_slice_mutation(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2, 3, 4]) sess.add(f1) @@ -367,7 +367,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [1, 5, 6, 4]) def test_del_slice(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2, 3, 4]) sess.add(f1) @@ -382,7 +382,7 @@ class _MutableListTestBase(_MutableListTestFixture): if not hasattr(list, "clear"): # py2 list doesn't have 'clear' return - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2]) sess.add(f1) @@ -394,7 +394,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, []) def test_pop(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2, 3]) sess.add(f1) @@ -409,7 +409,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [2]) def test_append(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2]) sess.add(f1) @@ -421,7 +421,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [1, 2, 5]) def test_extend(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2]) sess.add(f1) @@ -433,7 +433,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [1, 2, 5]) def test_operator_extend(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2]) sess.add(f1) @@ -445,7 +445,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [1, 2, 5]) def test_insert(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2]) sess.add(f1) @@ -457,7 +457,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [1, 5, 2]) def test_remove(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2, 3]) sess.add(f1) @@ -469,7 +469,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [1, 3]) def test_sort(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 3, 2]) sess.add(f1) @@ -481,7 +481,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [1, 2, 3]) def test_sort_w_key(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 3, 2]) sess.add(f1) @@ -493,7 +493,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [3, 2, 1]) def test_sort_w_reverse_kwarg(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 3, 2]) sess.add(f1) @@ -505,7 +505,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [3, 2, 1]) def test_reverse(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 3, 2]) sess.add(f1) @@ -517,7 +517,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [2, 3, 1]) def test_pickle_parent(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2]) sess.add(f1) @@ -526,14 +526,14 @@ class _MutableListTestBase(_MutableListTestFixture): sess.close() for loads, dumps in picklers(): - sess = Session() + sess = fixture_session() f2 = loads(dumps(f1)) sess.add(f2) f2.data[0] = 3 assert f2 in sess.dirty def test_unrelated_flush(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=[1, 2], unrelated_data="unrelated") sess.add(f1) sess.flush() @@ -635,7 +635,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): mapper(Foo, foo) def test_coerce_none(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=None) sess.add(f1) sess.commit() @@ -650,7 +650,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): ) def test_clear(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2])) sess.add(f1) @@ -662,7 +662,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set()) def test_pop(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1])) sess.add(f1) @@ -676,7 +676,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set()) def test_add(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2])) sess.add(f1) @@ -688,7 +688,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set([1, 2, 5])) def test_update(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2])) sess.add(f1) @@ -700,7 +700,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set([1, 2, 5])) def test_binary_update(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2])) sess.add(f1) @@ -712,7 +712,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set([1, 2, 5])) def test_intersection_update(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2])) sess.add(f1) @@ -724,7 +724,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set([2])) def test_binary_intersection_update(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2])) sess.add(f1) @@ -736,7 +736,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set([2])) def test_difference_update(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2])) sess.add(f1) @@ -748,7 +748,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set([1])) def test_operator_difference_update(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2])) sess.add(f1) @@ -760,7 +760,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set([1])) def test_symmetric_difference_update(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2])) sess.add(f1) @@ -772,7 +772,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set([1, 5])) def test_binary_symmetric_difference_update(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2])) sess.add(f1) @@ -784,7 +784,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set([1, 5])) def test_remove(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2, 3])) sess.add(f1) @@ -796,7 +796,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set([1, 3])) def test_discard(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2, 3])) sess.add(f1) @@ -813,7 +813,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): eq_(f1.data, set([1, 3])) def test_pickle_parent(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2])) sess.add(f1) @@ -822,14 +822,14 @@ class _MutableSetTestBase(_MutableSetTestFixture): sess.close() for loads, dumps in picklers(): - sess = Session() + sess = fixture_session() f2 = loads(dumps(f1)) sess.add(f2) f2.data.add(3) assert f2 in sess.dirty def test_unrelated_flush(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=set([1, 2]), unrelated_data="unrelated") sess.add(f1) sess.flush() @@ -873,7 +873,7 @@ class MutableColumnDefaultTest(_MutableDictTestFixture, fixtures.MappedTest): def test_evt_on_flush_refresh(self): # test for #3427 - sess = Session() + sess = fixture_session() f1 = Foo() sess.add(f1) @@ -1080,7 +1080,7 @@ class MutableAssocWithAttrInheritTest( MutableDict.associate_with_attribute(Foo.data) def test_in_place_mutation(self): - sess = Session() + sess = fixture_session() f1 = SubFoo(data={"a": "b"}) sess.add(f1) @@ -1092,7 +1092,7 @@ class MutableAssocWithAttrInheritTest( eq_(f1.data, {"a": "c"}) def test_replace(self): - sess = Session() + sess = fixture_session() f1 = SubFoo(data={"a": "b"}) sess.add(f1) sess.flush() @@ -1213,7 +1213,7 @@ class CustomMutableAssociationScalarJSONTest( pass def test_coerce(self): - sess = Session() + sess = fixture_session() f1 = Foo(data={"a": "b"}) sess.add(f1) sess.flush() @@ -1283,7 +1283,7 @@ class MutableCompositeColumnDefaultTest( def test_evt_on_flush_refresh(self): # this still worked prior to #3427 being fixed in any case - sess = Session() + sess = fixture_session() f1 = Foo(data=self.Point(None, None)) sess.add(f1) @@ -1325,7 +1325,7 @@ class MutableCompositesTest(_CompositeTestBase, fixtures.MappedTest): ) def test_in_place_mutation(self): - sess = Session() + sess = fixture_session() d = Point(3, 4) f1 = Foo(data=d) sess.add(f1) @@ -1337,7 +1337,7 @@ class MutableCompositesTest(_CompositeTestBase, fixtures.MappedTest): eq_(f1.data, Point(3, 5)) def test_pickle_of_parent(self): - sess = Session() + sess = fixture_session() d = Point(3, 4) f1 = Foo(data=d) sess.add(f1) @@ -1348,14 +1348,14 @@ class MutableCompositesTest(_CompositeTestBase, fixtures.MappedTest): sess.close() for loads, dumps in picklers(): - sess = Session() + sess = fixture_session() f2 = loads(dumps(f1)) sess.add(f2) f2.data.y = 12 assert f2 in sess.dirty def test_set_none(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=None) sess.add(f1) sess.commit() @@ -1377,7 +1377,7 @@ class MutableCompositesTest(_CompositeTestBase, fixtures.MappedTest): ) def test_unrelated_flush(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=Point(3, 4), unrelated_data="unrelated") sess.add(f1) sess.flush() @@ -1407,7 +1407,7 @@ class MutableCompositeCallableTest(_CompositeTestBase, fixtures.MappedTest): ) def test_basic(self): - sess = Session() + sess = fixture_session() f1 = Foo(data=Point(3, 4)) sess.add(f1) sess.flush() @@ -1442,7 +1442,7 @@ class MutableCompositeCustomCoerceTest( eq_(f.data, Point(3, 4)) def test_round_trip_ok(self): - sess = Session() + sess = fixture_session() f = Foo() f.data = (3, 4) @@ -1483,7 +1483,7 @@ class MutableInheritedCompositesTest(_CompositeTestBase, fixtures.MappedTest): mapper(SubFoo, subfoo, inherits=Foo) def test_in_place_mutation_subclass(self): - sess = Session() + sess = fixture_session() d = Point(3, 4) f1 = SubFoo(data=d) sess.add(f1) @@ -1495,7 +1495,7 @@ class MutableInheritedCompositesTest(_CompositeTestBase, fixtures.MappedTest): eq_(f1.data, Point(3, 5)) def test_pickle_of_parent_subclass(self): - sess = Session() + sess = fixture_session() d = Point(3, 4) f1 = SubFoo(data=d) sess.add(f1) @@ -1506,7 +1506,7 @@ class MutableInheritedCompositesTest(_CompositeTestBase, fixtures.MappedTest): sess.close() for loads, dumps in picklers(): - sess = Session() + sess = fixture_session() f2 = loads(dumps(f1)) sess.add(f2) f2.data.y = 12 diff --git a/test/ext/test_orderinglist.py b/test/ext/test_orderinglist.py index a1a6c6918..f23d6cb57 100644 --- a/test/ext/test_orderinglist.py +++ b/test/ext/test_orderinglist.py @@ -4,11 +4,11 @@ from sqlalchemy import MetaData from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.ext.orderinglist import ordering_list -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import create_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import picklers @@ -64,7 +64,7 @@ class OrderingListTest(fixtures.TestBase): global metadata, slides_table, bullets_table, Slide, Bullet slides_table, bullets_table = None, None Slide, Bullet = None, None - metadata = MetaData(testing.db) + metadata = MetaData() def _setup(self, test_collection_class): """Build a relationship situation using the given @@ -120,10 +120,10 @@ class OrderingListTest(fixtures.TestBase): ) mapper(Bullet, bullets_table) - metadata.create_all() + metadata.create_all(testing.db) def teardown(self): - metadata.drop_all() + metadata.drop_all(testing.db) def test_append_no_reorder(self): self._setup( diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index 6d7b8da33..12e4255fa 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -67,7 +67,7 @@ class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest): @classmethod def setup_mappers(cls): global Session - Session = scoped_session(sessionmaker()) + Session = scoped_session(sessionmaker(testing.db)) mapper( User, users, @@ -141,9 +141,8 @@ class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest): serializer.dumps(expr, -1), users.metadata, None ) eq_(str(expr), str(re_expr)) - assert re_expr.bind is testing.db eq_( - re_expr.execute().fetchall(), + Session.connection().execute(re_expr).fetchall(), [(7, "jack"), (8, "ed"), (8, "ed"), (8, "ed"), (9, "fred")], ) diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index fd00717f4..4c005d336 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -19,7 +19,6 @@ from sqlalchemy.orm import close_all_sessions from sqlalchemy.orm import column_property from sqlalchemy.orm import composite from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import decl_base from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declared_attr @@ -44,6 +43,7 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.util import with_metaclass @@ -68,7 +68,7 @@ class DeclarativeTestBase( def teardown(self): close_all_sessions() clear_mappers() - Base.metadata.drop_all() + Base.metadata.drop_all(testing.db) class DeclarativeTest(DeclarativeTestBase): @@ -93,7 +93,7 @@ class DeclarativeTest(DeclarativeTestBase): "user_id", Integer, ForeignKey("users.id"), key="_user_id" ) - Base.metadata.create_all() + Base.metadata.create_all(testing.db) eq_(Address.__table__.c["id"].name, "id") eq_(Address.__table__.c["_email"].name, "email") @@ -102,7 +102,7 @@ class DeclarativeTest(DeclarativeTestBase): u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -193,14 +193,15 @@ class DeclarativeTest(DeclarativeTestBase): orm_exc.UnmappedClassError, "Class .*User has a deferred " "mapping on it. It is not yet usable as a mapped class.", - Session().query, + fixture_session().query, User, ) User.prepare() self.assert_compile( - Session().query(User), 'SELECT "user".id AS user_id FROM "user"' + fixture_session().query(User), + 'SELECT "user".id AS user_id FROM "user"', ) def test_unicode_string_resolve(self): @@ -598,8 +599,8 @@ class DeclarativeTest(DeclarativeTestBase): email = Column(String(50)) user_id = Column(Integer) # note no foreign key - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() u1 = User( name="ed", addresses=[ @@ -644,8 +645,8 @@ class DeclarativeTest(DeclarativeTestBase): ) name = Column(String(50)) - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() u1 = User(name="ed") sess.add(u1) sess.flush() @@ -693,7 +694,7 @@ class DeclarativeTest(DeclarativeTestBase): name = Column(String(50)) users = relationship("User", order_by="User.fullname") - s = Session() + s = fixture_session() self.assert_compile( s.query(Game).options(joinedload(Game.users)), "SELECT game.id AS game_id, game.name AS game_name, " @@ -738,7 +739,7 @@ class DeclarativeTest(DeclarativeTestBase): id = Column(Integer, primary_key=True) - s = Session() + s = fixture_session() self.assert_compile( s.query(A).join(A.d), "SELECT a.id AS a_id, a.b_id AS a_b_id FROM a JOIN " @@ -1061,8 +1062,8 @@ class DeclarativeTest(DeclarativeTestBase): # generally go downhill from there. class_mapper(User) - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() u1 = User( name="ed", addresses=[ @@ -1158,12 +1159,12 @@ class DeclarativeTest(DeclarativeTestBase): master_id = Column(None, ForeignKey(Master.id)) master = relationship(Master) - Base.metadata.create_all() + Base.metadata.create_all(testing.db) configure_mappers() assert class_mapper(Detail).get_property("master").strategy.use_get m1 = Master() d1 = Detail(master=m1) - sess = create_session() + sess = fixture_session() sess.add(d1) sess.flush() sess.expunge_all() @@ -1193,7 +1194,7 @@ class DeclarativeTest(DeclarativeTestBase): assert User.__table__.c.name in set(i.columns) # tables create fine - Base.metadata.create_all() + Base.metadata.create_all(testing.db) def test_add_prop(self): class User(Base, fixtures.ComparableEntity): @@ -1217,14 +1218,14 @@ class DeclarativeTest(DeclarativeTestBase): Address.user_id = Column( "user_id", Integer, ForeignKey("users.id"), key="_user_id" ) - Base.metadata.create_all() + Base.metadata.create_all(testing.db) eq_(Address.__table__.c["id"].name, "id") eq_(Address.__table__.c["_email"].name, "email") eq_(Address.__table__.c["_user_id"].name, "user_id") u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -1331,11 +1332,11 @@ class DeclarativeTest(DeclarativeTestBase): name = Column("name", String(50)) addresses = relationship("Address", order_by=Address.email) - Base.metadata.create_all() + Base.metadata.create_all(testing.db) u1 = User( name="u1", addresses=[Address(email="two"), Address(email="one")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -1370,11 +1371,11 @@ class DeclarativeTest(DeclarativeTestBase): "Address", order_by=(Address.email, Address.id) ) - Base.metadata.create_all() + Base.metadata.create_all(testing.db) u1 = User( name="u1", addresses=[Address(email="two"), Address(email="one")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -1403,7 +1404,7 @@ class DeclarativeTest(DeclarativeTestBase): reg = registry(metadata=Base.metadata) reg.mapped(User) reg.mapped(Address) - reg.metadata.create_all() + reg.metadata.create_all(testing.db) u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) @@ -1593,11 +1594,11 @@ class DeclarativeTest(DeclarativeTestBase): .where(Address.user_id == User.id) .scalar_subquery() ) - Base.metadata.create_all() + Base.metadata.create_all(testing.db) u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -1641,11 +1642,11 @@ class DeclarativeTest(DeclarativeTestBase): .scalar_subquery() ) - Base.metadata.create_all() + Base.metadata.create_all(testing.db) u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -1707,11 +1708,11 @@ class DeclarativeTest(DeclarativeTestBase): User.a = Column("a", String(10)) User.b = Column(String(10)) - Base.metadata.create_all() + Base.metadata.create_all(testing.db) u1 = User(name="u1", a="a", b="b") eq_(u1.a, "a") eq_(User.a.get_history(u1), (["a"], (), ())) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -1742,11 +1743,11 @@ class DeclarativeTest(DeclarativeTestBase): ) addresses = relationship(Address) - Base.metadata.create_all() + Base.metadata.create_all(testing.db) u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -1792,8 +1793,8 @@ class DeclarativeTest(DeclarativeTestBase): ) name = sa.orm.deferred(Column(String(50))) - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() sess.add(User(name="u1")) sess.flush() sess.expunge_all() @@ -1825,8 +1826,8 @@ class DeclarativeTest(DeclarativeTestBase): Column("state", String(2)), ) - Base.metadata.create_all() - sess = Session() + Base.metadata.create_all(testing.db) + sess = fixture_session() sess.add(User(address=AddressComposite("123 anywhere street", "MD"))) sess.commit() eq_( @@ -1852,8 +1853,8 @@ class DeclarativeTest(DeclarativeTestBase): state = Column(String(2)) address = composite(AddressComposite, street, state) - Base.metadata.create_all() - sess = Session() + Base.metadata.create_all(testing.db) + sess = fixture_session() sess.add(User(address=AddressComposite("123 anywhere street", "MD"))) sess.commit() eq_( @@ -1908,8 +1909,8 @@ class DeclarativeTest(DeclarativeTestBase): "_name", descriptor=property(_get_name, _set_name) ) - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() u1 = User(name="someuser") eq_(u1.name, "SOMENAME someuser") sess.add(u1) @@ -1937,8 +1938,8 @@ class DeclarativeTest(DeclarativeTestBase): _name = Column("name", String(50)) name = sa.orm.synonym("_name", comparator_factory=CustomCompare) - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() u1 = User(name="someuser FOO") sess.add(u1) sess.flush() @@ -1962,8 +1963,8 @@ class DeclarativeTest(DeclarativeTestBase): name = property(_get_name, _set_name) User.name = sa.orm.synonym("_name", descriptor=User.name) - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() u1 = User(name="someuser") eq_(u1.name, "SOMENAME someuser") sess.add(u1) @@ -2000,11 +2001,11 @@ class DeclarativeTest(DeclarativeTestBase): list(Address.user_id.property.columns[0].foreign_keys)[0].column, User.__table__.c.id, ) - Base.metadata.create_all() + Base.metadata.create_all(testing.db) u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -2044,11 +2045,11 @@ class DeclarativeTest(DeclarativeTestBase): .where(Address.user_id == User.id) .scalar_subquery() ) - Base.metadata.create_all() + Base.metadata.create_all(testing.db) u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -2096,11 +2097,11 @@ class DeclarativeTest(DeclarativeTestBase): __table__ = Table("t1", Base.metadata, autoload_with=testing.db) - sess = create_session() + sess = fixture_session() m = MyObj(id="someid", data="somedata") sess.add(m) sess.flush() - eq_(t1.select().execute().fetchall(), [("someid", "somedata")]) + eq_(sess.execute(t1.select()).fetchall(), [("someid", "somedata")]) def test_synonym_for(self): class User(Base, fixtures.ComparableEntity): @@ -2116,8 +2117,8 @@ class DeclarativeTest(DeclarativeTestBase): def namesyn(self): return self.name - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() u1 = User(name="someuser") eq_(u1.name, "someuser") eq_(u1.namesyn, "someuser") @@ -2356,7 +2357,7 @@ def _produce_test(inline, stringbased): # PropertyLoader.Comparator will annotate the left side with # _orm_adapt, though. - sess = create_session() + sess = fixture_session() eq_( sess.query(User) .join(User.addresses, aliased=True) diff --git a/test/orm/declarative/test_concurrency.py b/test/orm/declarative/test_concurrency.py index d731c6afa..5f12d8272 100644 --- a/test/orm/declarative/test_concurrency.py +++ b/test/orm/declarative/test_concurrency.py @@ -12,8 +12,8 @@ from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declared_attr from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session class ConcurrentUseDeclMappingTest(fixtures.TestBase): @@ -34,7 +34,7 @@ class ConcurrentUseDeclMappingTest(fixtures.TestBase): @classmethod def query_a(cls, Base, result): - s = Session() + s = fixture_session() time.sleep(random.random() / 100) A = cls.A try: diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py index d8847ed40..cc29cab7d 100644 --- a/test/orm/declarative/test_inheritance.py +++ b/test/orm/declarative/test_inheritance.py @@ -9,11 +9,9 @@ from sqlalchemy.orm import class_mapper from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import close_all_sessions from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import deferred from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ @@ -21,6 +19,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -35,7 +34,7 @@ class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): def teardown(self): close_all_sessions() clear_mappers() - Base.metadata.drop_all() + Base.metadata.drop_all(testing.db) class DeclarativeInheritanceTest(DeclarativeTestBase): @@ -137,8 +136,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): ) golf_swing = Column("golf_swing", String(50)) - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() c1 = Company( name="MegaCorp, Inc.", employees=[ @@ -218,8 +217,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): ) Engineer.primary_language = Column("primary_language", String(50)) - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() e1 = Engineer(primary_language="java", name="dilbert") sess.add(e1) sess.flush() @@ -249,8 +248,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): ) Person.name = Column("name", String(50)) - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() e1 = Engineer(primary_language="java", name="dilbert") sess.add(e1) sess.flush() @@ -289,8 +288,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): ) Person.name = Column("name", String(50)) - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() e1 = Admin(primary_language="java", name="dilbert", workstation="foo") sess.add(e1) sess.flush() @@ -531,8 +530,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): __mapper_args__ = {"polymorphic_identity": "manager"} - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() c1 = Company( name="MegaCorp, Inc.", employees=[ @@ -621,8 +620,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): assert not hasattr(Person, "golf_swing") assert not hasattr(Engineer, "golf_swing") assert not hasattr(Manager, "primary_language") - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() e1 = Engineer(name="dilbert", primary_language="java") e2 = Engineer(name="wally", primary_language="c++") m1 = Manager(name="dogbert", golf_swing="fore!") @@ -837,8 +836,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): Manager.target_id.property.columns[0], Person.__table__.c.target_id ) # do a brief round trip on this - Base.metadata.create_all() - session = Session() + Base.metadata.create_all(testing.db) + session = fixture_session() o1, o2 = Other(), Other() session.add_all( [Engineer(target=o1), Manager(target=o2), Manager(target=o1)] @@ -957,8 +956,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): assert not hasattr(Person, "golf_swing") assert not hasattr(Engineer, "golf_swing") assert not hasattr(Manager, "primary_language") - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() e1 = Engineer(name="dilbert", primary_language="java") e2 = Engineer(name="wally", primary_language="c++") m1 = Manager(name="dogbert", golf_swing="fore!") @@ -1043,8 +1042,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): ) Person.name = deferred(Column(String(10))) - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() p = Person(name="ratbert") sess.add(p) sess.flush() @@ -1085,8 +1084,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): name = Column(String(50)) assert not hasattr(Person, "primary_language_id") - Base.metadata.create_all() - sess = create_session() + Base.metadata.create_all(testing.db) + sess = fixture_session() java, cpp, cobol = ( Language(name="java"), Language(name="cpp"), diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index bc36ee962..631527daf 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -12,7 +12,6 @@ from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import close_all_sessions from sqlalchemy.orm import column_property from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred @@ -20,7 +19,6 @@ from sqlalchemy.orm import events as orm_events from sqlalchemy.orm import has_inherited_table from sqlalchemy.orm import registry from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.orm import synonym from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -29,6 +27,7 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect @@ -42,13 +41,14 @@ class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): def setup(self): global Base, mapper_registry - mapper_registry = registry(metadata=MetaData(bind=testing.db)) + mapper_registry = registry(metadata=MetaData()) Base = mapper_registry.generate_base() def teardown(self): close_all_sessions() clear_mappers() - Base.metadata.drop_all() + with testing.db.begin() as conn: + Base.metadata.drop_all(conn) class DeclarativeMixinTest(DeclarativeTestBase): @@ -68,7 +68,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): name = Column(String(100), nullable=False, index=True) Base.metadata.create_all(testing.db) - session = create_session() + session = fixture_session() session.add(MyModel(name="testing")) session.flush() session.expunge_all() @@ -94,7 +94,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): name = Column(String(100), nullable=False, index=True) Base.metadata.create_all(testing.db) - session = create_session() + session = fixture_session() session.add(MyModel(name="testing")) session.flush() session.expunge_all() @@ -135,7 +135,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): name = Column(String(100), nullable=False, index=True) Base.metadata.create_all(testing.db) - session = create_session() + session = fixture_session() session.add(MyModel(name="testing", baz="fu")) session.flush() session.expunge_all() @@ -166,7 +166,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): name = Column(String(100), nullable=False, index=True) Base.metadata.create_all(testing.db) - session = create_session() + session = fixture_session() session.add(MyModel(name="testing", baz="fu")) session.flush() session.expunge_all() @@ -459,7 +459,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): ) # do a brief round trip on this Base.metadata.create_all(testing.db) - session = Session() + session = fixture_session() o1, o2 = Other(), Other() session.add_all( [Engineer(target=o1), Manager(target=o2), Manager(target=o1)] @@ -1388,7 +1388,7 @@ class DeclarativeMixinPropertyTest( MyModel.prop_hoho.property is not MyOtherModel.prop_hoho.property ) Base.metadata.create_all(testing.db) - sess = create_session() + sess = fixture_session() m1, m2 = MyModel(prop_hoho="foo"), MyOtherModel(prop_hoho="bar") sess.add_all([m1, m2]) sess.flush() @@ -1497,7 +1497,7 @@ class DeclarativeMixinPropertyTest( d1 = inspect(Derived) is_(b1.attrs["data_syn"], d1.attrs["data_syn"]) - s = Session() + s = fixture_session() self.assert_compile( s.query(Base.data_syn).filter(Base.data_syn == "foo"), "SELECT test.data AS test_data FROM test " @@ -1567,7 +1567,7 @@ class DeclarativeMixinPropertyTest( ) Base.metadata.create_all(testing.db) - sess = create_session() + sess = fixture_session() sess.add_all([MyModel(data="d1"), MyModel(data="d2")]) sess.flush() sess.expunge_all() @@ -1619,7 +1619,7 @@ class DeclarativeMixinPropertyTest( ) Base.metadata.create_all(testing.db) - sess = create_session() + sess = fixture_session() t1, t2 = Target(), Target() f1, f2, b1 = Foo(target=t1), Foo(target=t2), Bar(target=t1) sess.add_all([f1, f2, b1]) @@ -1678,7 +1678,7 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): is_(a_col, A.__table__.c.x) is_(b_col, B.__table__.c.x) - s = Session() + s = fixture_session() self.assert_compile( s.query(A), "SELECT a.x AS a_x, a.x + :x_1 AS anon_1, a.id AS a_id FROM a", @@ -1982,7 +1982,7 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): eq_(counter.mock_calls, [mock.call(User.id)]) - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User).having(User.address_count > 5), "SELECT (SELECT count(address.id) AS " diff --git a/test/orm/declarative/test_reflection.py b/test/orm/declarative/test_reflection.py index 32514a473..241528c44 100644 --- a/test/orm/declarative/test_reflection.py +++ b/test/orm/declarative/test_reflection.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.testing import assert_raises from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -77,7 +77,7 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -110,7 +110,7 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): u1 = User( nom="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -146,7 +146,7 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): u1 = User( nom="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() @@ -184,7 +184,7 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): IMHandle(network="lol", handle="zomg"), ], ) - sess = create_session() + sess = fixture_session() sess.add(u1) sess.flush() sess.expunge_all() diff --git a/test/orm/inheritance/_poly_fixtures.py b/test/orm/inheritance/_poly_fixtures.py index 4253b4bee..7a70810a1 100644 --- a/test/orm/inheritance/_poly_fixtures.py +++ b/test/orm/inheritance/_poly_fixtures.py @@ -575,6 +575,10 @@ class GeometryFixtureBase(fixtures.DeclarativeMappedTest): if "subclasses" in value: self._fixture_from_geometry(value["subclasses"], klass) - if is_base and self.metadata.tables and self.run_create_tables: - self.tables.update(self.metadata.tables) - self.metadata.create_all(config.db) + if ( + is_base + and self.tables_test_metadata.tables + and self.run_create_tables + ): + self.tables.update(self.tables_test_metadata.tables) + self.tables_test_metadata.create_all(config.db) diff --git a/test/orm/inheritance/test_abc_inheritance.py b/test/orm/inheritance/test_abc_inheritance.py index bce554f30..a368e7b2f 100644 --- a/test/orm/inheritance/test_abc_inheritance.py +++ b/test/orm/inheritance/test_abc_inheritance.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm.interfaces import MANYTOONE from sqlalchemy.orm.interfaces import ONETOMANY from sqlalchemy.testing import fixtures -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -221,7 +221,7 @@ class ABCTest(fixtures.MappedTest): parent_class = {"a": A, "b": B, "c": C}[parent] child_class = {"a": A, "b": B, "c": C}[child] - sess = create_session() + sess = fixture_session(autoflush=False, expire_on_commit=False) parent_obj = parent_class("parent1") child_obj = child_class("child1") diff --git a/test/orm/inheritance/test_abc_polymorphic.py b/test/orm/inheritance/test_abc_polymorphic.py index 0d28ef342..fd3d50ddd 100644 --- a/test/orm/inheritance/test_abc_polymorphic.py +++ b/test/orm/inheritance/test_abc_polymorphic.py @@ -5,7 +5,7 @@ from sqlalchemy import testing from sqlalchemy.orm import mapper from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -77,7 +77,7 @@ class ABCTest(fixtures.MappedTest): c2 = C(cdata="c2", bdata="c2", adata="c2") c3 = C(cdata="c2", bdata="c2", adata="c2") - sess = create_session() + sess = fixture_session() for x in (a1, b1, b2, b3, c1, c2, c3): sess.add(x) sess.flush() diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index ce8d76a53..3cf9c9837 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -21,14 +21,13 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import with_polymorphic from sqlalchemy.orm.interfaces import MANYTOONE from sqlalchemy.testing import AssertsExecutionResults from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -108,7 +107,7 @@ class RelationshipTest1(fixtures.MappedTest): [(managers.c.person_id, people.c.manager_id)], ) - session = create_session() + session = fixture_session() p = Person(name="some person") m = Manager(name="some manager") p.manager = m @@ -140,7 +139,7 @@ class RelationshipTest1(fixtures.MappedTest): }, ) - session = create_session() + session = fixture_session() p = Person(name="some person") m = Manager(name="some manager") m.employee = p @@ -297,7 +296,7 @@ class RelationshipTest2(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() p = Person(name="person1") m = Manager(name="manager1") m.colleague = p @@ -462,7 +461,7 @@ class RelationshipTest3(fixtures.MappedTest): self._setup_mappings(jointype, usedata) Person, Manager, Data = self.classes("Person", "Manager", "Data") - sess = create_session() + sess = fixture_session() p = Person(name="person1") p2 = Person(name="person2") p3 = Person(name="person3") @@ -611,7 +610,7 @@ class RelationshipTest4(fixtures.MappedTest): ) mapper(Car, cars, properties={"employee": relationship(person_mapper)}) - session = create_session() + session = fixture_session() # creating 5 managers named from M1 to E5 for i in range(1, 5): @@ -780,7 +779,7 @@ class RelationshipTest5(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() car1 = Car() car2 = Car() car2.manager = Manager() @@ -855,7 +854,7 @@ class RelationshipTest6(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() m = Manager(name="manager1") m2 = Manager(name="manager2") m.colleague = m2 @@ -1027,7 +1026,7 @@ class RelationshipTest7(fixtures.MappedTest): polymorphic_identity="manager", ) - session = create_session() + session = fixture_session() for i in range(1, 4): if i % 2: @@ -1095,7 +1094,7 @@ class RelationshipTest8(fixtures.MappedTest): u1 = User(data="u1") t1 = Taggable(owner=u1) - sess = create_session() + sess = fixture_session() sess.add(t1) sess.flush() @@ -1303,7 +1302,7 @@ class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults): Status, Person, Engineer, Manager, Car = self.classes( "Status", "Person", "Engineer", "Manager", "Car" ) - session = create_session() + session = fixture_session() r = ( session.query(Person) @@ -1328,7 +1327,7 @@ class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults): Status, Person, Engineer, Manager, Car = self.classes( "Status", "Person", "Engineer", "Manager", "Car" ) - session = create_session() + session = fixture_session() r = ( session.query(Engineer) .join("status") @@ -1351,7 +1350,7 @@ class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults): Status, Person, Engineer, Manager, Car = self.classes( "Status", "Person", "Engineer", "Manager", "Car" ) - session = create_session() + session = fixture_session() r = session.query(Person).filter( exists().where(Car.owner == Person.person_id) ) @@ -1471,7 +1470,7 @@ class MultiLevelTest(fixtures.MappedTest): b = Engineer().set(egn="two", machine="any") c = Manager().set(name="head", machine="fast", duties="many") - session = create_session() + session = fixture_session() session.add(a) session.add(b) session.add(c) @@ -1621,7 +1620,7 @@ class CustomPKTest(fixtures.MappedTest): mapper(T2, t2, inherits=T1, polymorphic_identity="t2") ot1 = T1() ot2 = T2() - sess = create_session() + sess = fixture_session() sess.add(ot1) sess.add(ot2) sess.flush() @@ -1668,7 +1667,7 @@ class CustomPKTest(fixtures.MappedTest): ot1 = T1() ot2 = T2() - sess = create_session() + sess = fixture_session() sess.add(ot1) sess.add(ot2) sess.flush() @@ -1754,7 +1753,7 @@ class InheritingEagerTest(fixtures.MappedTest): ) mapper(Tag, tags) - session = create_session() + session = fixture_session() bob = Employee() session.add(bob) @@ -1852,7 +1851,7 @@ class MissingPolymorphicOnTest(fixtures.MappedTest): c = C(cdata="c1", adata="a1", b=B(data="c")) d = D(cdata="c2", adata="a2", ddata="d2", b=B(data="d")) - sess = create_session() + sess = fixture_session() sess.add(c) sess.add(d) sess.flush() @@ -1899,7 +1898,7 @@ class JoinedInhAdjacencyTest(fixtures.MappedTest): def _roundtrip(self): User = self.classes.User - sess = Session() + sess = fixture_session() u1 = User() u2 = User() u2.supervisor = u1 @@ -1910,7 +1909,7 @@ class JoinedInhAdjacencyTest(fixtures.MappedTest): def _dude_roundtrip(self): Dude, User = self.classes.Dude, self.classes.User - sess = Session() + sess = fixture_session() u1 = User() d1 = Dude() d1.supervisor = u1 @@ -2070,7 +2069,7 @@ class Ticket2419Test(fixtures.DeclarativeMappedTest): ) def test_join_w_eager_w_any(self): B, C, D = (self.classes.B, self.classes.C, self.classes.D) - s = Session(testing.db) + s = fixture_session() b = B(ds=[D()]) s.add_all([C(b=b)]) @@ -2114,7 +2113,7 @@ class ColSubclassTest( def test_polymorphic_adaptation(self): A, B = self.classes.A, self.classes.B - s = Session() + s = fixture_session() self.assert_compile( s.query(A).join(B).filter(B.x == "test"), "SELECT a.id AS a_id FROM a JOIN " @@ -2181,7 +2180,7 @@ class CorrelateExceptWPolyAdaptTest( poly = with_polymorphic(Superclass, "*") - s = Session() + s = fixture_session() q = ( s.query(poly) .options(contains_eager(poly.common_relationship)) @@ -2210,7 +2209,7 @@ class CorrelateExceptWPolyAdaptTest( poly = with_polymorphic(Superclass, "*") - s = Session() + s = fixture_session() q = ( s.query(poly) .options(contains_eager(poly.common_relationship)) diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index af960625e..bdcdedc44 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -27,7 +27,6 @@ from sqlalchemy.orm import object_mapper from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship from sqlalchemy.orm import Session -from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import synonym from sqlalchemy.orm.util import instance_str from sqlalchemy.testing import assert_raises @@ -42,7 +41,7 @@ from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional from sqlalchemy.testing.assertsql import Or from sqlalchemy.testing.assertsql import RegexSQL -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -104,7 +103,7 @@ class O2MTest(fixtures.MappedTest): properties={"parent_foo": relationship(Foo)}, ) - sess = create_session() + sess = fixture_session() b1 = Blub("blub #1") b2 = Blub("blub #2") f = Foo("foo #1") @@ -166,7 +165,7 @@ class ColExpressionsTest(fixtures.DeclarativeMappedTest): def test_group_by(self): B = self.classes.B - s = Session() + s = fixture_session() rows = ( s.query(B.id.expressions[0], B.id.expressions[1], func.sum(B.data)) @@ -222,7 +221,7 @@ class PolyExpressionEagerLoad(fixtures.DeclarativeMappedTest): A = self.classes.A B = self.classes.B - session = Session(testing.db) + session = fixture_session() result = ( session.query(A) .filter_by(child_id=None) @@ -590,7 +589,7 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): else: assert False, "Got unexpected identity %r" % ident - s = Session(testing.db) + s = fixture_session() s.add_all([Parent(q="p1"), Child(q="c1", y="c1"), Parent(q="p2")]) s.commit() s.close() @@ -647,7 +646,7 @@ class SortOnlyOnImportantFKsTest(fixtures.MappedTest): cls.classes.B = B def test_flush(self): - s = Session(testing.db) + s = fixture_session() s.add(self.classes.B()) s.flush() @@ -674,7 +673,7 @@ class FalseDiscriminatorTest(fixtures.MappedTest): mapper(Foo, t1, polymorphic_on=t1.c.type, polymorphic_identity=True) mapper(Bar, inherits=Foo, polymorphic_identity=False) - sess = create_session() + sess = fixture_session() b1 = Bar() sess.add(b1) sess.flush() @@ -691,7 +690,7 @@ class FalseDiscriminatorTest(fixtures.MappedTest): mapper(Ding, t1, polymorphic_on=t1.c.type, polymorphic_identity=False) mapper(Bat, inherits=Ding, polymorphic_identity=True) - sess = create_session() + sess = fixture_session() d1 = Ding() sess.add(d1) sess.flush() @@ -741,7 +740,7 @@ class PolymorphicSynonymTest(fixtures.MappedTest): properties={"info": synonym("_info", map_column=True)}, ) mapper(T2, t2, inherits=T1, polymorphic_identity="t2") - sess = create_session() + sess = fixture_session() at1 = T1(info="at1") at2 = T2(info="at2", data="t2 data") sess.add(at1) @@ -832,7 +831,7 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): def test_base_class(self): A, C, B = (self.classes.A, self.classes.C, self.classes.B) - sess = Session() + sess = fixture_session() c1 = C() sess.add(c1) sess.commit() @@ -849,7 +848,7 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): """ D, B = self.classes.D, self.classes.B - sess = Session() + sess = fixture_session() b1 = B() b1.class_name = "d" sess.add(b1) @@ -863,7 +862,7 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): """ C = self.classes.C - sess = Session() + sess = fixture_session() c1 = C() c1.class_name = "b" sess.add(c1) @@ -882,7 +881,7 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): """ B = self.classes.B - sess = Session() + sess = fixture_session() b1 = B() b1.class_name = "c" sess.add(b1) @@ -898,7 +897,7 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): """test warn on an unknown polymorphic identity.""" B = self.classes.B - sess = Session() + sess = fixture_session() b1 = B() b1.class_name = "xyz" sess.add(b1) @@ -913,7 +912,7 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): def test_not_set_on_upate(self): C = self.classes.C - sess = Session() + sess = fixture_session() c1 = C() sess.add(c1) sess.commit() @@ -925,7 +924,7 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): def test_validate_on_upate(self): C = self.classes.C - sess = Session() + sess = fixture_session() c1 = C() sess.add(c1) sess.commit() @@ -1009,7 +1008,7 @@ class CascadeTest(fixtures.MappedTest): ) mapper(T4, t4) - sess = create_session() + sess = fixture_session() t1_1 = T1(data="t1") t3_1 = T3(data="t3", moredata="t3") @@ -1095,7 +1094,7 @@ class M2OUseGetTest(fixtures.MappedTest): assert class_mapper(Related).get_property("sub").strategy.use_get - sess = create_session() + sess = fixture_session() s1 = Sub() r1 = Related(sub=s1) sess.add(r1) @@ -1181,7 +1180,7 @@ class GetTest(fixtures.MappedTest): mapper(Bar, bar, inherits=Foo) mapper(Blub, blub, inherits=Bar) - sess = create_session() + sess = fixture_session() f = Foo() b = Bar() bl = Blub() @@ -1295,7 +1294,7 @@ class EagerLazyTest(fixtures.MappedTest): def test_basic(self): Bar = self.classes.Bar - sess = create_session() + sess = fixture_session() q = sess.query(Bar) self.assert_(len(q.first().lazy) == 1) self.assert_(len(q.first().eager) == 1) @@ -1350,7 +1349,7 @@ class EagerTargetingTest(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() b1 = B(id=1, name="b1", b_data="i") sess.add(b1) @@ -1454,7 +1453,7 @@ class FlushTest(fixtures.MappedTest): }, ) mapper(Admin, admins, inherits=user_mapper) - sess = create_session() + sess = fixture_session() adminrole = Role() sess.add(adminrole) sess.flush() @@ -1507,7 +1506,7 @@ class FlushTest(fixtures.MappedTest): # create roles adminrole = Role("admin") - sess = create_session() + sess = fixture_session() sess.add(adminrole) sess.flush() @@ -1581,7 +1580,7 @@ class PassiveDeletesTest(fixtures.MappedTest): A, B, C = self.classes("A", "B", "C") self._fixture() - s = Session() + s = fixture_session() a1, b1, c1 = A(id=1), B(id=2), C(cid=1, id=3) s.add_all([a1, b1, c1]) s.commit() @@ -1605,7 +1604,7 @@ class PassiveDeletesTest(fixtures.MappedTest): A, B, C = self.classes("A", "B", "C") self._fixture(c_p=True) - s = Session() + s = fixture_session() a1, b1, c1 = A(id=1), B(id=2), C(cid=1, id=3) s.add_all([a1, b1, c1]) s.commit() @@ -1647,7 +1646,7 @@ class PassiveDeletesTest(fixtures.MappedTest): A, B, C = self.classes("A", "B", "C") self._fixture(b_p=True) - s = Session() + s = fixture_session() a1, b1, c1 = A(id=1), B(id=2), C(cid=1, id=3) s.add_all([a1, b1, c1]) s.commit() @@ -1685,7 +1684,7 @@ class PassiveDeletesTest(fixtures.MappedTest): A, B, C = self.classes("A", "B", "C") self._fixture(a_p=True) - s = Session() + s = fixture_session() a1, b1, c1 = A(id=1), B(id=2), C(cid=1, id=3) s.add_all([a1, b1, c1]) s.commit() @@ -1767,7 +1766,7 @@ class OptimizedGetOnDeferredTest(fixtures.MappedTest): def test_column_property(self): A, B = self.classes("A", "B") - sess = Session() + sess = fixture_session() b1 = B(data="x") sess.add(b1) sess.flush() @@ -1776,7 +1775,7 @@ class OptimizedGetOnDeferredTest(fixtures.MappedTest): def test_expired_column(self): A, B = self.classes("A", "B") - sess = Session() + sess = fixture_session() b1 = B(data="x") sess.add(b1) sess.flush() @@ -1830,7 +1829,7 @@ class JoinedNoFKSortingTest(fixtures.MappedTest): def test_ordering(self): B, C = self.classes.B, self.classes.C - sess = Session() + sess = fixture_session() sess.add_all([B(), C(), B(), C()]) self.assert_sql_execution( testing.db, @@ -1918,7 +1917,7 @@ class VersioningTest(fixtures.MappedTest): ) mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) b1 = Base(value="b1") s1 = Sub(value="sub1", subdata="some subdata") @@ -1927,7 +1926,7 @@ class VersioningTest(fixtures.MappedTest): sess.commit() - sess2 = Session(autoflush=False) + sess2 = fixture_session(autoflush=False) s2 = sess2.get(Base, s1.id) s2.subdata = "sess2 subdata" @@ -1976,7 +1975,7 @@ class VersioningTest(fixtures.MappedTest): ) mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) - sess = Session(autoflush=False, expire_on_commit=False) + sess = fixture_session(autoflush=False, expire_on_commit=False) b1 = Base(value="b1") s1 = Sub(value="sub1", subdata="some subdata") @@ -1987,7 +1986,7 @@ class VersioningTest(fixtures.MappedTest): sess.commit() - sess2 = Session(autoflush=False, expire_on_commit=False) + sess2 = fixture_session(autoflush=False, expire_on_commit=False) s3 = sess2.get(Base, s1.id) sess2.delete(s3) sess2.commit() @@ -2099,7 +2098,7 @@ class DistinctPKTest(fixtures.MappedTest): self._do_test(False) def _do_test(self, composite): - session = create_session() + session = fixture_session() if composite: alice1 = session.get(Employee, [1, 2]) @@ -2178,7 +2177,7 @@ class SyncCompileTest(fixtures.MappedTest): mapper(B, _b_table, inherits=A, inherit_condition=j1) mapper(C, _c_table, inherits=B, inherit_condition=j2) - session = create_session() + session = fixture_session() a = A(data1="a1") session.add(a) @@ -2284,7 +2283,7 @@ class OverrideColKeyTest(fixtures.MappedTest): s1 = Sub() s1.id = 10 - sess = create_session() + sess = fixture_session() sess.add(s1) sess.flush() assert sess.get(Sub, 10) is s1 @@ -2312,7 +2311,7 @@ class OverrideColKeyTest(fixtures.MappedTest): s2 = Sub() s2.base_id = 15 - sess = create_session() + sess = fixture_session() sess.add_all([s1, s2]) sess.flush() @@ -2384,7 +2383,7 @@ class OverrideColKeyTest(fixtures.MappedTest): mapper(Sub, subtable, inherits=Base) s1 = Sub() - sess = create_session() + sess = fixture_session() sess.add(s1) sess.flush() assert sess.query(Sub).one().data == "im the data" @@ -2409,7 +2408,7 @@ class OverrideColKeyTest(fixtures.MappedTest): mapper(Sub, subtable, inherits=Base) s1 = Sub() - sess = create_session() + sess = fixture_session() sess.add(s1) sess.flush() assert sess.query(Sub).one().data == "im the data" @@ -2426,7 +2425,7 @@ class OverrideColKeyTest(fixtures.MappedTest): mapper(Base, base) mapper(Sub, subtable, inherits=Base) - sess = create_session() + sess = fixture_session() b1 = Base() assert b1.subdata == "this is base" s1 = Sub() @@ -2452,7 +2451,7 @@ class OverrideColKeyTest(fixtures.MappedTest): mapper(Base, base) mapper(Sub, subtable, inherits=Base) - sess = create_session() + sess = fixture_session() b1 = Base() assert b1.data == "this is base" s1 = Sub() @@ -2528,7 +2527,7 @@ class OptimizedLoadTest(fixtures.MappedTest): ) mapper(SubJoinBase, inherits=JoinBase) - sess = Session() + sess = fixture_session() sess.add(Base(data="data")) sess.commit() @@ -2595,7 +2594,7 @@ class OptimizedLoadTest(fixtures.MappedTest): base.outerjoin(sub).select().apply_labels().alias("foo"), ), ) - sess = Session() + sess = fixture_session() s1 = Sub( data="s1data", sub="s1sub", subcounter=1, counter=1, subcounter2=1 ) @@ -2638,7 +2637,7 @@ class OptimizedLoadTest(fixtures.MappedTest): polymorphic_identity="sub", properties={"id": [sub.c.id, base.c.id]}, ) - sess = sessionmaker()() + sess = fixture_session() s1 = Sub(data="s1data", sub="s1sub") sess.add(s1) sess.commit() @@ -2673,7 +2672,7 @@ class OptimizedLoadTest(fixtures.MappedTest): "concat": column_property(sub.c.sub + "|" + sub.c.sub) }, ) - sess = sessionmaker()() + sess = fixture_session() s1 = Sub(data="s1data", sub="s1sub") sess.add(s1) sess.commit() @@ -2702,7 +2701,7 @@ class OptimizedLoadTest(fixtures.MappedTest): "concat": column_property(base.c.data + "|" + sub.c.sub) }, ) - sess = sessionmaker()() + sess = fixture_session() s1 = Sub(data="s1data", sub="s1sub") s2 = Sub(data="s2data", sub="s2sub") s3 = Sub(data="s3data", sub="s3sub") @@ -2752,7 +2751,7 @@ class OptimizedLoadTest(fixtures.MappedTest): polymorphic_identity="wc", properties={"comp": composite(Comp, with_comp.c.a, with_comp.c.b)}, ) - sess = sessionmaker()() + sess = fixture_session() s1 = WithComp(data="s1data", comp=Comp("ham", "cheese")) s2 = WithComp(data="s2data", comp=Comp("bacon", "eggs")) sess.add_all([s1, s2]) @@ -2777,7 +2776,7 @@ class OptimizedLoadTest(fixtures.MappedTest): Base, base, polymorphic_on=base.c.type, polymorphic_identity="base" ) mapper(Sub, sub, inherits=Base, polymorphic_identity="sub") - sess = Session() + sess = fixture_session() s1 = Sub(data="s1") sess.add(s1) self.assert_sql_execution( @@ -2871,7 +2870,7 @@ class OptimizedLoadTest(fixtures.MappedTest): ) mapper(Sub, sub, inherits=Base, polymorphic_identity="sub") mapper(SubSub, subsub, inherits=Sub, polymorphic_identity="subsub") - sess = Session() + sess = fixture_session() s1 = SubSub(data="s1", counter=1, subcounter=2) sess.add(s1) self.assert_sql_execution( @@ -3152,7 +3151,7 @@ class PKDiscriminatorTest(fixtures.MappedTest): mapper(A, inherits=Child, polymorphic_identity=2) - s = create_session() + s = fixture_session() p = Parent("p1") a = A("a1") p.children.append(a) @@ -3220,7 +3219,7 @@ class NoPolyIdentInMiddleTest(fixtures.MappedTest): def test_load_from_middle(self): C, B = self.classes.C, self.classes.B - s = Session() + s = fixture_session() s.add(C()) o = s.query(B).first() eq_(o.type, "c") @@ -3229,7 +3228,7 @@ class NoPolyIdentInMiddleTest(fixtures.MappedTest): def test_load_from_base(self): A, C = self.classes.A, self.classes.C - s = Session() + s = fixture_session() s.add(C()) o = s.query(A).first() eq_(o.type, "c") @@ -3250,7 +3249,7 @@ class NoPolyIdentInMiddleTest(fixtures.MappedTest): self.tables.base, ) - s = Session() + s = fixture_session() s.add_all([C(), D(), E()]) eq_(s.query(B).order_by(base.c.type).all(), [C(), D()]) @@ -3314,7 +3313,7 @@ class DeleteOrphanTest(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() s1 = SubClass(data="s1") sess.add(s1) assert_raises(sa_exc.DBAPIError, sess.flush) @@ -3418,7 +3417,7 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest): @classmethod def insert_data(cls, connection): Parent, A, B = cls.classes("Parent", "A", "B") - s = Session() + s = fixture_session() p1 = Parent(id=1) p2 = Parent(id=2) @@ -3443,7 +3442,7 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest): def test_pk_is_null(self): Parent, A = self.classes("Parent", "A") - sess = Session() + sess = fixture_session() q = ( sess.query(Parent, A) .select_from(Parent) @@ -3457,7 +3456,7 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest): def test_pk_not_null_discriminator_null_from_base(self): (A,) = self.classes("A") - sess = Session() + sess = fixture_session() q = sess.query(A).filter(A.id == 3) assert_raises_message( sa_exc.InvalidRequestError, @@ -3470,7 +3469,7 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest): def test_pk_not_null_discriminator_null_from_sub(self): (B,) = self.classes("B") - sess = Session() + sess = fixture_session() q = sess.query(B).filter(B.id == 4) assert_raises_message( sa_exc.InvalidRequestError, @@ -3528,7 +3527,7 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest): ASingleSubA, ASingleSubB, AJoinedSubA, AJoinedSubB = cls.classes( "ASingleSubA", "ASingleSubB", "AJoinedSubA", "AJoinedSubB" ) - s = Session() + s = fixture_session() s.add_all([ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()]) s.commit() @@ -3536,7 +3535,7 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest): def test_single_invalid_ident(self): ASingle, ASingleSubA = self.classes("ASingle", "ASingleSubA") - s = Session() + s = fixture_session() q = s.query(ASingleSubA).select_entity_from(select(ASingle).subquery()) @@ -3552,7 +3551,7 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest): def test_joined_invalid_ident(self): AJoined, AJoinedSubA = self.classes("AJoined", "AJoinedSubA") - s = Session() + s = fixture_session() q = s.query(AJoinedSubA).select_entity_from(select(AJoined).subquery()) @@ -3600,7 +3599,7 @@ class NameConflictTest(fixtures.MappedTest): mapper( Foo, self.tables.foo, inherits=Content, polymorphic_identity="foo" ) - sess = create_session() + sess = fixture_session() f = Foo() f.content_type = "bar" sess.add(f) diff --git a/test/orm/inheritance/test_concrete.py b/test/orm/inheritance/test_concrete.py index e2777e9e9..f2f8d629b 100644 --- a/test/orm/inheritance/test_concrete.py +++ b/test/orm/inheritance/test_concrete.py @@ -11,13 +11,12 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship -from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -165,7 +164,7 @@ class ConcreteTest(fixtures.MappedTest): concrete=True, polymorphic_identity="engineer", ) - session = create_session() + session = fixture_session() session.add(Manager("Tom", "knows how to manage things")) session.add(Engineer("Kurt", "knows how to hack")) session.flush() @@ -225,7 +224,7 @@ class ConcreteTest(fixtures.MappedTest): concrete=True, polymorphic_identity="hacker", ) - session = create_session() + session = fixture_session() tom = Manager("Tom", "knows how to manage things") assert_raises_message( @@ -348,7 +347,7 @@ class ConcreteTest(fixtures.MappedTest): polymorphic_identity="engineer", ) - session = create_session() + session = fixture_session() tom = ManagerWHybrid("Tom", "mgrdata") # mapping did not impact the engineer_info @@ -422,7 +421,7 @@ class ConcreteTest(fixtures.MappedTest): concrete=True, polymorphic_identity="hacker", ) - session = create_session() + session = fixture_session() tom = Manager("Tom", "knows how to manage things") jerry = Engineer("Jerry", "knows how to program") hacker = Hacker("Kurt", "Badass", "knows how to hack") @@ -509,7 +508,7 @@ class ConcreteTest(fixtures.MappedTest): concrete=True, polymorphic_identity="hacker", ) - session = create_session() + session = fixture_session() jdoe = Employee("Jdoe") tom = Manager("Tom", "knows how to manage things") jerry = Engineer("Jerry", "knows how to program") @@ -635,7 +634,7 @@ class ConcreteTest(fixtures.MappedTest): concrete=True, polymorphic_identity="engineer", ) - session = create_session() + session = fixture_session() c = Company() c.employees.append(Manager("Tom", "knows how to manage things")) c.employees.append(Engineer("Kurt", "knows how to hack")) @@ -788,7 +787,7 @@ class PropertyInheritanceTest(fixtures.MappedTest): "many_b": relationship(B, back_populates="some_dest"), }, ) - sess = sessionmaker()() + sess = fixture_session() dest1 = Dest(name="c1") dest2 = Dest(name="c2") a1 = A(some_dest=dest1, aname="a1") @@ -916,7 +915,7 @@ class PropertyInheritanceTest(fixtures.MappedTest): }, ) - sess = sessionmaker()() + sess = fixture_session() dest1 = Dest(name="c1") dest2 = Dest(name="c2") a1 = A(some_dest=dest1, aname="a1", id=1) @@ -1021,7 +1020,7 @@ class PropertyInheritanceTest(fixtures.MappedTest): assert B.some_dest.property.parent is class_mapper(B) assert A.some_dest.property.parent is class_mapper(A) - sess = sessionmaker()() + sess = fixture_session() dest1 = Dest(name="d1") dest2 = Dest(name="d2") a1 = A(some_dest=dest2, aname="a1") @@ -1030,7 +1029,7 @@ class PropertyInheritanceTest(fixtures.MappedTest): sess.add_all([dest1, dest2, c1, a1, b1]) sess.commit() - sess2 = sessionmaker()() + sess2 = fixture_session() merged_c1 = sess2.merge(c1) eq_(merged_c1.some_dest.name, "d2") eq_(merged_c1.some_dest_id, c1.some_dest_id) @@ -1135,7 +1134,7 @@ class ManyToManyTest(fixtures.MappedTest): }, ) mapper(Related, related) - sess = sessionmaker()() + sess = fixture_session() b1, s1, r1, r2, r3 = Base(), Sub(), Related(), Related(), Related() b1.related.append(r1) b1.related.append(r2) @@ -1227,7 +1226,7 @@ class ColKeysTest(fixtures.MappedTest): concrete=True, polymorphic_identity="refugee", ) - sess = create_session() + sess = fixture_session() eq_(sess.get(Refugee, 1).name, "refugee1") eq_(sess.get(Refugee, 2).name, "refugee2") eq_(sess.get(Office, 1).name, "office1") diff --git a/test/orm/inheritance/test_magazine.py b/test/orm/inheritance/test_magazine.py index 334ed22f3..7f347e40f 100644 --- a/test/orm/inheritance/test_magazine.py +++ b/test/orm/inheritance/test_magazine.py @@ -10,9 +10,9 @@ from sqlalchemy.orm import backref from sqlalchemy.orm import mapper from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -360,7 +360,7 @@ class MagazineTest(fixtures.MappedTest): Publication = self.classes.Publication - session = Session() + session = fixture_session() pub = self._generate_data() session.add(pub) diff --git a/test/orm/inheritance/test_manytomany.py b/test/orm/inheritance/test_manytomany.py index 207ac09c7..f790b11ac 100644 --- a/test/orm/inheritance/test_manytomany.py +++ b/test/orm/inheritance/test_manytomany.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session class InheritTest(fixtures.MappedTest): @@ -114,7 +114,7 @@ class InheritTest(fixtures.MappedTest): login_id="lg1", ) ) - sess = create_session() + sess = fixture_session() sess.add(g) sess.flush() # TODO: put an assertion @@ -164,7 +164,7 @@ class InheritTest2(fixtures.MappedTest): print(foo.join(bar).primary_key) print(class_mapper(Bar).primary_key) b = Bar("somedata") - sess = create_session() + sess = fixture_session() sess.add(b) sess.flush() sess.expunge_all() @@ -192,7 +192,7 @@ class InheritTest2(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() b = Bar("barfoo") sess.add(b) sess.flush() @@ -304,7 +304,7 @@ class InheritTest3(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() b = Bar("bar #1") sess.add(b) b.foos.append(Foo("foo #1")) @@ -352,7 +352,7 @@ class InheritTest3(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() f1 = Foo("foo #1") b1 = Bar("bar #1") b2 = Bar("bar #2") diff --git a/test/orm/inheritance/test_poly_linked_list.py b/test/orm/inheritance/test_poly_linked_list.py index 83c3e75a0..d5305c473 100644 --- a/test/orm/inheritance/test_poly_linked_list.py +++ b/test/orm/inheritance/test_poly_linked_list.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.testing import fixtures -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -211,7 +211,7 @@ class PolymorphicCircularTest(fixtures.MappedTest): ) def _testlist(self, classes): - sess = create_session() + sess = fixture_session() # create objects in a linked list count = 1 diff --git a/test/orm/inheritance/test_poly_loading.py b/test/orm/inheritance/test_poly_loading.py index d7040e822..2f31ab0c4 100644 --- a/test/orm/inheritance/test_poly_loading.py +++ b/test/orm/inheritance/test_poly_loading.py @@ -16,6 +16,7 @@ from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import EachOf from sqlalchemy.testing.assertsql import Or +from sqlalchemy.testing.fixtures import fixture_session from ._poly_fixtures import _Polymorphic from ._poly_fixtures import Company from ._poly_fixtures import Engineer @@ -145,7 +146,7 @@ class LoadBaseAndSubWEagerRelOpt( def test_load(self): A, B, ASub, C = self.classes("A", "B", "ASub", "C") - s = Session() + s = fixture_session() q = ( s.query(A) @@ -169,7 +170,7 @@ class LoadBaseAndSubWEagerRelMapped( def test_load(self): A, B, ASub, C = self.classes("A", "B", "ASub", "C") - s = Session() + s = fixture_session() q = ( s.query(A) @@ -182,7 +183,7 @@ class LoadBaseAndSubWEagerRelMapped( class FixtureLoadTest(_Polymorphic, testing.AssertsExecutionResults): def test_person_selectin_subclasses(self): - s = Session() + s = fixture_session() q = s.query(Person).options( selectin_polymorphic(Person, [Engineer, Manager]) ) @@ -228,7 +229,7 @@ class FixtureLoadTest(_Polymorphic, testing.AssertsExecutionResults): eq_(result, self.all_employees) def test_load_company_plus_employees(self): - s = Session() + s = fixture_session() q = ( s.query(Company) .options( @@ -316,7 +317,7 @@ class TestGeometries(GeometryFixtureBase): ) a, b, c, d, e = self.classes("a", "b", "c", "d", "e") - sess = Session() + sess = fixture_session() sess.add_all([d(d_data="d1"), e(e_data="e1")]) sess.commit() @@ -370,7 +371,7 @@ class TestGeometries(GeometryFixtureBase): ) a, b, c, d, e = self.classes("a", "b", "c", "d", "e") - sess = Session() + sess = fixture_session() sess.add_all([d(d_data="d1"), e(e_data="e1")]) sess.commit() @@ -420,7 +421,7 @@ class TestGeometries(GeometryFixtureBase): ) a, b, c, d, e = self.classes("a", "b", "c", "d", "e") - sess = Session() + sess = fixture_session() sess.add_all([d(d_data="d1"), e(e_data="e1")]) sess.commit() @@ -507,7 +508,7 @@ class TestGeometries(GeometryFixtureBase): ) a, a1, a2 = self.classes("a", "a1", "a2") - sess = Session() + sess = fixture_session() a1_obj = a1() a2_obj = a2() @@ -586,7 +587,7 @@ class LoaderOptionsTest( Parent, ChildSubclass1, Other = self.classes( "Parent", "ChildSubclass1", "Other" ) - session = Session(enable_baked_queries=enable_baked) + session = fixture_session(enable_baked_queries=enable_baked) def no_opt(): q = session.query(Parent).options( diff --git a/test/orm/inheritance/test_poly_persistence.py b/test/orm/inheritance/test_poly_persistence.py index 99cab870b..c33f3e0de 100644 --- a/test/orm/inheritance/test_poly_persistence.py +++ b/test/orm/inheritance/test_poly_persistence.py @@ -14,7 +14,7 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -161,7 +161,7 @@ class InsertOrderTest(PolymorphTest): }, ) - session = create_session() + session = fixture_session() c = Company(name="company1") c.employees.append( Manager( @@ -391,7 +391,7 @@ class RoundTripTest(PolymorphTest): else: person_attribute_name = "name" - session = create_session() + session = fixture_session() dilbert = ( session.query(Engineer) @@ -429,7 +429,7 @@ class RoundTripTest(PolymorphTest): self.assert_sql_count(testing.db, go, 3) def test_baseclass_lookup(self, get_dilbert): - session = Session() + session = fixture_session() dilbert = get_dilbert(session) if self.redefine_colprop: @@ -449,7 +449,7 @@ class RoundTripTest(PolymorphTest): ) def test_subclass_lookup(self, get_dilbert): - session = Session() + session = fixture_session() dilbert = get_dilbert(session) if self.redefine_colprop: @@ -465,7 +465,7 @@ class RoundTripTest(PolymorphTest): ) def test_baseclass_base_alias_filter(self, get_dilbert): - session = Session() + session = fixture_session() dilbert = get_dilbert(session) # test selecting from the query, joining against @@ -485,7 +485,7 @@ class RoundTripTest(PolymorphTest): ) def test_subclass_base_alias_filter(self, get_dilbert): - session = Session() + session = fixture_session() dilbert = get_dilbert(session) palias = people.alias("palias") @@ -501,7 +501,7 @@ class RoundTripTest(PolymorphTest): ) def test_baseclass_sub_table_filter(self, get_dilbert): - session = Session() + session = fixture_session() dilbert = get_dilbert(session) # this unusual test is selecting from the plain people/engineers @@ -518,7 +518,7 @@ class RoundTripTest(PolymorphTest): ) def test_subclass_getitem(self, get_dilbert): - session = Session() + session = fixture_session() dilbert = get_dilbert(session) is_( @@ -530,7 +530,7 @@ class RoundTripTest(PolymorphTest): def test_primary_table_only_for_requery(self): - session = Session() + session = fixture_session() if self.redefine_colprop: person_attribute_name = "person_name" @@ -560,7 +560,7 @@ class RoundTripTest(PolymorphTest): else: person_attribute_name = "name" - session = Session() + session = fixture_session() daboss = Boss( status="BBB", diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py index 84ef22d52..581fa45fd 100644 --- a/test/orm/inheritance/test_polymorphic_rel.py +++ b/test/orm/inheritance/test_polymorphic_rel.py @@ -14,7 +14,7 @@ from sqlalchemy.orm import with_polymorphic from sqlalchemy.testing import assert_raises from sqlalchemy.testing import eq_ from sqlalchemy.testing.assertsql import CompiledSQL -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from ._poly_fixtures import _Polymorphic from ._poly_fixtures import _PolymorphicAliasedJoins from ._poly_fixtures import _PolymorphicJoins @@ -68,7 +68,7 @@ class _PolymorphicTestBase(object): with_polymorphic is used. """ - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -83,7 +83,7 @@ class _PolymorphicTestBase(object): # For both joinedload() and subqueryload(), if the original q is # not loading the subclass table, the joinedload doesn't happen. - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -99,7 +99,7 @@ class _PolymorphicTestBase(object): def test_primary_eager_aliasing_subqueryload(self): # test that subqueryload does not occur because the parent # row cannot support it - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -116,7 +116,7 @@ class _PolymorphicTestBase(object): def test_primary_eager_aliasing_selectinload(self): # test that selectinload does not occur because the parent # row cannot support it - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -134,7 +134,7 @@ class _PolymorphicTestBase(object): # assert the JOINs don't over JOIN - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -168,28 +168,28 @@ class _PolymorphicTestBase(object): For all mappers, ensure the primary key has been calculated as just the "person_id" column. """ - sess = create_session() + sess = fixture_session() eq_( sess.get(Person, e1.person_id), Engineer(name="dilbert", primary_language="java"), ) def test_get_two(self): - sess = create_session() + sess = fixture_session() eq_( sess.get(Engineer, e1.person_id), Engineer(name="dilbert", primary_language="java"), ) def test_get_three(self): - sess = create_session() + sess = fixture_session() eq_( sess.get(Manager, b1.person_id), Boss(name="pointy haired boss", golf_swing="fore"), ) def test_multi_join(self): - sess = create_session() + sess = fixture_session() e = aliased(Person) c = aliased(Company) q = ( @@ -230,7 +230,7 @@ class _PolymorphicTestBase(object): ) def test_multi_join_future(self): - sess = create_session(future=True) + sess = fixture_session(future=True) e = aliased(Person) c = aliased(Company) @@ -279,22 +279,22 @@ class _PolymorphicTestBase(object): ) def test_filter_on_subclass_one(self): - sess = create_session() + sess = fixture_session() eq_(sess.query(Engineer).all()[0], Engineer(name="dilbert")) def test_filter_on_subclass_one_future(self): - sess = create_session(future=True) + sess = fixture_session(future=True) eq_( sess.execute(select(Engineer)).scalar(), Engineer(name="dilbert"), ) def test_filter_on_subclass_two(self): - sess = create_session() + sess = fixture_session() eq_(sess.query(Engineer).first(), Engineer(name="dilbert")) def test_filter_on_subclass_three(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Engineer) .filter(Engineer.person_id == e1.person_id) @@ -303,7 +303,7 @@ class _PolymorphicTestBase(object): ) def test_filter_on_subclass_four(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Manager) .filter(Manager.person_id == m1.person_id) @@ -312,7 +312,7 @@ class _PolymorphicTestBase(object): ) def test_filter_on_subclass_five(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Manager) .filter(Manager.person_id == b1.person_id) @@ -321,14 +321,14 @@ class _PolymorphicTestBase(object): ) def test_filter_on_subclass_six(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Boss).filter(Boss.person_id == b1.person_id).one(), Boss(name="pointy haired boss"), ) def test_join_from_polymorphic_nonaliased_one(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .join("paperwork") @@ -338,7 +338,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_nonaliased_one_future(self): - sess = create_session(future=True) + sess = fixture_session(future=True) eq_( sess.execute( select(Person) @@ -352,7 +352,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_nonaliased_two(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .order_by(Person.person_id) @@ -363,7 +363,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_nonaliased_three(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Engineer) .order_by(Person.person_id) @@ -374,7 +374,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_nonaliased_four(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .order_by(Person.person_id) @@ -386,7 +386,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_flag_aliased_one(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .order_by(Person.person_id) @@ -397,7 +397,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_flag_aliased_one_future(self): - sess = create_session(future=True) + sess = fixture_session(future=True) pa = aliased(Paperwork) eq_( @@ -414,7 +414,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_explicit_aliased_one(self): - sess = create_session() + sess = fixture_session() pa = aliased(Paperwork) eq_( sess.query(Person) @@ -426,7 +426,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_flag_aliased_two(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .order_by(Person.person_id) @@ -437,7 +437,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_explicit_aliased_two(self): - sess = create_session() + sess = fixture_session() pa = aliased(Paperwork) eq_( sess.query(Person) @@ -449,7 +449,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_flag_aliased_three(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Engineer) .order_by(Person.person_id) @@ -460,7 +460,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_explicit_aliased_three(self): - sess = create_session() + sess = fixture_session() pa = aliased(Paperwork) eq_( sess.query(Engineer) @@ -472,7 +472,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_polymorphic_aliased_four(self): - sess = create_session() + sess = fixture_session() pa = aliased(Paperwork) eq_( sess.query(Person) @@ -485,7 +485,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_with_polymorphic_nonaliased_one(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .with_polymorphic(Manager) @@ -497,7 +497,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_with_polymorphic_nonaliased_one_future(self): - sess = create_session(future=True) + sess = fixture_session(future=True) pm = with_polymorphic(Person, [Manager]) eq_( @@ -514,7 +514,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_with_polymorphic_nonaliased_two(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .with_polymorphic([Manager, Engineer]) @@ -526,7 +526,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_with_polymorphic_nonaliased_three(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .with_polymorphic([Manager, Engineer]) @@ -539,7 +539,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_with_polymorphic_flag_aliased_one(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .with_polymorphic(Manager) @@ -550,7 +550,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_with_polymorphic_explicit_aliased_one(self): - sess = create_session() + sess = fixture_session() pa = aliased(Paperwork) eq_( sess.query(Person) @@ -562,7 +562,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_with_polymorphic_flag_aliased_two(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .with_polymorphic([Manager, Engineer]) @@ -574,7 +574,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_with_polymorphic_explicit_aliased_two(self): - sess = create_session() + sess = fixture_session() pa = aliased(Paperwork) eq_( sess.query(Person) @@ -587,7 +587,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_with_polymorphic_aliased_three(self): - sess = create_session() + sess = fixture_session() pa = aliased(Paperwork) eq_( @@ -602,7 +602,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_polymorphic_nonaliased(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join("employees") @@ -612,7 +612,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_polymorphic_flag_aliased(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join("employees", aliased=True) @@ -622,7 +622,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_polymorphic_explicit_aliased(self): - sess = create_session() + sess = fixture_session() ea = aliased(Person) eq_( sess.query(Company) @@ -633,13 +633,13 @@ class _PolymorphicTestBase(object): ) def test_polymorphic_any_one(self): - sess = create_session() + sess = fixture_session() any_ = Company.employees.any(Person.name == "vlad") eq_(sess.query(Company).filter(any_).all(), [c2]) def test_polymorphic_any_flag_alias_two(self): - sess = create_session() + sess = fixture_session() # test that the aliasing on "Person" does not bleed into the # EXISTS clause generated by any() any_ = Company.employees.any(Person.name == "wally") @@ -653,7 +653,7 @@ class _PolymorphicTestBase(object): ) def test_polymorphic_any_explicit_alias_two(self): - sess = create_session() + sess = fixture_session() # test that the aliasing on "Person" does not bleed into the # EXISTS clause generated by any() any_ = Company.employees.any(Person.name == "wally") @@ -668,7 +668,7 @@ class _PolymorphicTestBase(object): ) def test_polymorphic_any_three(self): - sess = create_session() + sess = fixture_session() any_ = Company.employees.any(Person.name == "vlad") ea = aliased(Person) eq_( @@ -681,7 +681,7 @@ class _PolymorphicTestBase(object): ) def test_polymorphic_any_eight(self): - sess = create_session() + sess = fixture_session() any_ = Engineer.machines.any(Machine.name == "Commodore 64") eq_( sess.query(Person).order_by(Person.person_id).filter(any_).all(), @@ -689,7 +689,7 @@ class _PolymorphicTestBase(object): ) def test_polymorphic_any_nine(self): - sess = create_session() + sess = fixture_session() any_ = Person.paperwork.any(Paperwork.description == "review #2") eq_( sess.query(Person).order_by(Person.person_id).filter(any_).all(), @@ -697,13 +697,13 @@ class _PolymorphicTestBase(object): ) def test_join_from_columns_or_subclass_one(self): - sess = create_session() + sess = fixture_session() expected = [("dogbert",), ("pointy haired boss",)] eq_(sess.query(Manager.name).order_by(Manager.name).all(), expected) def test_join_from_columns_or_subclass_two(self): - sess = create_session() + sess = fixture_session() expected = [("dogbert",), ("dogbert",), ("pointy haired boss",)] eq_( sess.query(Manager.name) @@ -714,7 +714,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_columns_or_subclass_three(self): - sess = create_session() + sess = fixture_session() expected = [ ("dilbert",), ("dilbert",), @@ -734,7 +734,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_columns_or_subclass_four(self): - sess = create_session() + sess = fixture_session() # Load Person.name, joining from Person -> paperwork, get all # the people. expected = [ @@ -756,7 +756,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_columns_or_subclass_five(self): - sess = create_session() + sess = fixture_session() # same, on manager. get only managers. expected = [("dogbert",), ("dogbert",), ("pointy haired boss",)] eq_( @@ -768,7 +768,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_columns_or_subclass_six(self): - sess = create_session() + sess = fixture_session() if self.select_type == "": # this now raises, due to [ticket:1892]. Manager.person_id # is now the "person_id" column on Manager. SQL is incorrect. @@ -813,7 +813,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_columns_or_subclass_seven(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Manager) .join(Paperwork, Manager.paperwork) @@ -823,7 +823,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_columns_or_subclass_eight(self): - sess = create_session() + sess = fixture_session() expected = [("dogbert",), ("dogbert",), ("pointy haired boss",)] eq_( sess.query(Manager.name) @@ -834,7 +834,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_columns_or_subclass_nine(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Manager.person_id) .join(paperwork, Manager.person_id == paperwork.c.person_id) @@ -844,7 +844,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_columns_or_subclass_ten(self): - sess = create_session() + sess = fixture_session() expected = [ ("pointy haired boss", "review #1"), ("dogbert", "review #2"), @@ -859,7 +859,7 @@ class _PolymorphicTestBase(object): ) def test_join_from_columns_or_subclass_eleven(self): - sess = create_session() + sess = fixture_session() expected = [("pointy haired boss",), ("dogbert",), ("dogbert",)] malias = aliased(Manager) eq_( @@ -870,7 +870,7 @@ class _PolymorphicTestBase(object): ) def test_subclass_option_pathing(self): - sess = create_session() + sess = fixture_session() dilbert = ( sess.query(Person) .options(defaultload(Engineer.machines).defer(Machine.name)) @@ -887,7 +887,7 @@ class _PolymorphicTestBase(object): the select_table mapper. """ - sess = create_session() + sess = fixture_session() name = "dogbert" m1 = sess.query(Manager).filter(Manager.name == name).one() @@ -900,7 +900,7 @@ class _PolymorphicTestBase(object): assert m2.golf_swing == "fore" def test_with_polymorphic_one(self): - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -914,7 +914,7 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, 1) def test_with_polymorphic_two(self): - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -928,7 +928,7 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, 1) def test_with_polymorphic_three(self): - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -942,7 +942,7 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, 3) def test_with_polymorphic_four(self): - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -956,7 +956,7 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, 3) def test_with_polymorphic_five(self): - sess = create_session() + sess = fixture_session() def go(): # limit the polymorphic join down to just "Person", @@ -969,7 +969,7 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, 6) def test_with_polymorphic_six(self): - sess = create_session() + sess = fixture_session() assert_raises( sa_exc.InvalidRequestError, @@ -988,7 +988,7 @@ class _PolymorphicTestBase(object): ) def test_with_polymorphic_seven(self): - sess = create_session() + sess = fixture_session() # compare to entities without related collections to prevent # additional lazy SQL from firing on loaded entities eq_( @@ -1001,7 +1001,7 @@ class _PolymorphicTestBase(object): def test_relationship_to_polymorphic_one(self): expected = self._company_with_emps_machines_fixture() - sess = create_session() + sess = fixture_session() def go(): # test load Companies with lazy load to 'employees' @@ -1012,7 +1012,7 @@ class _PolymorphicTestBase(object): def test_relationship_to_polymorphic_two(self): expected = self._company_with_emps_machines_fixture() - sess = create_session() + sess = fixture_session() def go(): # with #2438, of_type() is recognized. This @@ -1040,9 +1040,9 @@ class _PolymorphicTestBase(object): def test_relationship_to_polymorphic_three(self): expected = self._company_with_emps_machines_fixture() - sess = create_session() + sess = fixture_session() - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -1072,7 +1072,7 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, count) def test_joinedload_on_subclass(self): - sess = create_session() + sess = fixture_session() expected = [ Engineer( name="dilbert", @@ -1100,7 +1100,7 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, 1) def test_subqueryload_on_subclass(self): - sess = create_session() + sess = fixture_session() expected = [ Engineer( name="dilbert", @@ -1140,12 +1140,12 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, 2) def test_query_subclass_join_to_base_relationship(self): - sess = create_session() + sess = fixture_session() # non-polymorphic eq_(sess.query(Engineer).join(Person.paperwork).all(), [e1, e2, e3]) def test_join_to_subclass(self): - sess = create_session() + sess = fixture_session() # TODO: these should all be deprecated (?) - these joins are on the # core tables and should not be getting adapted, not sure why @@ -1161,7 +1161,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_one(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .select_from(companies.join(people).join(engineers)) @@ -1171,7 +1171,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_two(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join(people.join(engineers), "employees") @@ -1181,7 +1181,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_three(self): - sess = create_session() + sess = fixture_session() ealias = aliased(Engineer) eq_( sess.query(Company) @@ -1192,7 +1192,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_six(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join(people.join(engineers), "employees") @@ -1202,7 +1202,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_six_point_five(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join(people.join(engineers), "employees") @@ -1213,7 +1213,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_seven(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join(people.join(engineers), "employees") @@ -1224,11 +1224,11 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_eight(self): - sess = create_session() + sess = fixture_session() eq_(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3]) def test_join_to_subclass_nine(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .select_from(companies.join(people).join(engineers)) @@ -1238,7 +1238,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_ten(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join("employees") @@ -1248,7 +1248,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_eleven(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .select_from(companies.join(people).join(engineers)) @@ -1258,11 +1258,11 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_twelve(self): - sess = create_session() + sess = fixture_session() eq_(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3]) def test_join_to_subclass_thirteen(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .join(Engineer.machines) @@ -1272,14 +1272,14 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_fourteen(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company).join("employees", Engineer.machines).all(), [c1, c2], ) def test_join_to_subclass_fifteen(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join("employees", Engineer.machines) @@ -1289,12 +1289,12 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_sixteen(self): - sess = create_session() + sess = fixture_session() # non-polymorphic eq_(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3]) def test_join_to_subclass_seventeen(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Engineer) .join(Engineer.machines) @@ -1304,7 +1304,7 @@ class _PolymorphicTestBase(object): ) def test_join_and_thru_polymorphic_nonaliased_one(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join(Company.employees) @@ -1314,7 +1314,7 @@ class _PolymorphicTestBase(object): ) def test_join_and_thru_polymorphic_aliased_one(self): - sess = create_session() + sess = fixture_session() ea = aliased(Person) pa = aliased(Paperwork) eq_( @@ -1326,7 +1326,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_nonaliased_one(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join(Company.employees) @@ -1337,7 +1337,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_nonaliased_two(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join(Company.employees) @@ -1348,7 +1348,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_nonaliased_three(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join(Company.employees) @@ -1360,7 +1360,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_nonaliased_four(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join(Company.employees) @@ -1372,7 +1372,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_nonaliased_five(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join("employees") @@ -1384,7 +1384,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_nonaliased_six(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Company) .join("employees") @@ -1396,7 +1396,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_aliased_one(self): - sess = create_session() + sess = fixture_session() ea = aliased(Person) pa = aliased(Paperwork) eq_( @@ -1409,7 +1409,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_aliased_two(self): - sess = create_session() + sess = fixture_session() ea = aliased(Person) pa = aliased(Paperwork) eq_( @@ -1422,7 +1422,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_aliased_three(self): - sess = create_session() + sess = fixture_session() ea = aliased(Person) pa = aliased(Paperwork) eq_( @@ -1436,7 +1436,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_aliased_four(self): - sess = create_session() + sess = fixture_session() ea = aliased(Person) pa = aliased(Paperwork) eq_( @@ -1450,7 +1450,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_aliased_five(self): - sess = create_session() + sess = fixture_session() ea = aliased(Person) pa = aliased(Paperwork) eq_( @@ -1464,7 +1464,7 @@ class _PolymorphicTestBase(object): ) def test_join_through_polymorphic_aliased_six(self): - sess = create_session() + sess = fixture_session() pa = aliased(Paperwork) ea = aliased(Person) eq_( @@ -1478,7 +1478,7 @@ class _PolymorphicTestBase(object): ) def test_explicit_polymorphic_join_one(self): - sess = create_session() + sess = fixture_session() # join from Company to Engineer; join condition formulated by # ORMJoin using regular table foreign key connections. Engineer @@ -1493,7 +1493,7 @@ class _PolymorphicTestBase(object): ) def test_explicit_polymorphic_join_two(self): - sess = create_session() + sess = fixture_session() # same, using explicit join condition. Query.join() must # adapt the on clause here to match the subquery wrapped around @@ -1507,7 +1507,7 @@ class _PolymorphicTestBase(object): ) def test_filter_on_baseclass(self): - sess = create_session() + sess = fixture_session() eq_(sess.query(Person).order_by(Person.person_id).all(), all_employees) eq_( sess.query(Person).order_by(Person.person_id).first(), @@ -1522,7 +1522,7 @@ class _PolymorphicTestBase(object): ) def test_from_alias(self): - sess = create_session() + sess = fixture_session() palias = aliased(Person) eq_( sess.query(palias) @@ -1533,7 +1533,7 @@ class _PolymorphicTestBase(object): ) def test_self_referential_one(self): - sess = create_session() + sess = fixture_session() palias = aliased(Person) expected = [(m1, e1), (m1, e2), (m1, b1)] @@ -1549,7 +1549,7 @@ class _PolymorphicTestBase(object): def test_self_referential_two(self): - sess = create_session() + sess = fixture_session() palias = aliased(Person) expected = [(m1, e1), (m1, e2), (m1, b1)] @@ -1567,7 +1567,7 @@ class _PolymorphicTestBase(object): def test_self_referential_two_point_five(self): """Using two aliases, the above case works.""" - sess = create_session() + sess = fixture_session() palias = aliased(Person) palias2 = aliased(Person) @@ -1588,7 +1588,7 @@ class _PolymorphicTestBase(object): def test_self_referential_two_future(self): # TODO: this is the SECOND test *EVER* of an aliased class of # an aliased class. - sess = create_session(future=True) + sess = fixture_session(future=True) expected = [(m1, e1), (m1, e2), (m1, b1)] # not aliasing the first class @@ -1618,7 +1618,7 @@ class _PolymorphicTestBase(object): # TODO: this is the first test *EVER* of an aliased class of # an aliased class. we should add many more tests for this. # new case added in Id810f485c5f7ed971529489b84694e02a3356d6d - sess = create_session(future=True) + sess = fixture_session(future=True) expected = [(m1, e1), (m1, e2), (m1, b1)] # aliasing the first class @@ -1648,7 +1648,7 @@ class _PolymorphicTestBase(object): # second "filter" from hitting it, which would pollute the # subquery and usually results in recursion overflow errors # within the adaption. - sess = create_session() + sess = fixture_session() subq = ( sess.query(engineers.c.person_id) .filter(Engineer.primary_language == "java") @@ -1658,7 +1658,7 @@ class _PolymorphicTestBase(object): eq_(sess.query(Person).filter(Person.person_id.in_(subq)).one(), e1) def test_mixed_entities_one(self): - sess = create_session() + sess = fixture_session() expected = [ ( @@ -1744,7 +1744,7 @@ class _PolymorphicTestBase(object): _join_to_poly_wp_three, ) def test_mixed_entities_join_to_poly(self, q): - sess = create_session() + sess = fixture_session() expected = [ ("dilbert", "MegaCorp, Inc."), ("wally", "MegaCorp, Inc."), @@ -1758,7 +1758,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_two(self): - sess = create_session() + sess = fixture_session() expected = [ ("java", "MegaCorp, Inc."), ("cobol", "Elbonia, Inc."), @@ -1774,7 +1774,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_three(self): - sess = create_session() + sess = fixture_session() palias = aliased(Person) expected = [ ( @@ -1807,7 +1807,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_four(self): - sess = create_session() + sess = fixture_session() palias = aliased(Person) expected = [ ( @@ -1841,7 +1841,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_five(self): - sess = create_session() + sess = fixture_session() palias = aliased(Person) expected = [("vlad", "Elbonia, Inc.", "dilbert")] eq_( @@ -1855,7 +1855,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_six(self): - sess = create_session() + sess = fixture_session() palias = aliased(Person) expected = [ ("manager", "dogbert", "engineer", "dilbert"), @@ -1873,7 +1873,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_seven(self): - sess = create_session() + sess = fixture_session() expected = [ ("dilbert", "tps report #1"), ("dilbert", "tps report #2"), @@ -1893,7 +1893,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_eight(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(func.count(Person.person_id)) .filter(Engineer.primary_language == "java") @@ -1902,7 +1902,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_nine(self): - sess = create_session() + sess = fixture_session() expected = [("Elbonia, Inc.", 1), ("MegaCorp, Inc.", 4)] eq_( sess.query(Company.name, func.count(Person.person_id)) @@ -1914,7 +1914,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_ten(self): - sess = create_session() + sess = fixture_session() expected = [("Elbonia, Inc.", 1), ("MegaCorp, Inc.", 4)] eq_( sess.query(Company.name, func.count(Person.person_id)) @@ -1926,7 +1926,7 @@ class _PolymorphicTestBase(object): ) # def test_mixed_entities(self): - # sess = create_session() + # sess = fixture_session() # TODO: I think raise error on these for now. different # inheritance/loading schemes have different results here, # all incorrect @@ -1936,7 +1936,7 @@ class _PolymorphicTestBase(object): # []) # def test_mixed_entities(self): - # sess = create_session() + # sess = fixture_session() # eq_(sess.query( # Person.name, # Engineer.primary_language, @@ -1945,7 +1945,7 @@ class _PolymorphicTestBase(object): # []) def test_mixed_entities_eleven(self): - sess = create_session() + sess = fixture_session() expected = [("java",), ("c++",), ("cobol",)] eq_( sess.query(Engineer.primary_language) @@ -1955,7 +1955,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_twelve(self): - sess = create_session() + sess = fixture_session() expected = [("vlad", "Elbonia, Inc.")] eq_( sess.query(Person.name, Company.name) @@ -1966,12 +1966,12 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_thirteen(self): - sess = create_session() + sess = fixture_session() expected = [("pointy haired boss", "fore")] eq_(sess.query(Boss.name, Boss.golf_swing).all(), expected) def test_mixed_entities_fourteen(self): - sess = create_session() + sess = fixture_session() expected = [("dilbert", "java"), ("wally", "c++"), ("vlad", "cobol")] eq_( sess.query(Engineer.name, Engineer.primary_language).all(), @@ -1979,7 +1979,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_fifteen(self): - sess = create_session() + sess = fixture_session() expected = [ ( @@ -2001,7 +2001,7 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_sixteen(self): - sess = create_session() + sess = fixture_session() expected = [ ( Engineer( @@ -2022,17 +2022,17 @@ class _PolymorphicTestBase(object): ) def test_mixed_entities_seventeen(self): - sess = create_session() + sess = fixture_session() expected = [("pointy haired boss",), ("dogbert",)] eq_(sess.query(Manager.name).all(), expected) def test_mixed_entities_eighteen(self): - sess = create_session() + sess = fixture_session() expected = [("pointy haired boss foo",), ("dogbert foo",)] eq_(sess.query(Manager.name + " foo").all(), expected) def test_mixed_entities_nineteen(self): - sess = create_session() + sess = fixture_session() row = ( sess.query(Engineer.name, Engineer.primary_language) .filter(Engineer.name == "dilbert") @@ -2042,7 +2042,7 @@ class _PolymorphicTestBase(object): assert row.primary_language == "java" def test_correlation_one(self): - sess = create_session() + sess = fixture_session() # this for a long time did not work with PolymorphicAliased and # PolymorphicUnions, which was due to the no_replacement_traverse @@ -2063,7 +2063,7 @@ class _PolymorphicTestBase(object): ) def test_correlation_two(self): - sess = create_session() + sess = fixture_session() paliased = aliased(Person) @@ -2081,7 +2081,7 @@ class _PolymorphicTestBase(object): ) def test_correlation_three(self): - sess = create_session() + sess = fixture_session() paliased = aliased(Person, flat=True) @@ -2101,7 +2101,7 @@ class _PolymorphicTestBase(object): class PolymorphicTest(_PolymorphicTestBase, _Polymorphic): def test_join_to_subclass_four(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .select_from(people.join(engineers)) @@ -2111,7 +2111,7 @@ class PolymorphicTest(_PolymorphicTestBase, _Polymorphic): ) def test_join_to_subclass_five(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person) .select_from(people.join(engineers)) @@ -2123,7 +2123,7 @@ class PolymorphicTest(_PolymorphicTestBase, _Polymorphic): def test_correlation_w_polymorphic(self): - sess = create_session() + sess = fixture_session() p_poly = with_polymorphic(Person, "*") @@ -2142,7 +2142,7 @@ class PolymorphicTest(_PolymorphicTestBase, _Polymorphic): def test_correlation_w_polymorphic_flat(self): - sess = create_session() + sess = fixture_session() p_poly = with_polymorphic(Person, "*", flat=True) @@ -2184,7 +2184,7 @@ class PolymorphicPolymorphicTest( # aliased(polymorphic) will normally do the old-school # "(SELECT * FROM a JOIN b ...) AS anon_1" thing. # this is the safest - sess = create_session() + sess = fixture_session() palias = aliased(Person) self.assert_compile( sess.query(palias, Company.name) @@ -2235,7 +2235,7 @@ class PolymorphicPolymorphicTest( ) def test_flat_aliased_w_select_from(self): - sess = create_session() + sess = fixture_session() palias = aliased(Person, flat=True) self.assert_compile( sess.query(palias, Company.name) @@ -2275,7 +2275,7 @@ class PolymorphicPolymorphicTest( class PolymorphicUnionsTest(_PolymorphicTestBase, _PolymorphicUnions): def test_subqueryload_on_subclass_uses_path_correctly(self): - sess = create_session() + sess = fixture_session() expected = [ Engineer( name="dilbert", @@ -2361,7 +2361,7 @@ class PolymorphicAliasedJoinsTest( class PolymorphicJoinsTest(_PolymorphicTestBase, _PolymorphicJoins): def test_having_group_by(self): - sess = create_session() + sess = fixture_session() eq_( sess.query(Person.name) .group_by(Person.name) diff --git a/test/orm/inheritance/test_productspec.py b/test/orm/inheritance/test_productspec.py index 35c7565fb..e940cb0f4 100644 --- a/test/orm/inheritance/test_productspec.py +++ b/test/orm/inheritance/test_productspec.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import deferred from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.testing import fixtures -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -174,7 +174,7 @@ class InheritTest(fixtures.MappedTest): ), ) - session = create_session() + session = fixture_session() a1 = Assembly(name="a1") @@ -223,7 +223,7 @@ class InheritTest(fixtures.MappedTest): ), ) - session = create_session() + session = fixture_session() s = SpecLine(follower=Product(name="p1")) s2 = SpecLine(follower=Detail(name="d1")) @@ -300,7 +300,7 @@ class InheritTest(fixtures.MappedTest): polymorphic_identity="raster_document", ) - session = create_session() + session = fixture_session() a1 = Assembly(name="a1") a1.specification.append(SpecLine(follower=Detail(name="d1"))) @@ -359,7 +359,7 @@ class InheritTest(fixtures.MappedTest): polymorphic_identity="raster_document", ) - session = create_session() + session = fixture_session() a1 = Assembly(name="a1") a1.documents.append(RasterDocument("doc2")) @@ -448,7 +448,7 @@ class InheritTest(fixtures.MappedTest): mapper(Assembly, inherits=Product, polymorphic_identity="assembly") - session = create_session() + session = fixture_session() a1 = Assembly(name="a1") a1.specification.append(SpecLine(follower=Detail(name="d1"))) diff --git a/test/orm/inheritance/test_relationship.py b/test/orm/inheritance/test_relationship.py index 6879e1465..214be5e9a 100644 --- a/test/orm/inheritance/test_relationship.py +++ b/test/orm/inheritance/test_relationship.py @@ -21,7 +21,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing.entities import ComparableEntity -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -116,7 +116,7 @@ class SelfReferentialTestJoinedToBase(fixtures.MappedTest): def test_has(self): p1 = Person(name="dogbert") e1 = Engineer(name="dilbert", primary_language="java", reports_to=p1) - sess = create_session() + sess = fixture_session() sess.add(p1) sess.add(e1) sess.flush() @@ -131,7 +131,7 @@ class SelfReferentialTestJoinedToBase(fixtures.MappedTest): def test_oftype_aliases_in_exists(self): e1 = Engineer(name="dilbert", primary_language="java") e2 = Engineer(name="wally", primary_language="c++", reports_to=e1) - sess = create_session() + sess = fixture_session() sess.add_all([e1, e2]) sess.flush() eq_( @@ -148,7 +148,7 @@ class SelfReferentialTestJoinedToBase(fixtures.MappedTest): def test_join(self): p1 = Person(name="dogbert") e1 = Engineer(name="dilbert", primary_language="java", reports_to=p1) - sess = create_session() + sess = fixture_session() sess.add(p1) sess.add(e1) sess.flush() @@ -242,7 +242,7 @@ class SelfReferentialJ2JTest(fixtures.MappedTest): def test_has(self): m1 = Manager(name="dogbert") e1 = Engineer(name="dilbert", primary_language="java", reports_to=m1) - sess = create_session() + sess = fixture_session() sess.add(m1) sess.add(e1) sess.flush() @@ -258,7 +258,7 @@ class SelfReferentialJ2JTest(fixtures.MappedTest): def test_join(self): m1 = Manager(name="dogbert") e1 = Engineer(name="dilbert", primary_language="java", reports_to=m1) - sess = create_session() + sess = fixture_session() sess.add(m1) sess.add(e1) sess.flush() @@ -281,7 +281,7 @@ class SelfReferentialJ2JTest(fixtures.MappedTest): e2 = Engineer(name="dilbert", primary_language="c++", reports_to=m2) e3 = Engineer(name="etc", primary_language="c++") - sess = create_session() + sess = fixture_session() sess.add_all([m1, m2, e1, e2, e3]) sess.flush() sess.expunge_all() @@ -318,7 +318,7 @@ class SelfReferentialJ2JTest(fixtures.MappedTest): e2 = Engineer(name="wally", primary_language="c++", reports_to=m2) e3 = Engineer(name="etc", primary_language="c++") - sess = create_session() + sess = fixture_session() sess.add(m1) sess.add(m2) sess.add(e1) @@ -409,13 +409,13 @@ class SelfReferentialJ2JSelfTest(fixtures.MappedTest): def _two_obj_fixture(self): e1 = Engineer(name="wally") e2 = Engineer(name="dilbert", reports_to=e1) - sess = Session() + sess = fixture_session() sess.add_all([e1, e2]) sess.commit() return sess def _five_obj_fixture(self): - sess = Session() + sess = fixture_session() e1, e2, e3, e4, e5 = [Engineer(name="e%d" % (i + 1)) for i in range(5)] e3.reports_to = e1 e4.reports_to = e2 @@ -596,7 +596,7 @@ class M2MFilterTest(fixtures.MappedTest): def test_not_contains(self): Organization = self.classes.Organization - sess = create_session() + sess = fixture_session() e1 = sess.query(Person).filter(Engineer.name == "e1").one() eq_( @@ -615,7 +615,7 @@ class M2MFilterTest(fixtures.MappedTest): ) def test_any(self): - sess = create_session() + sess = fixture_session() Organization = self.classes.Organization eq_( @@ -718,7 +718,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): def test_query_crit(self): Child1, Child2 = self.classes.Child1, self.classes.Child2 - sess = create_session() + sess = fixture_session() c11, c12, c13 = Child1(), Child1(), Child1() c21, c22, c23 = Child2(), Child2(), Child2() c11.left_child2 = c22 @@ -812,7 +812,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): def test_eager_join(self): Child1, Child2 = self.classes.Child1, self.classes.Child2 - sess = create_session() + sess = fixture_session() c1 = Child1() c1.left_child2 = Child2() sess.add(c1) @@ -849,7 +849,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): def test_subquery_load(self): Child1, Child2 = self.classes.Child1, self.classes.Child2 - sess = create_session() + sess = fixture_session() c1 = Child1() c1.left_child2 = Child2() sess.add(c1) @@ -974,7 +974,7 @@ class EagerToSubclassTest(fixtures.MappedTest): def test_joinedload(self): Parent = self.classes.Parent - sess = Session() + sess = fixture_session() def go(): eq_( @@ -987,7 +987,7 @@ class EagerToSubclassTest(fixtures.MappedTest): def test_contains_eager(self): Parent = self.classes.Parent Sub = self.classes.Sub - sess = Session() + sess = fixture_session() def go(): eq_( @@ -1004,7 +1004,7 @@ class EagerToSubclassTest(fixtures.MappedTest): def test_subq_through_related(self): Parent = self.classes.Parent Base = self.classes.Base - sess = Session() + sess = fixture_session() def go(): eq_( @@ -1023,7 +1023,7 @@ class EagerToSubclassTest(fixtures.MappedTest): Parent = self.classes.Parent Base = self.classes.Base pa = aliased(Parent) - sess = Session() + sess = fixture_session() def go(): eq_( @@ -1150,7 +1150,7 @@ class SubClassEagerToSubClassTest(fixtures.MappedTest): def test_joinedload(self): Subparent = self.classes.Subparent - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -1175,7 +1175,7 @@ class SubClassEagerToSubClassTest(fixtures.MappedTest): def test_contains_eager(self): Subparent = self.classes.Subparent - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -1204,7 +1204,7 @@ class SubClassEagerToSubClassTest(fixtures.MappedTest): def test_subqueryload(self): Subparent = self.classes.Subparent - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -1336,7 +1336,7 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): C = self.classes.C D = self.classes.D - session = Session() + session = fixture_session() d = session.query(D).one() a_poly = with_polymorphic(A, [B, C]) @@ -1354,7 +1354,7 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): C = self.classes.C D = self.classes.D - session = Session() + session = fixture_session() d = session.query(D).one() def go(): @@ -1375,7 +1375,7 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): C = self.classes.C D = self.classes.D - session = Session() + session = fixture_session() d = session.query(D).one() a_poly = with_polymorphic(A, [B, C]) @@ -1393,7 +1393,7 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): C = self.classes.C D = self.classes.D - session = Session() + session = fixture_session() d = session.query(D).one() def go(): @@ -1499,7 +1499,7 @@ class SubClassToSubClassFromParentTest(fixtures.MappedTest): def test_2617(self): A = self.classes.A - session = Session() + session = fixture_session() def go(): a1 = session.query(A).first() @@ -1642,7 +1642,7 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): def test_one(self): Parent, Base1, Base2, Sub1, Sub2, EP1, EP2 = self._classes() - s = Session() + s = fixture_session() self.assert_compile( s.query(Parent) .join(Parent.sub1, Sub1.sub2) @@ -1663,7 +1663,7 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): s2a = aliased(Sub2, flat=True) - s = Session() + s = fixture_session() self.assert_compile( s.query(Parent).join(Parent.sub1).join(s2a, Sub1.sub2), "SELECT parent.id AS parent_id, parent.data AS parent_data " @@ -1677,7 +1677,7 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): def test_three(self): Parent, Base1, Base2, Sub1, Sub2, EP1, EP2 = self._classes() - s = Session() + s = fixture_session() self.assert_compile( s.query(Base1).join(Base1.sub2).join(Sub2.ep1).join(Sub2.ep2), "SELECT base1.id AS base1_id, base1.data AS base1_data " @@ -1691,7 +1691,7 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): def test_four(self): Parent, Base1, Base2, Sub1, Sub2, EP1, EP2 = self._classes() - s = Session() + s = fixture_session() self.assert_compile( s.query(Sub2) .join(Base1, Base1.id == Sub2.base1_id) @@ -1709,7 +1709,7 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): def test_five(self): Parent, Base1, Base2, Sub1, Sub2, EP1, EP2 = self._classes() - s = Session() + s = fixture_session() self.assert_compile( s.query(Sub2) .join(Sub1, Sub1.id == Sub2.base1_id) @@ -1729,7 +1729,7 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): def test_six_legacy(self): Parent, Base1, Base2, Sub1, Sub2, EP1, EP2 = self._classes() - s = Session() + s = fixture_session() # as of from_self() changing in # I3abfb45dd6e50f84f29d39434caa0b550ce27864, @@ -1781,7 +1781,7 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): def test_seven_legacy(self): Parent, Base1, Base2, Sub1, Sub2, EP1, EP2 = self._classes() - s = Session() + s = fixture_session() # as of from_self() changing in # I3abfb45dd6e50f84f29d39434caa0b550ce27864, @@ -1946,7 +1946,7 @@ class JoinedloadWPolyOfTypeContinued( def test_joined_load_lastlink_subclass(self): Foo, User, SubBar = self.classes("Foo", "User", "SubBar") - s = Session() + s = fixture_session() foo_polymorphic = with_polymorphic(Foo, "*", aliased=True) @@ -1992,7 +1992,7 @@ class JoinedloadWPolyOfTypeContinued( def test_joined_load_lastlink_baseclass(self): Foo, User, Bar = self.classes("Foo", "User", "Bar") - s = Session() + s = fixture_session() foo_polymorphic = with_polymorphic(Foo, "*", aliased=True) @@ -2072,7 +2072,7 @@ class ContainsEagerMultipleOfType( def test_contains_eager_multi_alias(self): X, B, A = self.classes("X", "B", "A") - s = Session() + s = fixture_session() a_b_alias = aliased(B, name="a_b") b_x_alias = aliased(X, name="b_x") @@ -2141,7 +2141,7 @@ class JoinedloadSinglePolysubSingle( def test_query(self): Thing = self.classes.Thing - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(Thing), "SELECT things.id AS things_id, " @@ -2281,7 +2281,7 @@ class JoinedloadOverWPolyAliased( cls = fn() Link = self.classes.Link - session = Session() + session = fixture_session() q = session.query(cls).options( joinedload(cls.links).joinedload(Link.child).joinedload(cls.links) ) @@ -2310,7 +2310,7 @@ class JoinedloadOverWPolyAliased( parent_cls = fn() Link = self.classes.Link - session = Session() + session = fixture_session() q = session.query(Link).options( joinedload(Link.child).joinedload(parent_cls.owner) ) @@ -2341,7 +2341,7 @@ class JoinedloadOverWPolyAliased( poly = with_polymorphic(Parent, [Sub1]) - session = Session() + session = fixture_session() q = session.query(poly).options( joinedload(poly.Sub1.links) .joinedload(Link.child.of_type(Sub1)) @@ -2369,7 +2369,7 @@ class JoinedloadOverWPolyAliased( poly = with_polymorphic(Parent, [Sub1]) - session = Session() + session = fixture_session() q = session.query(poly).options( joinedload(poly.Sub1.links, innerjoin=True) .joinedload(Link.child.of_type(Sub1), innerjoin=True) @@ -2395,7 +2395,7 @@ class JoinedloadOverWPolyAliased( Parent = self.classes.Parent Link = self.classes.Link - session = Session() + session = fixture_session() session.add_all([Parent(), Parent()]) # represents "Parent" and "Sub1" rows @@ -2472,7 +2472,7 @@ class JoinAcrossJoinedInhMultiPath( t1_alias = aliased(Target) t2_alias = aliased(Target) - sess = Session() + sess = fixture_session() q = ( sess.query(Root) .join(s1_alias, Root.sub1) @@ -2508,7 +2508,7 @@ class JoinAcrossJoinedInhMultiPath( t1_alias = aliased(Target) t2_alias = aliased(Target) - sess = Session() + sess = fixture_session() q = ( sess.query(Root) .join(s1_alias, Root.sub1) @@ -2539,7 +2539,7 @@ class JoinAcrossJoinedInhMultiPath( self.classes.Sub1, ) - sess = Session() + sess = fixture_session() q = sess.query(Root).options( joinedload(Root.sub1).joinedload(Sub1.target), joinedload(Root.intermediate) @@ -2631,7 +2631,7 @@ class MultipleAdaptUsesEntityOverTableTest( def _two_join_fixture(self): B, C, D = (self.classes.B, self.classes.C, self.classes.D) - s = Session() + s = fixture_session() return ( s.query(B.name, C.name, D.name) .select_from(B) @@ -2805,7 +2805,7 @@ class BetweenSubclassJoinWExtraJoinedLoad( def test_query(self): Engineer, Manager = self.classes("Engineer", "Manager") - sess = Session() + sess = fixture_session() # eager join is both from Enginer->LastSeen as well as # Manager->LastSeen. In the case of Manager->LastSeen, @@ -2885,7 +2885,7 @@ class M2ODontLoadSiblingTest(fixtures.DeclarativeMappedTest): def test_load_m2o_emit_query(self): Other, Child1 = self.classes("Other", "Child1") - s = Session() + s = fixture_session() obj = s.query(Other).first() @@ -2894,7 +2894,7 @@ class M2ODontLoadSiblingTest(fixtures.DeclarativeMappedTest): def test_load_m2o_use_get(self): Other, Child1 = self.classes("Other", "Child1") - s = Session() + s = fixture_session() obj = s.query(Other).first() c1 = s.query(Child1).first() diff --git a/test/orm/inheritance/test_selects.py b/test/orm/inheritance/test_selects.py index dab184194..24297dd0e 100644 --- a/test/orm/inheritance/test_selects.py +++ b/test/orm/inheritance/test_selects.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import mapper from sqlalchemy.orm import Session from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -117,7 +118,7 @@ class JoinFromSelectPersistenceTest(fixtures.MappedTest): ) mapper(Child, child, inherits=Base, polymorphic_identity="child") - sess = Session() + sess = fixture_session() # 2. use an id other than "1" here so can't rely on # the two inserts having the same id diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index cbe6bd238..11c6bb212 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -16,14 +16,13 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import Session -from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import subqueryload from sqlalchemy.orm import with_polymorphic from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertsql import CompiledSQL -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -126,7 +125,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): self.classes.Engineer, ) - session = create_session() + session = fixture_session() m1 = Manager(name="Tom", manager_data="knows how to manage things") e1 = Engineer(name="Kurt", engineer_info="knows how to hack") @@ -283,7 +282,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): def test_from_self_legacy(self): Engineer = self.classes.Engineer - sess = create_session() + sess = fixture_session() with testing.expect_deprecated(r"The Query.from_self\(\) method"): self.assert_compile( sess.query(Engineer).from_self(), @@ -350,7 +349,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): def test_select_from_aliased_w_subclass(self): Engineer = self.classes.Engineer - sess = create_session() + sess = fixture_session() a1 = aliased(Engineer) self.assert_compile( @@ -369,7 +368,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): def test_union_modifiers(self): Engineer, Manager = self.classes("Engineer", "Manager") - sess = create_session() + sess = fixture_session() q1 = sess.query(Engineer).filter(Engineer.engineer_info == "foo") q2 = sess.query(Manager).filter(Manager.manager_data == "bar") @@ -420,7 +419,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): Engineer, Manager = self.classes("Engineer", "Manager") - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Engineer) @@ -437,7 +436,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): def test_from_self_count(self): Engineer = self.classes.Engineer - sess = create_session() + sess = fixture_session() col = func.count(literal_column("*")) with testing.expect_deprecated(r"The Query.from_self\(\) method"): self.assert_compile( @@ -452,7 +451,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): def test_select_from_count(self): Manager, Engineer = (self.classes.Manager, self.classes.Engineer) - sess = create_session() + sess = fixture_session() m1 = Manager(name="Tom", manager_data="data1") e1 = Engineer(name="Kurt", engineer_info="knows how to hack") sess.add_all([m1, e1]) @@ -468,7 +467,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): self.classes.Engineer, ) - sess = create_session() + sess = fixture_session() m1 = Manager(name="Tom", manager_data="data1") m2 = Manager(name="Tom2", manager_data="data2") e1 = Engineer(name="Kurt", engineer_info="knows how to hack") @@ -500,7 +499,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): self.classes.Engineer, ) - sess = create_session() + sess = fixture_session() r1, r2, r3, r4 = ( Report(name="r1"), Report(name="r2"), @@ -544,7 +543,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): Manager = self.classes.Manager Engineer = self.classes.Engineer - sess = create_session() + sess = fixture_session() m1 = Manager(name="Tom", manager_data="data1") m2 = Manager(name="Tom2", manager_data="data2") e1 = Engineer(name="Kurt", engineer_info="data3") @@ -562,7 +561,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): def test_exists_standalone(self): Engineer = self.classes.Engineer - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query( @@ -580,7 +579,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): self.classes.Engineer, ) - sess = create_session() + sess = fixture_session() m1 = Manager(name="Tom", manager_data="data1") r1 = Report(employee=m1) @@ -602,7 +601,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): self.classes.Engineer, ) - sess = create_session() + sess = fixture_session() m1 = Manager(name="Tom", manager_data="data1") r1 = Report(employee=m1) @@ -695,7 +694,7 @@ class RelationshipFromSingleTest( def test_subquery_load(self): Employee, Stuff, Manager = self.classes("Employee", "Stuff", "Manager") - sess = create_session() + sess = fixture_session() with self.sql_execution_asserter(testing.db) as asserter: sess.query(Manager).options(subqueryload("stuff")).all() @@ -809,7 +808,7 @@ class RelationshipToSingleTest( inherits=Engineer, polymorphic_identity="juniorengineer", ) - sess = sessionmaker()() + sess = fixture_session() c1 = Company(name="c1") c2 = Company(name="c2") @@ -851,7 +850,7 @@ class RelationshipToSingleTest( mapper(Employee, employees, polymorphic_on=employees.c.type) mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company).outerjoin( Company.employee.of_type(Engineer), @@ -882,7 +881,7 @@ class RelationshipToSingleTest( mapper(Employee, employees) mapper(Engineer, inherits=Employee) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company, Engineer.name).join( Engineer, Company.company_id == Engineer.company_id @@ -910,7 +909,7 @@ class RelationshipToSingleTest( mapper(Employee, employees, polymorphic_on=employees.c.type) mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company, Engineer.name).outerjoin("engineers"), "SELECT companies.company_id AS companies_company_id, " @@ -938,7 +937,7 @@ class RelationshipToSingleTest( mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") eng_alias = aliased(Engineer) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company, eng_alias.name).outerjoin( eng_alias, Company.engineers @@ -967,7 +966,7 @@ class RelationshipToSingleTest( mapper(Employee, employees, polymorphic_on=employees.c.type) mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company, Engineer).outerjoin( Engineer, Company.company_id == Engineer.company_id @@ -1002,7 +1001,7 @@ class RelationshipToSingleTest( mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") eng_alias = aliased(Engineer) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company, eng_alias).outerjoin( eng_alias, Company.company_id == eng_alias.company_id @@ -1036,7 +1035,7 @@ class RelationshipToSingleTest( mapper(Employee, employees, polymorphic_on=employees.c.type) mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company, Engineer).outerjoin(Engineer), "SELECT companies.company_id AS companies_company_id, " @@ -1069,7 +1068,7 @@ class RelationshipToSingleTest( mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") eng_alias = aliased(Engineer) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company, eng_alias).outerjoin(eng_alias), "SELECT companies.company_id AS companies_company_id, " @@ -1102,7 +1101,7 @@ class RelationshipToSingleTest( ) mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") - sess = create_session() + sess = fixture_session() engineer_count = ( sess.query(func.count(Engineer.employee_id)) .select_from(Engineer) @@ -1144,7 +1143,7 @@ class RelationshipToSingleTest( mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") mapper(Manager, inherits=Employee, polymorphic_identity="manager") - s = create_session() + s = fixture_session() q1 = ( s.query(Engineer) @@ -1236,7 +1235,7 @@ class RelationshipToSingleTest( inherits=Engineer, polymorphic_identity="juniorengineer", ) - sess = sessionmaker()() + sess = fixture_session() c1 = Company(name="c1") c2 = Company(name="c2") @@ -1433,7 +1432,7 @@ class ManyToManyToSingleTest(fixtures.MappedTest, AssertsCompiledSQL): Parent = self.classes.Parent SubChild1 = self.classes.SubChild1 - s = Session() + s = fixture_session() p1 = s.query(Parent).options(joinedload(Parent.s1)).all()[0] eq_(p1.__dict__["s1"], SubChild1(name="sc1_1")) @@ -1443,7 +1442,7 @@ class ManyToManyToSingleTest(fixtures.MappedTest, AssertsCompiledSQL): Child = self.classes.Child SubChild1 = self.classes.SubChild1 - s = Session() + s = fixture_session() p1, c1 = s.query(Parent, Child).outerjoin(Parent.s1).all()[0] eq_(c1, SubChild1(name="sc1_1")) @@ -1452,7 +1451,7 @@ class ManyToManyToSingleTest(fixtures.MappedTest, AssertsCompiledSQL): Parent = self.classes.Parent Child = self.classes.Child - s = Session() + s = fixture_session() self.assert_compile( s.query(Parent, Child).outerjoin(Parent.s1), @@ -1468,7 +1467,7 @@ class ManyToManyToSingleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_assert_joinedload_sql(self): Parent = self.classes.Parent - s = Session() + s = fixture_session() self.assert_compile( s.query(Parent).options(joinedload(Parent.s1)), @@ -1538,7 +1537,7 @@ class SingleOnJoinedTest(fixtures.MappedTest): ) mapper(Manager, inherits=Employee, polymorphic_identity="manager") - sess = create_session() + sess = fixture_session() sess.add(Person(name="p1")) sess.add(Employee(name="e1", employee_data="ed1")) sess.add(Manager(name="m1", employee_data="ed2", manager_data="md1")) @@ -1663,7 +1662,7 @@ class SingleFromPolySelectableTest( [self.classes.Boss, self.classes.Manager, self.classes.Engineer], self._with_poly_fixture(), ) - s = Session() + s = fixture_session() q = s.query(poly.Boss) self.assert_compile( q, @@ -1695,7 +1694,7 @@ class SingleFromPolySelectableTest( poly = self._with_poly_fixture() - s = Session() + s = fixture_session() q = s.query(Boss).with_polymorphic(Boss, poly) self.assert_compile( q, @@ -1719,7 +1718,7 @@ class SingleFromPolySelectableTest( def test_single_inh_subclass_join_joined_inh_subclass(self): Boss, Engineer = self.classes("Boss", "Engineer") - s = Session() + s = fixture_session() q = s.query(Boss).join(Engineer, Engineer.manager_id == Boss.id) @@ -1744,7 +1743,7 @@ class SingleFromPolySelectableTest( self._with_poly_fixture(), ) - s = Session() + s = fixture_session() q = s.query(Boss).join( poly.Engineer, poly.Engineer.manager_id == Boss.id @@ -1776,7 +1775,7 @@ class SingleFromPolySelectableTest( def test_joined_inh_subclass_join_single_inh_subclass(self): Engineer = self.classes.Engineer Boss = self.classes.Boss - s = Session() + s = fixture_session() q = s.query(Engineer).join(Boss, Engineer.manager_id == Boss.id) @@ -1826,7 +1825,7 @@ class EagerDefaultEvalTest(fixtures.DeclarativeMappedTest): foo = Foo() - session = Session() + session = fixture_session() session.add(foo) session.flush() @@ -1839,7 +1838,7 @@ class EagerDefaultEvalTest(fixtures.DeclarativeMappedTest): def test_persist_bar(self): Bar = self.classes.Bar bar = Bar() - session = Session() + session = fixture_session() session.add(bar) session.flush() diff --git a/test/orm/inheritance/test_with_poly.py b/test/orm/inheritance/test_with_poly.py index dee76fc7b..2492e593c 100644 --- a/test/orm/inheritance/test_with_poly.py +++ b/test/orm/inheritance/test_with_poly.py @@ -2,9 +2,9 @@ from sqlalchemy import and_ from sqlalchemy import exc from sqlalchemy import or_ from sqlalchemy import testing -from sqlalchemy.orm import create_session from sqlalchemy.orm import with_polymorphic from sqlalchemy.testing import eq_ +from sqlalchemy.testing.fixtures import fixture_session from ._poly_fixtures import _Polymorphic from ._poly_fixtures import _PolymorphicAliasedJoins from ._poly_fixtures import _PolymorphicFixtureBase @@ -19,7 +19,7 @@ from ._poly_fixtures import Person class WithPolymorphicAPITest(_Polymorphic, _PolymorphicFixtureBase): def test_no_use_flat_and_aliased(self): - sess = create_session() + sess = fixture_session() subq = sess.query(Person).subquery() @@ -37,7 +37,7 @@ class WithPolymorphicAPITest(_Polymorphic, _PolymorphicFixtureBase): class _WithPolymorphicBase(_PolymorphicFixtureBase): def test_join_base_to_sub(self): - sess = create_session() + sess = fixture_session() pa = with_polymorphic(Person, [Engineer]) def go(): @@ -51,7 +51,7 @@ class _WithPolymorphicBase(_PolymorphicFixtureBase): self.assert_sql_count(testing.db, go, 1) def test_col_expression_base_plus_two_subs(self): - sess = create_session() + sess = fixture_session() pa = with_polymorphic(Person, [Engineer, Manager]) eq_( @@ -70,7 +70,7 @@ class _WithPolymorphicBase(_PolymorphicFixtureBase): ) def test_join_to_join_entities(self): - sess = create_session() + sess = fixture_session() pa = with_polymorphic(Person, [Engineer]) pa_alias = with_polymorphic(Person, [Engineer], aliased=True) @@ -101,7 +101,7 @@ class _WithPolymorphicBase(_PolymorphicFixtureBase): ) def test_join_to_join_columns(self): - sess = create_session() + sess = fixture_session() pa = with_polymorphic(Person, [Engineer]) pa_alias = with_polymorphic(Person, [Engineer], aliased=True) diff --git a/test/orm/test_ac_relationships.py b/test/orm/test_ac_relationships.py index fbbf192a0..40f099fc8 100644 --- a/test/orm/test_ac_relationships.py +++ b/test/orm/test_ac_relationships.py @@ -16,6 +16,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.fixtures import ComparableEntity +from sqlalchemy.testing.fixtures import fixture_session class PartitionByFixture(fixtures.DeclarativeMappedTest): @@ -213,7 +214,7 @@ class AltSelectableTest( def test_lazyload(self): A, B = self.classes("A", "B") - sess = Session() + sess = fixture_session() a1 = sess.query(A).first() with self.sql_execution_asserter() as asserter: @@ -232,7 +233,7 @@ class AltSelectableTest( def test_joinedload(self): A, B = self.classes("A", "B") - sess = Session() + sess = fixture_session() with self.sql_execution_asserter() as asserter: # note this is many-to-one. use_get is unconditionally turned @@ -254,7 +255,7 @@ class AltSelectableTest( def test_selectinload(self): A, B = self.classes("A", "B") - sess = Session() + sess = fixture_session() with self.sql_execution_asserter() as asserter: # note this is many-to-one. use_get is unconditionally turned @@ -280,7 +281,7 @@ class AltSelectableTest( def test_join(self): A, B = self.classes("A", "B") - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(A).join(A.b), diff --git a/test/orm/test_association.py b/test/orm/test_association.py index b4c689c01..30e6f3541 100644 --- a/test/orm/test_association.py +++ b/test/orm/test_association.py @@ -3,11 +3,11 @@ from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -119,7 +119,7 @@ class AssociationTest(fixtures.MappedTest): self.classes.Keyword, ) - sess = create_session() + sess = fixture_session() item1 = Item("item1") item2 = Item("item2") item1.keywords.append( @@ -144,7 +144,7 @@ class AssociationTest(fixtures.MappedTest): self.classes.Keyword, ) - sess = create_session() + sess = fixture_session() item1 = Item("item1") item1.keywords.append( KeywordAssociation(Keyword("blue"), "blue_assoc") @@ -170,7 +170,7 @@ class AssociationTest(fixtures.MappedTest): self.classes.Keyword, ) - sess = create_session() + sess = fixture_session() item1 = Item("item1") item2 = Item("item2") item1.keywords.append( @@ -211,7 +211,7 @@ class AssociationTest(fixtures.MappedTest): item_keywords = self.tables.item_keywords Keyword = self.classes.Keyword - sess = create_session() + sess = fixture_session() item1 = Item("item1") item2 = Item("item2") item1.keywords.append( @@ -223,9 +223,19 @@ class AssociationTest(fixtures.MappedTest): ) sess.add_all((item1, item2)) sess.flush() - eq_(select(func.count("*")).select_from(item_keywords).scalar(), 3) + eq_( + sess.connection().scalar( + select(func.count("*")).select_from(item_keywords) + ), + 3, + ) sess.delete(item1) sess.delete(item2) sess.flush() - eq_(select(func.count("*")).select_from(item_keywords).scalar(), 0) + eq_( + sess.connection().scalar( + select(func.count("*")).select_from(item_keywords) + ), + 0, + ) diff --git a/test/orm/test_assorted_eager.py b/test/orm/test_assorted_eager.py index 310e50eb2..8ca6a8d86 100644 --- a/test/orm/test_assorted_eager.py +++ b/test/orm/test_assorted_eager.py @@ -15,11 +15,12 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import text from sqlalchemy.orm import backref -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -151,7 +152,7 @@ class EagerTest(fixtures.MappedTest): cls.classes.Thing, ) - session = create_session(connection) + session = Session(connection) o = Owner() c = Category(name="Some Category") @@ -167,7 +168,7 @@ class EagerTest(fixtures.MappedTest): session.flush() - def test_noorm(self): + def test_noorm(self, connection): """test the control case""" tests, options, categories = ( @@ -187,7 +188,7 @@ class EagerTest(fixtures.MappedTest): # not orm style correct query print("Obtaining correct results without orm") - result = ( + result = connection.execute( sa.select(tests.c.id, categories.c.name) .where( sa.and_( @@ -208,9 +209,7 @@ class EagerTest(fixtures.MappedTest): ), ) ) - .execute() - .fetchall() - ) + ).fetchall() eq_(result, [(1, "Some Category"), (3, "Some Category")]) def test_withoutjoinedload(self): @@ -220,7 +219,7 @@ class EagerTest(fixtures.MappedTest): self.tables.options, ) - s = create_session() + s = fixture_session() result = ( s.query(Thing) .select_from( @@ -260,7 +259,7 @@ class EagerTest(fixtures.MappedTest): self.tables.options, ) - s = create_session() + s = fixture_session() q = s.query(Thing).options(sa.orm.joinedload("category")) result = q.select_from( @@ -293,7 +292,7 @@ class EagerTest(fixtures.MappedTest): self.tables.options, ) - s = create_session() + s = fixture_session() q = s.query(Thing).options(sa.orm.joinedload("category")) result = q.filter( sa.and_( @@ -312,7 +311,7 @@ class EagerTest(fixtures.MappedTest): def test_without_outerjoin_literal(self): Thing, tests = (self.classes.Thing, self.tables.tests) - s = create_session() + s = fixture_session() q = s.query(Thing).options(sa.orm.joinedload("category")) result = q.filter( (tests.c.owner_id == 1) @@ -331,7 +330,7 @@ class EagerTest(fixtures.MappedTest): self.tables.options, ) - s = create_session() + s = fixture_session() q = s.query(Thing).options(sa.orm.joinedload("category")) result = q.filter( (tests.c.owner_id == 1) @@ -434,7 +433,7 @@ class EagerTest2(fixtures.MappedTest): p.left.append(Left("l1")) p.right.append(Right("r1")) - session = create_session() + session = fixture_session() session.add(p) session.flush() session.expunge_all() @@ -509,7 +508,7 @@ class EagerTest3(fixtures.MappedTest): mapper(Stat, stats, properties={"data": relationship(Data)}) - session = create_session() + session = fixture_session() data = [Data(a=x) for x in range(5)] session.add_all(data) @@ -538,7 +537,7 @@ class EagerTest3(fixtures.MappedTest): .group_by(stats.c.data_id) ) - arb_result = arb_data.execute().fetchall() + arb_result = session.connection().execute(arb_data).fetchall() # order the result list descending based on 'max' arb_result.sort(key=lambda a: a._mapping["max"], reverse=True) @@ -633,7 +632,7 @@ class EagerTest4(fixtures.MappedTest): for e in "Joe", "Bob", "Mary", "Wally": d2.employees.append(Employee(name=e)) - sess = create_session() + sess = fixture_session() sess.add_all((d1, d2)) sess.flush() @@ -751,7 +750,7 @@ class EagerTest5(fixtures.MappedTest): mapper(DerivedII, derivedII, inherits=baseMapper) - sess = create_session() + sess = fixture_session() d = Derived("uid1", "x", "y") d.comments = [Comment("uid1", "comment")] d2 = DerivedII("uid2", "xx", "z") @@ -910,7 +909,7 @@ class EagerTest6(fixtures.MappedTest): ) d = Design() - sess = create_session() + sess = fixture_session() sess.add(d) sess.flush() sess.expunge_all() @@ -1024,7 +1023,7 @@ class EagerTest7(fixtures.MappedTest): c1 = Company(company_name="company 1", addresses=[a1, a2]) i1 = Invoice(date=datetime.datetime.now(), company=c1) - session = create_session() + session = fixture_session() session.add(i1) session.flush() @@ -1169,7 +1168,7 @@ class EagerTest8(fixtures.MappedTest): properties=dict(type=relationship(Task_Type, lazy="joined")), ) - session = create_session() + session = fixture_session() eq_( session.query(Joined).limit(10).offset(0).one(), @@ -1284,7 +1283,7 @@ class EagerTest9(fixtures.MappedTest): self.classes.Transaction, ) - session = create_session() + session = fixture_session() tx1 = Transaction(name="tx1") tx2 = Transaction(name="tx2") diff --git a/test/orm/test_backref_mutations.py b/test/orm/test_backref_mutations.py index c873f46c7..a6d651d22 100644 --- a/test/orm/test_backref_mutations.py +++ b/test/orm/test_backref_mutations.py @@ -16,9 +16,9 @@ from sqlalchemy.orm import backref from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import Session -from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import eq_ from sqlalchemy.testing import is_ +from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures @@ -44,7 +44,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_collection_move_hitslazy(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") a2 = Address(email_address="address2") a3 = Address(email_address="address3") @@ -65,7 +65,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_collection_move_preloaded(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") u1 = User(name="jack", addresses=[a1]) @@ -88,7 +88,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_collection_move_notloaded(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") u1 = User(name="jack", addresses=[a1]) @@ -109,7 +109,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_collection_move_commitfirst(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") u1 = User(name="jack", addresses=[a1]) @@ -134,7 +134,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_scalar_move_preloaded(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() u1 = User(name="jack") u2 = User(name="ed") @@ -161,7 +161,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() u1 = User(name="jack") u2 = User(name="ed") a1 = Address(email_address="a1") @@ -183,7 +183,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_set_none(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() u1 = User(name="jack") a1 = Address(email_address="a1") a1.user = u1 @@ -201,7 +201,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_scalar_move_notloaded(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() u1 = User(name="jack") u2 = User(name="ed") @@ -221,7 +221,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_scalar_move_commitfirst(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() u1 = User(name="jack") u2 = User(name="ed") @@ -399,7 +399,7 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): def test_collection_move_preloaded(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") u1 = User(name="jack", address=a1) @@ -425,7 +425,7 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): def test_scalar_move_preloaded(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") a2 = Address(email_address="address1") u1 = User(name="jack", address=a1) @@ -449,7 +449,7 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): def test_collection_move_notloaded(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") u1 = User(name="jack", address=a1) @@ -471,7 +471,7 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): def test_scalar_move_notloaded(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") a2 = Address(email_address="address1") u1 = User(name="jack", address=a1) @@ -492,7 +492,7 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): def test_collection_move_commitfirst(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") u1 = User(name="jack", address=a1) @@ -519,7 +519,7 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): def test_scalar_move_commitfirst(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") a2 = Address(email_address="address2") u1 = User(name="jack", address=a1) @@ -568,7 +568,7 @@ class O2OScalarMoveTest(_fixtures.FixtureTest): def test_collection_move_commitfirst(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") u1 = User(name="jack", address=a1) @@ -622,7 +622,7 @@ class O2OScalarOrphanTest(_fixtures.FixtureTest): def test_m2o_event(self): User, Address = self.classes.User, self.classes.Address - sess = sessionmaker()() + sess = fixture_session() a1 = Address(email_address="address1") u1 = User(name="jack", address=a1) @@ -667,7 +667,7 @@ class M2MCollectionMoveTest(_fixtures.FixtureTest): Item, Keyword = (self.classes.Item, self.classes.Keyword) - session = Session(autoflush=False) + session = fixture_session(autoflush=False) i1 = Item(description="i1") session.add(i1) @@ -685,7 +685,7 @@ class M2MCollectionMoveTest(_fixtures.FixtureTest): Item, Keyword = (self.classes.Item, self.classes.Keyword) - session = Session(autoflush=False) + session = fixture_session(autoflush=False) k1 = Keyword(name="k1") i1 = Item(description="i1", keywords=[k1]) @@ -805,7 +805,7 @@ class M2MScalarMoveTest(_fixtures.FixtureTest): def test_collection_move_preloaded(self): Item, Keyword = self.classes.Item, self.classes.Keyword - sess = sessionmaker()() + sess = fixture_session() k1 = Keyword(name="k1") i1 = Item(description="i1", keyword=k1) @@ -828,7 +828,7 @@ class M2MScalarMoveTest(_fixtures.FixtureTest): def test_collection_move_notloaded(self): Item, Keyword = self.classes.Item, self.classes.Keyword - sess = sessionmaker()() + sess = fixture_session() k1 = Keyword(name="k1") i1 = Item(description="i1", keyword=k1) @@ -847,7 +847,7 @@ class M2MScalarMoveTest(_fixtures.FixtureTest): def test_collection_move_commit(self): Item, Keyword = self.classes.Item, self.classes.Keyword - sess = sessionmaker()() + sess = fixture_session() k1 = Keyword(name="k1") i1 = Item(description="i1", keyword=k1) diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py index 64f85b335..2f54f7fff 100644 --- a/test/orm/test_bind.py +++ b/test/orm/test_bind.py @@ -8,7 +8,6 @@ from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import true from sqlalchemy.orm import backref -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import Session @@ -20,6 +19,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.mock import Mock from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -55,9 +55,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): }, ) - sess = Session( - binds={User: self.metadata.bind, Address: self.metadata.bind} - ) + sess = Session(binds={User: testing.db, Address: testing.db}) u1 = User(id=1, name="ed") sess.add(u1) @@ -112,13 +110,13 @@ class BindIntegrationTest(_fixtures.FixtureTest): }, ) - Session = sessionmaker( + maker = sessionmaker( binds={ - users_unbound: self.metadata.bind, - addresses_unbound: self.metadata.bind, + users_unbound: testing.db, + addresses_unbound: testing.db, } ) - sess = Session() + sess = maker() u1 = User(id=1, name="ed") sess.add(u1) @@ -151,7 +149,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): mapper(User, users) - session = Session() + session = fixture_session() session.execute(users.insert(), dict(name="Johnny")) @@ -376,7 +374,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): sess.close() def test_bind_arg(self): - sess = Session() + sess = fixture_session() assert_raises_message( sa.exc.ArgumentError, @@ -403,7 +401,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): mapper(User, users) c = testing.db.connect() - sess = create_session(bind=c) + sess = Session(bind=c) sess.begin() transaction = sess._legacy_transaction() u = User(name="u1") @@ -432,7 +430,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): mapper(User, users) c = testing.db.connect() - sess = create_session(bind=c, autocommit=False) + sess = Session(bind=c, autocommit=False) u = User(name="u1") sess.add(u) sess.flush() @@ -440,7 +438,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): assert not c.in_transaction() assert c.exec_driver_sql("select count(1) from users").scalar() == 0 - sess = create_session(bind=c, autocommit=False) + sess = Session(bind=c, autocommit=False) u = User(name="u2") sess.add(u) sess.flush() @@ -455,7 +453,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): c = testing.db.connect() trans = c.begin() - sess = create_session(bind=c, autocommit=True) + sess = Session(bind=c, autocommit=True) u = User(name="u3") sess.add(u) sess.flush() @@ -495,11 +493,11 @@ class SessionBindTest(fixtures.MappedTest): def test_session_bind(self): Foo = self.classes.Foo - engine = self.metadata.bind + engine = testing.db for bind in (engine, engine.connect()): try: - sess = create_session(bind=bind) + sess = Session(bind=bind) assert sess.bind is bind f = Foo() sess.add(f) @@ -512,7 +510,7 @@ class SessionBindTest(fixtures.MappedTest): def test_session_unbound(self): Foo = self.classes.Foo - sess = create_session() + sess = Session() sess.add(Foo()) assert_raises_message( sa.exc.UnboundExecutionError, @@ -578,10 +576,6 @@ class GetBindTest(fixtures.MappedTest): def _fixture(self, binds): return Session(binds=binds) - def test_fallback_table_metadata(self): - session = self._fixture({}) - is_(session.get_bind(self.classes.BaseClass), testing.db) - def test_bind_base_table_base_class(self): base_class_bind = Mock() session = self._fixture({self.tables.base_table: base_class_bind}) @@ -610,11 +604,25 @@ class GetBindTest(fixtures.MappedTest): # table, so this is what we expect is_(session.get_bind(self.classes.JoinedSubClass), base_class_bind) + def test_fallback_table_metadata(self): + session = self._fixture({}) + assert_raises_message( + sa.exc.UnboundExecutionError, + "Could not locate a bind configured on mapper mapped class", + session.get_bind, + self.classes.BaseClass, + ) + def test_bind_base_table_concrete_sub_class(self): base_class_bind = Mock() session = self._fixture({self.tables.base_table: base_class_bind}) - is_(session.get_bind(self.classes.ConcreteSubClass), testing.db) + assert_raises_message( + sa.exc.UnboundExecutionError, + "Could not locate a bind configured on mapper mapped class", + session.get_bind, + self.classes.ConcreteSubClass, + ) def test_bind_sub_table_concrete_sub_class(self): base_class_bind = Mock(name="base") diff --git a/test/orm/test_bulk.py b/test/orm/test_bulk.py index 27b187342..83f74f055 100644 --- a/test/orm/test_bulk.py +++ b/test/orm/test_bulk.py @@ -4,12 +4,12 @@ from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import mapper -from sqlalchemy.orm import Session from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm import _fixtures @@ -48,7 +48,7 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest): def test_bulk_insert_via_save(self): Foo = self.classes.Foo - s = Session() + s = fixture_session() s.bulk_save_objects([Foo(value="value")]) @@ -58,7 +58,7 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest): def test_bulk_update_via_save(self): Foo = self.classes.Foo - s = Session() + s = fixture_session() s.add(Foo(value="value")) s.commit() @@ -84,7 +84,7 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): def test_bulk_save_return_defaults(self): (User,) = self.classes("User") - s = Session() + s = fixture_session() objects = [User(name="u1"), User(name="u2"), User(name="u3")] assert "id" not in objects[0].__dict__ @@ -121,7 +121,7 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): def test_bulk_save_mappings_preserve_order(self): (User,) = self.classes("User") - s = Session() + s = fixture_session() # commit some object into db user1 = User(name="i1") @@ -177,7 +177,7 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): def test_bulk_save_no_defaults(self): (User,) = self.classes("User") - s = Session() + s = fixture_session() objects = [User(name="u1"), User(name="u2"), User(name="u3")] assert "id" not in objects[0].__dict__ @@ -195,7 +195,7 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): def test_bulk_save_updated_include_unchanged(self): (User,) = self.classes("User") - s = Session(expire_on_commit=False) + s = fixture_session(expire_on_commit=False) objects = [User(name="u1"), User(name="u2"), User(name="u3")] s.add_all(objects) s.commit() @@ -203,7 +203,7 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): objects[0].name = "u1new" objects[2].name = "u3new" - s = Session() + s = fixture_session() with self.sql_execution_asserter() as asserter: s.bulk_save_objects(objects, update_changed_only=False) @@ -221,12 +221,12 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): def test_bulk_update(self): (User,) = self.classes("User") - s = Session(expire_on_commit=False) + s = fixture_session(expire_on_commit=False) objects = [User(name="u1"), User(name="u2"), User(name="u3")] s.add_all(objects) s.commit() - s = Session() + s = fixture_session() with self.sql_execution_asserter() as asserter: s.bulk_update_mappings( User, @@ -251,7 +251,7 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): def test_bulk_insert(self): (User,) = self.classes("User") - s = Session() + s = fixture_session() with self.sql_execution_asserter() as asserter: s.bulk_insert_mappings( User, @@ -276,7 +276,7 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): def test_bulk_insert_render_nulls(self): (Order,) = self.classes("Order") - s = Session() + s = fixture_session() with self.sql_execution_asserter() as asserter: s.bulk_insert_mappings( Order, @@ -334,7 +334,7 @@ class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest): def test_insert_w_fetch(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = A(x=1) s.bulk_save_objects([a1]) s.commit() @@ -342,7 +342,7 @@ class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest): def test_update_w_fetch(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = A(x=1, y=2) s.add(a1) s.commit() @@ -488,7 +488,7 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): def _test_insert(self, person_cls): Person = person_cls - s = Session() + s = fixture_session() with self.sql_execution_asserter(testing.db) as asserter: s.bulk_insert_mappings( Person, [{"id": 5, "personname": "thename"}] @@ -501,7 +501,7 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): def _test_update(self, person_cls): Person = person_cls - s = Session() + s = fixture_session() s.add(Person(id=5, personname="thename")) s.commit() @@ -605,7 +605,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): "Person", "Engineer", "Manager", "Boss" ) - s = Session() + s = fixture_session() objects = [ Manager(name="m1", status="s1", manager_name="mn1"), Engineer(name="e1", status="s2", primary_language="l1"), @@ -684,7 +684,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): "Person", "Engineer", "Manager", "Boss" ) - s = Session() + s = fixture_session() with self.sql_execution_asserter() as asserter: s.bulk_save_objects( [ @@ -766,7 +766,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): "Person", "Engineer", "Manager", "Boss" ) - s = Session() + s = fixture_session() with self.sql_execution_asserter() as asserter: s.bulk_insert_mappings( Boss, diff --git a/test/orm/test_bundle.py b/test/orm/test_bundle.py index 956645506..b0113f1fc 100644 --- a/test/orm/test_bundle.py +++ b/test/orm/test_bundle.py @@ -16,6 +16,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -124,7 +125,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_same_named_col_in_orderby(self): Data, Other = self.classes("Data", "Other") bundle = Bundle("pk", Data.id, Other.id) - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(Data, Other).order_by(bundle), @@ -138,7 +139,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_same_named_col_in_fetch(self): Data, Other = self.classes("Data", "Other") bundle = Bundle("pk", Data.id, Other.id) - sess = Session() + sess = fixture_session() eq_( sess.query(bundle) @@ -159,7 +160,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_result(self): Data = self.classes.Data - sess = Session() + sess = fixture_session() b1 = Bundle("b1", Data.d1, Data.d2) @@ -170,7 +171,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_subclass(self): Data = self.classes.Data - sess = Session() + sess = fixture_session() class MyBundle(Bundle): def create_row_processor(self, query, procs, labels): @@ -199,7 +200,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): b1 = Bundle("b1", d1.d1, d1.d2) b2 = Bundle("b2", Data.d1, Other.o1) - sess = Session() + sess = fixture_session() q = ( sess.query(b1, b2) @@ -249,7 +250,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_single_entity_legacy_query(self): Data = self.classes.Data - sess = Session() + sess = fixture_session() b1 = Bundle("b1", Data.d1, Data.d2, single_entity=True) @@ -260,7 +261,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_labeled_cols_non_single_entity_legacy_query(self): Data = self.classes.Data - sess = Session() + sess = fixture_session() b1 = Bundle("b1", Data.d1.label("x"), Data.d2.label("y")) @@ -271,7 +272,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_labeled_cols_single_entity_legacy_query(self): Data = self.classes.Data - sess = Session() + sess = fixture_session() b1 = Bundle( "b1", Data.d1.label("x"), Data.d2.label("y"), single_entity=True @@ -284,7 +285,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_labeled_cols_as_rows_future(self): Data = self.classes.Data - sess = Session() + sess = fixture_session() b1 = Bundle("b1", Data.d1.label("x"), Data.d2.label("y")) @@ -297,7 +298,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_labeled_cols_as_scalars_future(self): Data = self.classes.Data - sess = Session() + sess = fixture_session() b1 = Bundle("b1", Data.d1.label("x"), Data.d2.label("y")) @@ -340,7 +341,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_single_entity_flag_but_multi_entities(self): Data = self.classes.Data - sess = Session() + sess = fixture_session() b1 = Bundle("b1", Data.d1, Data.d2, single_entity=True) b2 = Bundle("b1", Data.d3, single_entity=True) @@ -356,7 +357,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_bundle_nesting(self): Data = self.classes.Data - sess = Session() + sess = fixture_session() b1 = Bundle("b1", Data.d1, Bundle("b2", Data.d2, Data.d3)) @@ -374,7 +375,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_bundle_nesting_unions(self): Data = self.classes.Data - sess = Session() + sess = fixture_session() b1 = Bundle("b1", Data.d1, Bundle("b2", Data.d2, Data.d3)) @@ -407,12 +408,12 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): def test_query_count(self): Data = self.classes.Data b1 = Bundle("b1", Data.d1, Data.d2) - eq_(Session().query(b1).count(), 10) + eq_(fixture_session().query(b1).count(), 10) def test_join_relationship(self): Data = self.classes.Data - sess = Session() + sess = fixture_session() b1 = Bundle("b1", Data.d1, Data.d2) q = sess.query(b1).join(Data.others) self.assert_compile( @@ -426,7 +427,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): Data = self.classes.Data Other = self.classes.Other - sess = Session() + sess = fixture_session() b1 = Bundle("b1", Data.d1, Data.d2) q = sess.query(b1).join(Other) self.assert_compile( @@ -444,7 +445,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): b1 = Bundle("b1", Data.id, Data.d1, Data.d2) - session = Session() + session = fixture_session() first = session.query(b1) second = session.query(b1) unioned = first.union(second) @@ -488,7 +489,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): b1 = Bundle("b1", Data.id, Data.d1, Data.d2) - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(b1).filter_by(d1="d1"), @@ -501,7 +502,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): b1 = Bundle("b1", Data.id, Data.d1, Data.d2) - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(Data).order_by(b1), "SELECT data.id AS data_id, data.d1 AS data_d1, " @@ -520,7 +521,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): b1 = Bundle("b1", data_table.c.d1, data_table.c.d2) - sess = Session() + sess = fixture_session() eq_( sess.query(b1).filter(b1.c.d1.between("d3d1", "d5d1")).all(), [(("d3d1", "d3d2"),), (("d4d1", "d4d2"),), (("d5d1", "d5d2"),)], @@ -531,7 +532,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): b1 = Bundle("b1", data_table.c.d1, data_table.c.d2, single_entity=True) - sess = Session() + sess = fixture_session() eq_( sess.query(b1).filter(b1.c.d1.between("d3d1", "d5d1")).all(), [("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")], diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py index 59d747012..7ef9d1b60 100644 --- a/test/orm/test_cache_key.py +++ b/test/orm/test_cache_key.py @@ -21,6 +21,7 @@ from sqlalchemy.sql.base import CacheableOptions from sqlalchemy.sql.visitors import InternalTraversal from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures from .inheritance import _poly_fixtures from .test_query import QueryTest @@ -260,20 +261,25 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): self._run_cache_key_fixture( lambda: stmt_20( - Session().query(User).join(User.addresses), - Session().query(User).join(User.orders), - Session().query(User).join(User.addresses).join(User.orders), - Session() + fixture_session().query(User).join(User.addresses), + fixture_session().query(User).join(User.orders), + fixture_session() + .query(User) + .join(User.addresses) + .join(User.orders), + fixture_session() .query(User) .join("addresses") .join("dingalings", from_joinpoint=True), - Session().query(User).join("addresses"), - Session().query(User).join("orders"), - Session().query(User).join("addresses").join("orders"), - Session().query(User).join(Address, User.addresses), - Session().query(User).join(a1, "addresses"), - Session().query(User).join(a1, "addresses", aliased=True), - Session().query(User).join(User.addresses.of_type(a1)), + fixture_session().query(User).join("addresses"), + fixture_session().query(User).join("orders"), + fixture_session().query(User).join("addresses").join("orders"), + fixture_session().query(User).join(Address, User.addresses), + fixture_session().query(User).join(a1, "addresses"), + fixture_session() + .query(User) + .join(a1, "addresses", aliased=True), + fixture_session().query(User).join(User.addresses.of_type(a1)), ), compare_values=True, ) @@ -285,21 +291,21 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): self._run_cache_key_fixture( lambda: stmt_20( - Session() + fixture_session() .query(User) .from_statement(text("select * from user")), - Session() + fixture_session() .query(User) .options(selectinload(User.addresses)) .from_statement(text("select * from user")), - Session() + fixture_session() .query(User) .options(subqueryload(User.addresses)) .from_statement(text("select * from user")), - Session() + fixture_session() .query(User) .from_statement(text("select * from user order by id")), - Session() + fixture_session() .query(User.id) .from_statement(text("select * from user")), ), @@ -316,28 +322,40 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): self._run_cache_key_fixture( lambda: stmt_20( - Session().query(User), - Session().query(User).prefix_with("foo"), - Session().query(User).filter_by(name="ed"), - Session().query(User).filter_by(name="ed").order_by(User.id), - Session().query(User).filter_by(name="ed").order_by(User.name), - Session().query(User).filter_by(name="ed").group_by(User.id), - Session() + fixture_session().query(User), + fixture_session().query(User).prefix_with("foo"), + fixture_session().query(User).filter_by(name="ed"), + fixture_session() + .query(User) + .filter_by(name="ed") + .order_by(User.id), + fixture_session() + .query(User) + .filter_by(name="ed") + .order_by(User.name), + fixture_session() + .query(User) + .filter_by(name="ed") + .group_by(User.id), + fixture_session() .query(User) .join(User.addresses) .filter(User.name == "ed"), - Session().query(User).join(User.orders), - Session() + fixture_session().query(User).join(User.orders), + fixture_session() .query(User) .join(User.orders) .filter(Order.description == "adsf"), - Session().query(User).join(User.addresses).join(User.orders), - Session().query(User).join(Address, User.addresses), - Session().query(User).join(a1, User.addresses), - Session().query(User).join(User.addresses.of_type(a1)), - Session().query(Address).join(Address.user), - Session().query(User, Address).filter_by(name="ed"), - Session().query(User, a1).filter_by(name="ed"), + fixture_session() + .query(User) + .join(User.addresses) + .join(User.orders), + fixture_session().query(User).join(Address, User.addresses), + fixture_session().query(User).join(a1, User.addresses), + fixture_session().query(User).join(User.addresses.of_type(a1)), + fixture_session().query(Address).join(Address.user), + fixture_session().query(User, Address).filter_by(name="ed"), + fixture_session().query(User, a1).filter_by(name="ed"), ), compare_values=True, ) @@ -401,27 +419,29 @@ class PolyCacheKeyTest(CacheKeyFixture, _poly_fixtures._Polymorphic): def one(): return ( - Session().query(Person).with_polymorphic([Manager, Engineer]) + fixture_session() + .query(Person) + .with_polymorphic([Manager, Engineer]) ) def two(): wp = with_polymorphic(Person, [Manager, Engineer]) - return Session().query(wp) + return fixture_session().query(wp) def three(): wp = with_polymorphic(Person, [Manager, Engineer]) - return Session().query(wp).filter(wp.name == "asdfo") + return fixture_session().query(wp).filter(wp.name == "asdfo") def three_a(): wp = with_polymorphic(Person, [Manager, Engineer], flat=True) - return Session().query(wp).filter(wp.name == "asdfo") + return fixture_session().query(wp).filter(wp.name == "asdfo") def four(): return ( - Session() + fixture_session() .query(Person) .with_polymorphic([Manager, Engineer]) .filter(Person.name == "asdf") @@ -436,7 +456,7 @@ class PolyCacheKeyTest(CacheKeyFixture, _poly_fixtures._Polymorphic): ) wp = with_polymorphic(Person, [Manager, Engineer], subq) - return Session().query(wp).filter(wp.name == "asdfo") + return fixture_session().query(wp).filter(wp.name == "asdfo") def six(): subq = ( @@ -447,7 +467,7 @@ class PolyCacheKeyTest(CacheKeyFixture, _poly_fixtures._Polymorphic): ) return ( - Session() + fixture_session() .query(Person) .with_polymorphic([Manager, Engineer], subq) .filter(Person.name == "asdfo") @@ -467,7 +487,7 @@ class PolyCacheKeyTest(CacheKeyFixture, _poly_fixtures._Polymorphic): def one(): return ( - Session() + fixture_session() .query(Company) .join(Company.employees) .filter(Person.name == "asdf") @@ -476,7 +496,7 @@ class PolyCacheKeyTest(CacheKeyFixture, _poly_fixtures._Polymorphic): def two(): wp = with_polymorphic(Person, [Manager, Engineer]) return ( - Session() + fixture_session() .query(Company) .join(Company.employees.of_type(wp)) .filter(wp.name == "asdf") @@ -485,7 +505,7 @@ class PolyCacheKeyTest(CacheKeyFixture, _poly_fixtures._Polymorphic): def three(): wp = with_polymorphic(Person, [Manager, Engineer]) return ( - Session() + fixture_session() .query(Company) .join(Company.employees.of_type(wp)) .filter(wp.Engineer.name == "asdf") diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py index 6a916e28a..180b479ba 100644 --- a/test/orm/test_cascade.py +++ b/test/orm/test_cascade.py @@ -18,7 +18,6 @@ from sqlalchemy.orm import mapper from sqlalchemy.orm import object_mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import Session -from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import util as orm_util from sqlalchemy.orm.attributes import instance_state from sqlalchemy.testing import assert_raises @@ -27,6 +26,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import not_in +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm import _fixtures @@ -270,7 +270,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): def test_list_assignment_new(self): User, Order = self.classes.User, self.classes.Order - with Session() as sess: + with fixture_session() as sess: u = User( name="jack", orders=[ @@ -295,7 +295,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): def test_list_assignment_replace(self): User, Order = self.classes.User, self.classes.Order - with Session() as sess: + with fixture_session() as sess: u = User( name="jack", orders=[ @@ -331,7 +331,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): def test_standalone_orphan(self): Order = self.classes.Order - with Session() as sess: + with fixture_session() as sess: o5 = Order(description="order 5") sess.add(o5) assert_raises(sa_exc.DBAPIError, sess.flush) @@ -342,7 +342,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): Order, User = self.classes.Order, self.classes.User - sess = sessionmaker(expire_on_commit=False)() + sess = fixture_session(expire_on_commit=False) o1, o2, o3 = ( Order(description="o1"), Order(description="o2"), @@ -363,7 +363,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): def test_remove_pending_from_collection(self): User, Order = self.classes.User, self.classes.Order - with Session() as sess: + with fixture_session() as sess: u = User(name="jack") sess.add(u) @@ -380,7 +380,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): User, Order = self.classes.User, self.classes.Order - with Session() as sess: + with fixture_session() as sess: u = User(name="jack") @@ -407,7 +407,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): self.classes.Order, ) - with Session() as sess: + with fixture_session() as sess: u = User( name="jack", orders=[ @@ -444,7 +444,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): self.classes.Address, ) - with Session() as sess: + with fixture_session() as sess: u = User( name="jack", addresses=[ @@ -497,7 +497,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): self.tables.orders, ) - with Session(autoflush=False) as sess: + with fixture_session(autoflush=False) as sess: u = User( name="jack", orders=[ @@ -550,7 +550,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): self.classes.Address, ) - sess = Session() + sess = fixture_session() u = User(name="jack") sess.add(u) assert "orders" not in u.__dict__ @@ -580,7 +580,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): self.classes.Order, ) - sess = Session() + sess = fixture_session() u = User( name="jack", orders=[ @@ -619,7 +619,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): self.classes.Order, ) - with Session() as sess: + with fixture_session() as sess: u = User( name="jack", orders=[ @@ -710,7 +710,7 @@ class O2MCascadeTest(fixtures.MappedTest): def test_none_o2m_collection_assignment(self): User = self.classes.User - s = Session() + s = fixture_session() u1 = User(name="u", addresses=[None]) s.add(u1) eq_(u1.addresses, [None]) @@ -723,7 +723,7 @@ class O2MCascadeTest(fixtures.MappedTest): def test_none_o2m_collection_append(self): User = self.classes.User - s = Session() + s = fixture_session() u1 = User(name="u") s.add(u1) @@ -793,7 +793,7 @@ class O2MCascadeDeleteNoOrphanTest(fixtures.MappedTest): self.tables.users, ) - with Session() as sess: + with fixture_session() as sess: u = User( name="jack", orders=[ @@ -936,7 +936,7 @@ class O2OSingleParentNoFlushTest(fixtures.MappedTest): User, Address = self.classes.User, self.classes.Address a1 = Address(email_address="some address") u1 = User(name="u1", address=a1) - sess = Session() + sess = fixture_session() sess.add(u1) sess.commit() @@ -1083,7 +1083,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=False) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") u1.addresses.append(a1) @@ -1096,7 +1096,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=False, o2m_cascade=False) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") u1.addresses.append(a1) @@ -1109,7 +1109,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=False, o2m_cascade=False) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") sess.add(a1) @@ -1127,7 +1127,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") u1.addresses.append(a1) @@ -1140,7 +1140,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True, o2m_cascade=False) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") u1.addresses.append(a1) @@ -1153,7 +1153,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True, o2m_cascade=False) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") u1.addresses.append(a1) @@ -1172,7 +1172,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True, o2m_cascade=False) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") sess.add(a1) @@ -1189,7 +1189,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True, o2m_cascade=False) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") sess.add(a1) @@ -1212,7 +1212,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=False, m2o=True) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") a1.user = u1 @@ -1225,7 +1225,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=False, m2o=True, m2o_cascade=False) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") a1.user = u1 @@ -1238,7 +1238,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=False, m2o=True, m2o_cascade=False) - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) sess.flush() @@ -1255,7 +1255,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") a1.user = u1 @@ -1268,7 +1268,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True, m2o_cascade=False) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") a1.user = u1 @@ -1281,7 +1281,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True, m2o_cascade=False) - with Session() as sess: + with fixture_session() as sess: u1 = User(name="u1") sess.add(u1) sess.flush() @@ -1324,7 +1324,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True, m2o_cascade=False) - sess = Session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="a1") @@ -1346,7 +1346,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): self._one_to_many_fixture(o2m=True, m2o=True, m2o_cascade=False) - with Session() as sess: + with fixture_session() as sess: u1 = User(name="u1") sess.add(u1) sess.flush() @@ -1400,7 +1400,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): Item, Keyword = self.classes.Item, self.classes.Keyword self._many_to_many_fixture(fwd=True, bkd=False) - sess = Session() + sess = fixture_session() i1 = Item(description="i1") k1 = Keyword(name="k1") i1.keywords.append(k1) @@ -1413,7 +1413,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): Item, Keyword = self.classes.Item, self.classes.Keyword self._many_to_many_fixture(fwd=True, bkd=False, fwd_cascade=False) - sess = Session() + sess = fixture_session() i1 = Item(description="i1") k1 = Keyword(name="k1") i1.keywords.append(k1) @@ -1426,7 +1426,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): Item, Keyword = self.classes.Item, self.classes.Keyword self._many_to_many_fixture(fwd=True, bkd=False, fwd_cascade=False) - sess = Session() + sess = fixture_session() i1 = Item(description="i1") k1 = Keyword(name="k1") sess.add(k1) @@ -1444,7 +1444,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): Item, Keyword = self.classes.Item, self.classes.Keyword self._many_to_many_fixture(fwd=True, bkd=True) - sess = Session() + sess = fixture_session() i1 = Item(description="i1") k1 = Keyword(name="k1") i1.keywords.append(k1) @@ -1457,7 +1457,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): Item, Keyword = self.classes.Item, self.classes.Keyword self._many_to_many_fixture(fwd=True, bkd=True, fwd_cascade=False) - sess = Session() + sess = fixture_session() i1 = Item(description="i1") k1 = Keyword(name="k1") i1.keywords.append(k1) @@ -1470,7 +1470,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): Item, Keyword = self.classes.Item, self.classes.Keyword self._many_to_many_fixture(fwd=True, bkd=True, fwd_cascade=False) - sess = Session() + sess = fixture_session() i1 = Item(description="i1") k1 = Keyword(name="k1") i1.keywords.append(k1) @@ -1489,7 +1489,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): Item, Keyword = self.classes.Item, self.classes.Keyword self._many_to_many_fixture(fwd=True, bkd=True, fwd_cascade=False) - sess = Session() + sess = fixture_session() i1 = Item(description="i1") k1 = Keyword(name="k1") sess.add(k1) @@ -1506,7 +1506,7 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): Item, Keyword = self.classes.Item, self.classes.Keyword self._many_to_many_fixture(fwd=True, bkd=True, fwd_cascade=False) - sess = Session() + sess = fixture_session() i1 = Item(description="i1") k1 = Keyword(name="k1") sess.add(k1) @@ -1549,7 +1549,7 @@ class NoSaveCascadeBackrefTest(_fixtures.FixtureTest): ), ) - sess = Session() + sess = fixture_session() o1 = Order() sess.add(o1) @@ -1584,7 +1584,7 @@ class NoSaveCascadeBackrefTest(_fixtures.FixtureTest): ) mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User() sess.add(u1) @@ -1625,7 +1625,7 @@ class NoSaveCascadeBackrefTest(_fixtures.FixtureTest): ) mapper(Keyword, keywords) - sess = Session() + sess = fixture_session() i1 = Item() k1 = Keyword() @@ -1753,7 +1753,7 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): self.tables.extra, ) - sess = Session() + sess = fixture_session() eq_( sess.execute(select(func.count("*")).select_from(prefs)).scalar(), 3, @@ -1779,7 +1779,7 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): Foo, User = self.classes.Foo, self.classes.User - sess = sessionmaker(expire_on_commit=True)() + sess = fixture_session(expire_on_commit=True) u1 = User(name="jack", foo=Foo(data="f1")) sess.add(u1) @@ -1802,7 +1802,7 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): Pref, User = self.classes.Pref, self.classes.User - sess = sessionmaker(expire_on_commit=False)() + sess = fixture_session(expire_on_commit=False) p1, p2 = Pref(data="p1"), Pref(data="p2") u = User(name="jack", pref=p1) @@ -1824,7 +1824,7 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): self.tables.extra, ) - sess = Session() + sess = fixture_session() jack = sess.query(User).filter_by(name="jack").one() p = jack.pref e = jack.pref.extra[0] @@ -1849,7 +1849,7 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): def test_pending_expunge(self): Pref, User = self.classes.Pref, self.classes.User - sess = Session() + sess = fixture_session() someuser = User(name="someuser") sess.add(someuser) sess.flush() @@ -1868,7 +1868,7 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): Pref, User = self.classes.Pref, self.classes.User - sess = Session() + sess = fixture_session() jack = sess.query(User).filter_by(name="jack").one() newpref = Pref(data="newpref") @@ -1961,7 +1961,7 @@ class M2OCascadeDeleteOrphanTestTwo(fixtures.MappedTest): def test_cascade_delete(self): T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) - sess = Session() + sess = fixture_session() x = T1(data="t1a", t2=T2(data="t2a", t3=T3(data="t3a"))) sess.add(x) sess.flush() @@ -1975,7 +1975,7 @@ class M2OCascadeDeleteOrphanTestTwo(fixtures.MappedTest): def test_deletes_orphans_onelevel(self): T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) - sess = Session() + sess = fixture_session() x2 = T1(data="t1b", t2=T2(data="t2b", t3=T3(data="t3b"))) sess.add(x2) sess.flush() @@ -1990,7 +1990,7 @@ class M2OCascadeDeleteOrphanTestTwo(fixtures.MappedTest): def test_deletes_orphans_twolevel(self): T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) - sess = Session() + sess = fixture_session() x = T1(data="t1a", t2=T2(data="t2a", t3=T3(data="t3a"))) sess.add(x) sess.flush() @@ -2005,7 +2005,7 @@ class M2OCascadeDeleteOrphanTestTwo(fixtures.MappedTest): def test_finds_orphans_twolevel(self): T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) - sess = Session() + sess = fixture_session() x = T1(data="t1a", t2=T2(data="t2a", t3=T3(data="t3a"))) sess.add(x) sess.flush() @@ -2102,7 +2102,7 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): def test_cascade_delete(self): T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) - sess = Session() + sess = fixture_session() x = T1(data="t1a", t2=T2(data="t2a", t3=T3(data="t3a"))) sess.add(x) sess.flush() @@ -2116,7 +2116,7 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): def test_cascade_delete_postappend_onelevel(self): T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) - sess = Session() + sess = fixture_session() x1 = T1(data="t1") x2 = T2(data="t2") x3 = T3(data="t3") @@ -2134,7 +2134,7 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): def test_cascade_delete_postappend_twolevel(self): T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) - sess = Session() + sess = fixture_session() x1 = T1(data="t1", t2=T2(data="t2")) x3 = T3(data="t3") sess.add_all((x1, x3)) @@ -2150,7 +2150,7 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): def test_preserves_orphans_onelevel(self): T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) - sess = Session() + sess = fixture_session() x2 = T1(data="t1b", t2=T2(data="t2b", t3=T3(data="t3b"))) sess.add(x2) sess.flush() @@ -2166,7 +2166,7 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): def test_preserves_orphans_onelevel_postremove(self): T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) - sess = Session() + sess = fixture_session() x2 = T1(data="t1b", t2=T2(data="t2b", t3=T3(data="t3b"))) sess.add(x2) sess.flush() @@ -2181,7 +2181,7 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): def test_preserves_orphans_twolevel(self): T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) - sess = Session() + sess = fixture_session() x = T1(data="t1a", t2=T2(data="t2a", t3=T3(data="t3a"))) sess.add(x) sess.flush() @@ -2270,7 +2270,7 @@ class M2MCascadeTest(fixtures.MappedTest): ) mapper(B, b) - sess = Session() + sess = fixture_session() b1 = B(data="b1") a1 = A(data="a1", bs=[b1]) sess.add(a1) @@ -2310,7 +2310,7 @@ class M2MCascadeTest(fixtures.MappedTest): # failed until [ticket:427] was fixed mapper(B, b) - sess = Session() + sess = fixture_session() b1 = B(data="b1") a1 = A(data="a1", bs=[b1]) sess.add(a1) @@ -2356,7 +2356,7 @@ class M2MCascadeTest(fixtures.MappedTest): ) mapper(C, c) - sess = Session() + sess = fixture_session() b1 = B(data="b1", cs=[C(data="c1")]) a1 = A(data="a1", bs=[b1]) sess.add(a1) @@ -2394,7 +2394,7 @@ class M2MCascadeTest(fixtures.MappedTest): ) mapper(B, b) - sess = Session() + sess = fixture_session() a1 = A(data="a1", bs=[B(data="b1")]) sess.add(a1) sess.flush() @@ -2513,7 +2513,7 @@ class M2MCascadeTest(fixtures.MappedTest): ) mapper(B, b) - s = Session() + s = fixture_session() a1 = A(bs=[None]) s.add(a1) eq_(a1.bs, [None]) @@ -2540,7 +2540,7 @@ class M2MCascadeTest(fixtures.MappedTest): ) mapper(B, b) - s = Session() + s = fixture_session() a1 = A() a1.bs.append(None) s.add(a1) @@ -2588,7 +2588,7 @@ class O2MSelfReferentialDetelOrphanTest(fixtures.MappedTest): def test_self_referential_delete(self): Node = self.classes.Node - s = Session() + s = fixture_session() n1, n2, n3, n4 = Node(), Node(), Node(), Node() n1.children = [n2, n3] @@ -2640,7 +2640,7 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): def test_o2m_basic(self): User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -2652,7 +2652,7 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): def test_o2m_commit_warns(self): User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -2667,7 +2667,7 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): def test_o2m_flag_on_backref(self): Dingaling, Address = self.classes.Dingaling, self.classes.Address - sess = Session() + sess = fixture_session() a1 = Address(email_address="a1") sess.add(a1) @@ -2686,7 +2686,7 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): def test_m2o_basic(self): Dingaling, Address = self.classes.Dingaling, self.classes.Address - sess = Session() + sess = fixture_session() a1 = Address(email_address="a1") d1 = Dingaling() @@ -2698,7 +2698,7 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): def test_m2o_flag_on_backref(self): User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() a1 = Address(email_address="a1") sess.add(a1) @@ -2714,7 +2714,7 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): def test_m2o_commit_warns(self): Dingaling, Address = self.classes.Dingaling, self.classes.Address - sess = Session() + sess = fixture_session() a1 = Address(email_address="a1") d1 = Dingaling() @@ -2810,7 +2810,7 @@ class PendingOrphanTestSingleLevel(fixtures.MappedTest): orders=relationship(Order, cascade="all, delete-orphan"), ), ) - s = Session() + s = fixture_session() # the standalone Address goes in, its foreign key # allows NULL @@ -2855,7 +2855,7 @@ class PendingOrphanTestSingleLevel(fixtures.MappedTest): ) ), ) - s = Session() + s = fixture_session() u = User() s.add(u) @@ -2891,7 +2891,7 @@ class PendingOrphanTestSingleLevel(fixtures.MappedTest): ) ), ) - s = Session() + s = fixture_session() u = User(name="u1", addresses=[Address(email_address="ad1")]) s.add(u) a1 = u.addresses[0] @@ -2964,7 +2964,7 @@ class PendingOrphanTestTwoLevel(fixtures.MappedTest): }, ) mapper(Item, item) - s = Session() + s = fixture_session() o1 = Order() s.add(o1) @@ -3001,7 +3001,7 @@ class PendingOrphanTestTwoLevel(fixtures.MappedTest): }, ) mapper(Attribute, attribute) - s = Session() + s = fixture_session() o1 = Order() s.add(o1) @@ -3117,7 +3117,7 @@ class DoubleParentO2MOrphanTest(fixtures.MappedTest): ) ), ) - s = Session(expire_on_commit=False, autoflush=False) + s = fixture_session(expire_on_commit=False, autoflush=False) a = Account(balance=0) sr = SalesRep(name="John") @@ -3282,7 +3282,7 @@ class DoubleParentM2OOrphanTest(fixtures.MappedTest): }, ) - session = Session() + session = fixture_session() h1 = Home(description="home1", address=Address(street="address1")) b1 = Business( description="business1", address=Address(street="address2") @@ -3341,7 +3341,7 @@ class DoubleParentM2OOrphanTest(fixtures.MappedTest): ) }, ) - session = Session() + session = fixture_session() a1 = Address() session.add(a1) session.flush() @@ -3386,7 +3386,7 @@ class CollectionAssignmentOrphanTest(fixtures.MappedTest): a1 = A(name="a1", bs=[B(name="b1"), B(name="b2"), B(name="b3")]) - sess = Session() + sess = fixture_session() sess.add(a1) sess.flush() @@ -3490,7 +3490,7 @@ class OrphanCriterionTest(fixtures.MappedTest): RelatedTwo(cores=[c1]) if persistent: - s = Session() + s = fixture_session() s.add(c1) s.flush() @@ -3629,7 +3629,7 @@ class O2MConflictTest(fixtures.MappedTest): def _do_move_test(self, delete_old): Parent, Child = self.classes.Parent, self.classes.Child - with Session(autoflush=False) as sess: + with fixture_session(autoflush=False) as sess: p1, p2, c1 = Parent(), Parent(), Child() if Parent.child.property.uselist: p1.child.append(c1) @@ -3880,7 +3880,7 @@ class PartialFlushTest(fixtures.MappedTest): ) mapper(Child, noninh_child) - sess = Session() + sess = fixture_session() c1, c2 = Child(), Child() b1 = Base(descr="b1", children=[c1, c2]) @@ -3897,7 +3897,7 @@ class PartialFlushTest(fixtures.MappedTest): assert c2 in sess and c2 not in sess.new assert b1 in sess and b1 not in sess.new - sess = Session() + sess = fixture_session() c1, c2 = Child(), Child() b1 = Base(descr="b1", children=[c1, c2]) sess.add(b1) @@ -3907,7 +3907,7 @@ class PartialFlushTest(fixtures.MappedTest): assert c2 in sess and c2 in sess.new assert b1 in sess and b1 in sess.new - sess = Session() + sess = fixture_session() c1, c2 = Child(), Child() b1 = Base(descr="b1", children=[c1, c2]) sess.add(b1) @@ -3952,7 +3952,7 @@ class PartialFlushTest(fixtures.MappedTest): mapper(Parent, parent, inherits=Base) - sess = Session() + sess = fixture_session() p1 = Parent() c1, c2, c3 = Child(), Child(), Child() @@ -4097,7 +4097,7 @@ class SubclassCascadeTest(fixtures.DeclarativeMappedTest): ) ] ) - s = Session() + s = fixture_session() s.add(obj) s.commit() @@ -4188,7 +4188,7 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest): }, ) - sess = Session() + sess = fixture_session() u = User(id=1, name="jack") sess.add(u) sess.add_all( @@ -4235,7 +4235,7 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest): }, ) - sess = Session() + sess = fixture_session() u1 = User(id=1, name="jack") sess.add(u1) @@ -4274,7 +4274,7 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest): }, ) - sess = Session() + sess = fixture_session() u1 = User(id=1, name="jack") o1, o2 = ( diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 7c7662618..3d09bd446 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -10,7 +10,6 @@ from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import util from sqlalchemy.orm import attributes -from sqlalchemy.orm import create_session from sqlalchemy.orm import instrumentation from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship @@ -23,6 +22,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import ne_ +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -1753,7 +1753,7 @@ class DictHelpersTest(OrderedDictFixture, fixtures.MappedTest): p = Parent() p.children["foo"] = Child("foo", "value") p.children["bar"] = Child("bar", "value") - session = create_session() + session = fixture_session() session.add(p) session.flush() pid = p.id @@ -1839,7 +1839,7 @@ class DictHelpersTest(OrderedDictFixture, fixtures.MappedTest): p.children[("foo", "1")] = Child("foo", "1", "value 1") p.children[("foo", "2")] = Child("foo", "2", "value 2") - session = create_session() + session = fixture_session() session.add(p) session.flush() pid = p.id @@ -2104,7 +2104,7 @@ class CustomCollectionsTest(fixtures.MappedTest): f = Foo() f.bars.add(Bar()) f.bars.add(Bar()) - sess = create_session() + sess = fixture_session() sess.add(f) sess.flush() sess.expunge_all() @@ -2147,7 +2147,7 @@ class CustomCollectionsTest(fixtures.MappedTest): f = Foo() f.bars.set(Bar()) f.bars.set(Bar()) - sess = create_session() + sess = fixture_session() sess.add(f) sess.flush() sess.expunge_all() @@ -2189,7 +2189,7 @@ class CustomCollectionsTest(fixtures.MappedTest): col = collections.collection_adapter(f.bars) col.append_with_event(Bar("a")) col.append_with_event(Bar("b")) - sess = create_session() + sess = fixture_session() sess.add(f) sess.flush() sess.expunge_all() @@ -2444,7 +2444,7 @@ class CustomCollectionsTest(fixtures.MappedTest): p1.children.append(o) assert control == list(p1.children) - sess = create_session() + sess = fixture_session() sess.add(p1) sess.flush() sess.expunge_all() diff --git a/test/orm/test_compile.py b/test/orm/test_compile.py index c6a1226d4..df652daf4 100644 --- a/test/orm/test_compile.py +++ b/test/orm/test_compile.py @@ -5,14 +5,13 @@ from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import String from sqlalchemy import Table -from sqlalchemy import testing from sqlalchemy import Unicode from sqlalchemy.orm import backref from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import fixtures @@ -24,7 +23,7 @@ class CompileTest(fixtures.ORMTest): clear_mappers() def test_with_polymorphic(self): - metadata = MetaData(testing.db) + metadata = MetaData() order = Table( "orders", @@ -122,7 +121,7 @@ class CompileTest(fixtures.ORMTest): def test_conflicting_backref_one(self): """test that conflicting backrefs raises an exception""" - metadata = MetaData(testing.db) + metadata = MetaData() order = Table( "orders", @@ -190,9 +189,7 @@ class CompileTest(fixtures.ORMTest): sa_exc.ArgumentError, "Error creating backref", configure_mappers ) - @testing.provide_metadata - def test_misc_one(self, connection): - metadata = self.metadata + def test_misc_one(self, connection, metadata): node_table = Table( "node", metadata, @@ -235,7 +232,7 @@ class CompileTest(fixtures.ORMTest): "host": relationship(Host), }, ) - sess = create_session(connection) + sess = Session(connection) assert sess.query(Node).get(1).names == [] def test_conflicting_backref_two(self): diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index a164034da..6ee87eefe 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -15,6 +15,7 @@ from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -151,7 +152,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): # pending/transient object. e1 = Edge() assert e1.end is None - sess = Session() + sess = fixture_session() sess.add(e1) # however, once it's persistent, the code as of 0.7.3 @@ -382,7 +383,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): Graph, Edge = self.classes.Graph, self.classes.Edge - sess = Session() + sess = fixture_session() g = Graph(id=1) e = Edge(None, None) g.edges.append(e) @@ -488,7 +489,7 @@ class NestedTest(fixtures.MappedTest, testing.AssertsCompiledSQL): def test_round_trip(self): Thing, AB, CD = self._fixture() - s = Session() + s = fixture_session() s.add(Thing(AB("a", "b", CD("c", "d")))) s.commit() @@ -553,7 +554,7 @@ class PrimaryKeyTest(fixtures.MappedTest): def _fixture(self): Graph, Version = self.classes.Graph, self.classes.Version - sess = Session() + sess = fixture_session() g = Graph(Version(1, 1)) sess.add(g) sess.commit() @@ -593,7 +594,7 @@ class PrimaryKeyTest(fixtures.MappedTest): def test_null_pk(self): Graph, Version = self.classes.Graph, self.classes.Version - sess = Session() + sess = fixture_session() # test pk with one column NULL # only sqlite can really handle this @@ -674,7 +675,7 @@ class DefaultsTest(fixtures.MappedTest): def test_attributes_with_defaults(self): Foobar, FBComposite = self.classes.Foobar, self.classes.FBComposite - sess = Session() + sess = fixture_session() f1 = Foobar() f1.foob = FBComposite(None, 5, None, None) sess.add(f1) @@ -690,7 +691,7 @@ class DefaultsTest(fixtures.MappedTest): def test_set_composite_values(self): Foobar, FBComposite = self.classes.Foobar, self.classes.FBComposite - sess = Session() + sess = fixture_session() f1 = Foobar() f1.foob = FBComposite(None, 5, None, None) sess.add(f1) @@ -783,7 +784,7 @@ class MappedSelectTest(fixtures.MappedTest): self.tables.descriptions, ) - session = Session() + session = fixture_session() d = Descriptions( custom_descriptions=CustomValues("Color", "Number"), values=[ @@ -866,7 +867,7 @@ class ManyToOneTest(fixtures.MappedTest): def test_persist(self): A, C, B = (self.classes.A, self.classes.C, self.classes.B) - sess = Session() + sess = fixture_session() sess.add(A(c=C("b1", B(data="b2")))) sess.commit() @@ -876,7 +877,7 @@ class ManyToOneTest(fixtures.MappedTest): def test_query(self): A, C, B = (self.classes.A, self.classes.C, self.classes.B) - sess = Session() + sess = fixture_session() b1, b2 = B(data="b1"), B(data="b2") a1 = A(c=C("a1b1", b1)) a2 = A(c=C("a2b1", b2)) @@ -888,7 +889,7 @@ class ManyToOneTest(fixtures.MappedTest): def test_query_aliased(self): A, C, B = (self.classes.A, self.classes.C, self.classes.B) - sess = Session() + sess = fixture_session() b1, b2 = B(data="b1"), B(data="b2") a1 = A(c=C("a1b1", b1)) a2 = A(c=C("a2b1", b2)) @@ -941,7 +942,7 @@ class ConfigurationTest(fixtures.MappedTest): Edge, Point = self.classes.Edge, self.classes.Point e1 = Edge(start=Point(3, 4), end=Point(5, 6)) - sess = Session() + sess = fixture_session() sess.add(e1) sess.commit() @@ -1131,7 +1132,7 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): def _test_comparator_behavior(self): Edge, Point = (self.classes.Edge, self.classes.Point) - sess = Session() + sess = fixture_session() e1 = Edge(Point(3, 4), Point(5, 6)) e2 = Edge(Point(14, 5), Point(2, 7)) sess.add_all([e1, e2]) @@ -1159,7 +1160,7 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): Edge(Point(0, 1), Point(3, 5)), ) - sess = Session() + sess = fixture_session() sess.add_all([edge_1, edge_2]) sess.commit() @@ -1179,7 +1180,7 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): def test_order_by(self): self._fixture(False) Edge = self.classes.Edge - s = Session() + s = fixture_session() self.assert_compile( s.query(Edge).order_by(Edge.start, Edge.end), "SELECT edge.id AS edge_id, edge.x1 AS edge_x1, " @@ -1190,7 +1191,7 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): def test_order_by_aliased(self): self._fixture(False) Edge = self.classes.Edge - s = Session() + s = fixture_session() ea = aliased(Edge) self.assert_compile( s.query(ea).order_by(ea.start, ea.end), diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 12d3f7bfb..1a58356e3 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -14,7 +14,6 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper from sqlalchemy.orm import query_expression from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.orm import with_expression from sqlalchemy.orm import with_polymorphic from sqlalchemy.sql import sqltypes @@ -24,6 +23,7 @@ from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing.fixtures import fixture_session from .inheritance import _poly_fixtures from .test_query import QueryTest @@ -293,7 +293,7 @@ class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL): def test_no_joinedload_in_subquery_select_rows(self, joinedload_fixture): User, Address = joinedload_fixture - sess = Session() + sess = fixture_session() stmt1 = sess.query(User).subquery() stmt1 = sess.query(stmt1) @@ -316,7 +316,7 @@ class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL): def test_no_joinedload_in_subquery_select_entity(self, joinedload_fixture): User, Address = joinedload_fixture - sess = Session() + sess = fixture_session() stmt1 = sess.query(User).subquery() ua = aliased(User, stmt1) stmt1 = sess.query(ua) @@ -645,7 +645,7 @@ class RelationshipNaturalCompileTest(QueryTest, AssertsCompiledSQL): stmt1 = select(u1).where(u1.addresses.of_type(a1)) stmt2 = ( - Session() + fixture_session() .query(u1) .filter(u1.addresses.of_type(a1)) ._final_statement(legacy_query_style=False) @@ -844,7 +844,7 @@ class ImplicitWithPolymorphicTest( .order_by(Person.person_id) ) - sess = Session() + sess = fixture_session() q = ( sess.query(Person.person_id, Person.name) .filter(Person.name == "some name") @@ -884,7 +884,7 @@ class ImplicitWithPolymorphicTest( .order_by(Person.person_id) ) - sess = Session() + sess = fixture_session() q = ( sess.query(Person) .filter(Person.name == "some name") @@ -931,7 +931,7 @@ class ImplicitWithPolymorphicTest( .order_by(Engineer.person_id) ) - sess = Session() + sess = fixture_session() q = ( sess.query(Engineer) .filter(Engineer.name == "some name") @@ -990,7 +990,7 @@ class ImplicitWithPolymorphicTest( .order_by(Engineer.person_id) ) - sess = Session() + sess = fixture_session() q = ( sess.query(Engineer.person_id, Engineer.name) .filter(Engineer.name == "some name") @@ -1079,7 +1079,7 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): ) stmt2 = select(Company).join(Company.employees) stmt3 = ( - Session() + fixture_session() .query(Company) .join(Company.employees) ._final_statement(legacy_query_style=False) @@ -1113,7 +1113,7 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): select(Company).join(Company.employees).where(Person.name == "ed") ) stmt3 = ( - Session() + fixture_session() .query(Company) .join(Company.employees) .filter(Person.name == "ed") @@ -1137,7 +1137,7 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): stmt2 = select(Company).join(Company.employees).join(Person.paperwork) stmt3 = ( - Session() + fixture_session() .query(Company) .join(Company.employees) .join(Person.paperwork) @@ -1161,7 +1161,7 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): stmt2 = select(Company).join(Company.employees.of_type(p1)) stmt3 = ( - Session() + fixture_session() .query(Company) .join(Company.employees.of_type(p1)) ._final_statement(legacy_query_style=False) @@ -1179,7 +1179,7 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): Company, Person, Manager, Engineer = self.classes( "Company", "Person", "Manager", "Engineer" ) - s = Session() + s = fixture_session() p1 = with_polymorphic(Person, "*", aliased=True) @@ -1218,7 +1218,7 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): stmt2 = select(Company).join(p1, Company.employees.of_type(p1)) stmt3 = ( - Session() + fixture_session() .query(Company) .join(Company.employees.of_type(p1)) ._final_statement(legacy_query_style=False) @@ -1479,7 +1479,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): stmt1 = select(User).where(User.addresses) stmt2 = ( - Session() + fixture_session() .query(User) .filter(User.addresses) ._final_statement(legacy_query_style=False) @@ -1505,7 +1505,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): stmt1 = select(Item).where(Item.keywords) stmt2 = ( - Session() + fixture_session() .query(Item) .filter(Item.keywords) ._final_statement(legacy_query_style=False) @@ -1519,7 +1519,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): expected = "SELECT * FROM users" stmt1 = select(literal_column("*")).select_from(User) stmt2 = ( - Session() + fixture_session() .query(literal_column("*")) .select_from(User) ._final_statement(legacy_query_style=False) @@ -1534,7 +1534,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): stmt1 = select(literal_column("*")).select_from(ua) stmt2 = ( - Session() + fixture_session() .query(literal_column("*")) .select_from(ua) ._final_statement(legacy_query_style=False) @@ -1565,7 +1565,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): .scalar_subquery(), ) stmt2 = ( - Session() + fixture_session() .query( User.name, Address.id, @@ -1595,7 +1595,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ) stmt2 = ( - Session() + fixture_session() .query( uu.name, Address.id, @@ -1624,7 +1624,9 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): stmt1 = select(User) stmt2 = ( - Session().query(User)._final_statement(legacy_query_style=False) + fixture_session() + .query(User) + ._final_statement(legacy_query_style=False) ) self.assert_compile(stmt1, expected) @@ -1637,7 +1639,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): stmt1 = select(User.id, User.name) stmt2 = ( - Session() + fixture_session() .query(User.id, User.name) ._final_statement(legacy_query_style=False) ) @@ -1651,7 +1653,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): stmt1 = select(ua.id, ua.name) stmt2 = ( - Session() + fixture_session() .query(ua.id, ua.name) ._final_statement(legacy_query_style=False) ) @@ -1665,7 +1667,11 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ua = aliased(User, name="ua") stmt1 = select(ua) - stmt2 = Session().query(ua)._final_statement(legacy_query_style=False) + stmt2 = ( + fixture_session() + .query(ua) + ._final_statement(legacy_query_style=False) + ) expected = "SELECT ua.id, ua.name FROM users AS ua" self.assert_compile(stmt1, expected) @@ -1695,7 +1701,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - s = Session() + s = fixture_session() q = s.query(User.id, User.name).filter_by(name="ed") self.assert_compile( insert(Address).from_select(("id", "email_address"), q), @@ -1708,7 +1714,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - s = Session() + s = fixture_session() q = s.query(User.id, User.name).filter_by(name="ed") self.assert_compile( insert(Address).from_select( @@ -1781,7 +1787,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): stmt1 = select(Foo).where(Foo.foob == "somename").order_by(Foo.foob) stmt2 = ( - Session() + fixture_session() .query(Foo) .filter(Foo.foob == "somename") .order_by(Foo.foob) diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index 47b5404c9..e1ef67fed 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -14,11 +14,8 @@ from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import backref -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session -from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -27,6 +24,7 @@ from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional from sqlalchemy.testing.assertsql import RegexSQL +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -88,7 +86,7 @@ class SelfReferentialTest(fixtures.MappedTest): a = C1("head c1") a.c1s.append(C1("another c1")) - sess = create_session() + sess = fixture_session() sess.add(a) sess.flush() sess.delete(a) @@ -119,7 +117,7 @@ class SelfReferentialTest(fixtures.MappedTest): c1 = C1() - sess = create_session() + sess = fixture_session() sess.add(c1) sess.flush() sess.expunge_all() @@ -156,7 +154,7 @@ class SelfReferentialTest(fixtures.MappedTest): a.c1s[0].c1s.append(C1("subchild2")) a.c1s[1].c2s.append(C2("child2 data1")) a.c1s[1].c2s.append(C2("child2 data2")) - sess = create_session() + sess = fixture_session() sess.add(a) sess.flush() @@ -168,7 +166,7 @@ class SelfReferentialTest(fixtures.MappedTest): mapper(C1, t1, properties={"children": relationship(C1)}) - sess = create_session() + sess = fixture_session() c1 = C1() c2 = C1() c1.children.append(c2) @@ -234,7 +232,7 @@ class SelfReferentialNoPKTest(fixtures.MappedTest): t1.children.append(TT()) t1.children.append(TT()) - s = create_session() + s = fixture_session() s.add(t1) s.flush() s.expunge_all() @@ -244,7 +242,7 @@ class SelfReferentialNoPKTest(fixtures.MappedTest): def test_lazy_clause(self): TT = self.classes.TT - s = create_session() + s = fixture_session() t1 = TT() t2 = TT() t1.children.append(t2) @@ -327,7 +325,7 @@ class InheritTestOne(fixtures.MappedTest): Child1, Child2 = self.classes.Child1, self.classes.Child2 - session = create_session() + session = fixture_session() c1 = Child1() c1.child1_data = "qwerty" @@ -419,7 +417,7 @@ class InheritTestTwo(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() bobj = B() sess.add(bobj) cobj = C() @@ -506,7 +504,7 @@ class BiDirectionalManyToOneTest(fixtures.MappedTest): o1 = T1() o1.t2 = T2() - sess = create_session() + sess = fixture_session() sess.add(o1) sess.flush() @@ -528,7 +526,7 @@ class BiDirectionalManyToOneTest(fixtures.MappedTest): o1 = T1() o1.t2 = T2() - sess = create_session() + sess = fixture_session() sess.add(o1) sess.flush() @@ -621,7 +619,7 @@ class BiDirectionalOneToManyTest(fixtures.MappedTest): a.c2s.append(b) d.c1s.append(c) b.c1s.append(c) - sess = create_session() + sess = fixture_session() sess.add_all((a, b, c, d, e, f)) sess.flush() @@ -726,7 +724,7 @@ class BiDirectionalOneToManyTest2(fixtures.MappedTest): a.data.append(C1Data(data="c1data1")) a.data.append(C1Data(data="c1data2")) c.data.append(C1Data(data="c1data3")) - sess = create_session() + sess = fixture_session() sess.add_all((a, b, c, d, e, f)) sess.flush() @@ -818,7 +816,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest): b = Ball() p = Person() p.balls.append(b) - sess = create_session() + sess = fixture_session() sess.add(p) sess.flush() @@ -845,7 +843,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest): b = Ball(data="some data") p = Person(data="some data") p.favorite = b - sess = create_session() + sess = fixture_session() sess.add(b) sess.add(p) sess.flush() @@ -903,7 +901,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest): p.balls.append(Ball(data="some data")) p.balls.append(Ball(data="some data")) p.favorite = b - sess = create_session() + sess = fixture_session() sess.add(b) sess.add(p) @@ -1001,7 +999,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest): ), ) - sess = sessionmaker()() + sess = fixture_session() p1 = Person(data="p1") p2 = Person(data="p2") p3 = Person(data="p3") @@ -1065,7 +1063,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest): b4 = Ball(data="some data") p.balls.append(b4) p.favorite = b - sess = create_session() + sess = fixture_session() sess.add_all((b, p, b2, b3, b4)) self.assert_sql_execution( @@ -1176,7 +1174,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest): ) mapper(Person, person) - sess = create_session(autocommit=False, expire_on_commit=True) + sess = fixture_session(autocommit=False, expire_on_commit=True) sess.add(Ball(person=Person())) sess.commit() b1 = sess.query(Ball).first() @@ -1267,7 +1265,7 @@ class SelfReferentialPostUpdateTest(fixtures.MappedTest): }, ) - session = create_session() + session = fixture_session(autoflush=False) def append_child(parent, child): if parent.children: @@ -1421,7 +1419,7 @@ class SelfReferentialPostUpdateTest2(fixtures.MappedTest): }, ) - session = create_session() + session = fixture_session() f1 = A(fui="f1") session.add(f1) @@ -1509,7 +1507,7 @@ class SelfReferentialPostUpdateTest3(fixtures.MappedTest): properties={"parent": relationship(Child, remote_side=child.c.id)}, ) - session = create_session() + session = fixture_session() p1 = Parent("p1") c1 = Child("c1") c2 = Child("c2") @@ -1668,7 +1666,7 @@ class PostUpdateBatchingTest(fixtures.MappedTest): mapper(Child2, child2) mapper(Child3, child3) - sess = create_session() + sess = fixture_session() p1 = Parent("p1") c11, c12, c13 = Child1("c1"), Child1("c2"), Child1("c3") @@ -1753,7 +1751,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): def test_update_defaults(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() a1 = A() b1 = B() @@ -1772,7 +1770,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): event.listen(A, "refresh_flush", canary.refresh_flush) event.listen(A, "expire", canary.expire) - s = Session() + s = fixture_session() a1 = A() b1 = B() @@ -1800,7 +1798,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): event.listen(A, "refresh_flush", canary.refresh_flush) event.listen(A, "expire", canary.expire) - s = Session() + s = fixture_session() a1 = A() s.add(a1) @@ -1831,7 +1829,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): event.listen(A, "refresh_flush", canary.refresh_flush) event.listen(A, "expire", canary.expire) - s = Session() + s = fixture_session() a1 = A() b1 = B() @@ -1885,7 +1883,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): event.listen(A, "refresh_flush", canary.refresh_flush) event.listen(A, "expire", canary.expire) - s = Session() + s = fixture_session() a1 = A() s.add(a1) @@ -1936,7 +1934,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): def test_update_defaults_can_set_value(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() a1 = A() b1 = B() diff --git a/test/orm/test_default_strategies.py b/test/orm/test_default_strategies.py index 3bd5d97db..e5206d2ae 100644 --- a/test/orm/test_default_strategies.py +++ b/test/orm/test_default_strategies.py @@ -1,12 +1,11 @@ import sqlalchemy as sa from sqlalchemy import testing from sqlalchemy import util -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures @@ -111,7 +110,7 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): ), ) - return create_session() + return fixture_session() def _upgrade_fixture(self): ( @@ -183,7 +182,7 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): ), ) - return create_session() + return fixture_session() def test_downgrade_baseline(self): """Mapper strategy defaults load as expected @@ -630,7 +629,7 @@ class NoLoadTest(_fixtures.FixtureTest): ) ), ) - q = create_session().query(m) + q = fixture_session().query(m) result = [None] def go(): @@ -661,7 +660,7 @@ class NoLoadTest(_fixtures.FixtureTest): ) ), ) - q = create_session().query(m).options(sa.orm.lazyload("addresses")) + q = fixture_session().query(m).options(sa.orm.lazyload("addresses")) result = [None] def go(): @@ -684,7 +683,7 @@ class NoLoadTest(_fixtures.FixtureTest): ) mapper(Address, addresses, properties={"user": relationship(User)}) mapper(User, users) - s = Session() + s = fixture_session() a1 = ( s.query(Address) .filter_by(id=1) diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py index aa1f2b88d..97743b5de 100644 --- a/test/orm/test_defaults.py +++ b/test/orm/test_defaults.py @@ -5,14 +5,13 @@ from sqlalchemy import Identity from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper -from sqlalchemy.orm import Session from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertsql import assert_engine from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -166,7 +165,7 @@ class TriggerDefaultsTest(fixtures.MappedTest): eq_(d1.col3, None) eq_(d1.col4, None) - session = create_session() + session = fixture_session() session.add(d1) session.flush() @@ -181,7 +180,7 @@ class TriggerDefaultsTest(fixtures.MappedTest): d1 = Default(id=1) - session = create_session() + session = fixture_session() session.add(d1) session.flush() d1.col1 = "set" @@ -214,10 +213,10 @@ class ExcludedDefaultsTest(fixtures.MappedTest): mapper(Foo, dt, exclude_properties=("col1",)) f1 = Foo() - sess = create_session() + sess = fixture_session() sess.add(f1) sess.flush() - eq_(dt.select().execute().fetchall(), [(1, "hello")]) + eq_(sess.connection().execute(dt.select()).fetchall(), [(1, "hello")]) class ComputedDefaultsOnUpdateTest(fixtures.MappedTest): @@ -261,7 +260,7 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest): else: Thing = self.classes.ThingNoEager - s = Session() + s = fixture_session() t1, t2 = (Thing(id=1, foo=5), Thing(id=2, foo=10)) @@ -342,7 +341,7 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest): else: Thing = self.classes.ThingNoEager - s = Session() + s = fixture_session() t1, t2 = (Thing(id=1, foo=1), Thing(id=2, foo=2)) @@ -445,7 +444,7 @@ class IdentityDefaultsOnUpdateTest(fixtures.MappedTest): def test_insert_identity(self): Thing = self.classes.Thing - s = Session() + s = fixture_session() t1, t2 = (Thing(foo=5), Thing(foo=10)) diff --git a/test/orm/test_deferred.py b/test/orm/test_deferred.py index 6be967337..6d1cd0184 100644 --- a/test/orm/test_deferred.py +++ b/test/orm/test_deferred.py @@ -9,7 +9,6 @@ from sqlalchemy import util from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes from sqlalchemy.orm import contains_eager -from sqlalchemy.orm import create_session from sqlalchemy.orm import defaultload from sqlalchemy.orm import defer from sqlalchemy.orm import deferred @@ -31,6 +30,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm import _fixtures @@ -56,7 +56,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): o = Order() self.assert_(o.description is None) - q = create_session().query(Order).order_by(Order.id) + q = fixture_session().query(Order).order_by(Order.id) def go(): result = q.all() @@ -90,7 +90,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(Order, orders, properties={"id": deferred(orders.c.id)}) # right now, it's not that graceful :) - q = create_session().query(Order) + q = fixture_session().query(Order) assert_raises_message( sa.exc.NoSuchColumnError, "Could not locate", q.first ) @@ -106,7 +106,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): properties={"description": deferred(orders.c.description)}, ) - sess = create_session() + sess = fixture_session() o = Order() sess.add(o) o.id = 7 @@ -128,7 +128,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): }, ) - sess = create_session() + sess = fixture_session() o1 = sess.query(Order).get(1) eq_(o1.description, "order 1") @@ -141,7 +141,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): properties={"description": deferred(orders.c.description)}, ) - sess = create_session() + sess = fixture_session() o = Order() sess.add(o) @@ -164,7 +164,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() o = Order() sess.add(o) o.id = 7 @@ -186,7 +186,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() o = Order() sess.add(o) @@ -204,7 +204,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): properties={"description": deferred(orders.c.description)}, ) - sess = create_session() + sess = fixture_session() o2 = sess.query(Order).get(2) o2.isopen = 1 sess.flush() @@ -233,7 +233,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() q = sess.query(Order).order_by(Order.id) def go(): @@ -287,7 +287,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): "opened": deferred(orders.c.isopen, group="primary"), }, ) - sess = create_session() + sess = fixture_session(autoflush=False) o = sess.query(Order).get(3) assert "userident" not in o.__dict__ o.description = "somenewdescription" @@ -319,7 +319,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): }, ) - sess = create_session() + sess = fixture_session() o2 = sess.query(Order).get(3) # this will load the group of attributes @@ -351,7 +351,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): properties={"description": deferred(order_select.c.description)}, ) - sess = Session() + sess = fixture_session() o1 = sess.query(Order).order_by(Order.id).first() assert "description" not in o1.__dict__ eq_(o1.description, "order 1") @@ -367,7 +367,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(Order, orders) - sess = create_session() + sess = fixture_session() q = sess.query(Order).order_by(Order.id).options(defer("user_id")) def go(): @@ -427,7 +427,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() q = sess.query(Order).order_by(Order.id) def go(): @@ -470,7 +470,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() q = sess.query(Order).order_by(Order.id) def go(): @@ -515,7 +515,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() q = sess.query(Order).order_by(Order.id) def go(): @@ -570,7 +570,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() q = ( sess.query(User) .filter(User.id == 7) @@ -633,7 +633,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() q = ( sess.query(User) .filter(User.id == 7) @@ -699,7 +699,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() q = ( sess.query(User) .filter(User.id == 7) @@ -762,7 +762,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() q = ( sess.query(User) .filter(User.id == 7) @@ -810,7 +810,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() q = sess.query(Order).options(Load(Order).undefer("*")) self.assert_compile( q, @@ -834,7 +834,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): properties={"description": deferred(orders.c.description)}, ) - sess = create_session() + sess = fixture_session() o1 = ( sess.query(Order) .order_by(Order.id) @@ -867,7 +867,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): properties={"description": deferred(orders.c.description)}, ) - sess = create_session() + sess = fixture_session() stmt = sa.select(Order).order_by(Order.id) o1 = (sess.query(Order).from_statement(stmt).all())[0] @@ -889,7 +889,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): }, ) - sess = create_session() + sess = fixture_session() stmt = sa.select(Order).order_by(Order.id) o1 = (sess.query(Order).from_statement(stmt).all())[0] @@ -906,7 +906,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(Order, orders) - sess = create_session() + sess = fixture_session() stmt = sa.select(Order).order_by(Order.id) o1 = ( sess.query(Order) @@ -927,7 +927,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(Order, orders) - sess = create_session() + sess = fixture_session() stmt = sa.select(Order).order_by(Order.id) o1 = ( sess.query(Order) @@ -971,7 +971,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): properties=dict(orders=relationship(Order, order_by=orders.c.id)), ) - sess = create_session() + sess = fixture_session() q = sess.query(User).order_by(User.id) result = q.all() item = result[0].orders[1].items[1] @@ -1020,7 +1020,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() exp = ( "SELECT users.id AS users_id, users.name AS users_name, " @@ -1049,7 +1049,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(User, users, properties={"orders": relationship(Order)}) mapper(Order, orders) - sess = create_session() + sess = fixture_session() q = sess.query(User).options( joinedload(User.orders).defer("description").defer("isopen") ) @@ -1070,7 +1070,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(Order, orders) - sess = create_session() + sess = fixture_session() q = sess.query(Order).options(load_only("isopen", "description")) self.assert_compile( q, @@ -1084,7 +1084,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(Order, orders) - sess = create_session() + sess = fixture_session() q = ( sess.query(Order) .order_by(Order.id) @@ -1101,7 +1101,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): properties={"description": deferred(orders.c.description)}, ) - sess = create_session() + sess = fixture_session() q = sess.query(Order).options( load_only("isopen", "description"), undefer("user_id") ) @@ -1129,7 +1129,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() expected = [ ( "SELECT users.id AS users_id, users.name AS users_name " @@ -1179,7 +1179,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(Address, addresses) mapper(Order, orders) - sess = create_session() + sess = fixture_session() q = sess.query(User, Order, Address).options( Load(User).load_only("name"), Load(Order).load_only("id"), @@ -1219,7 +1219,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(Address, addresses) mapper(Order, orders) - sess = create_session() + sess = fixture_session() q = sess.query(User).options( load_only("name") @@ -1274,7 +1274,7 @@ class SelfReferentialMultiPathTest(testing.fixtures.DeclarativeMappedTest): def test_present_overrides_deferred(self): Node = self.classes.Node - session = Session() + session = fixture_session() q = session.query(Node).options( joinedload(Node.parent).load_only(Node.id, Node.parent_id) @@ -1305,7 +1305,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_subclass(self): - s = Session() + s = fixture_session() q = ( s.query(Manager) .order_by(Manager.person_id) @@ -1324,7 +1324,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_subclass_bound(self): - s = Session() + s = fixture_session() q = ( s.query(Manager) .order_by(Manager.person_id) @@ -1343,7 +1343,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_subclass_and_superclass(self): - s = Session() + s = fixture_session() q = ( s.query(Boss) .order_by(Person.person_id) @@ -1362,7 +1362,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_subclass_and_superclass_bound(self): - s = Session() + s = fixture_session() q = ( s.query(Boss) .order_by(Person.person_id) @@ -1381,7 +1381,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_alias_subclass(self): - s = Session() + s = fixture_session() m1 = aliased(Manager, flat=True) q = ( s.query(m1) @@ -1401,7 +1401,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_alias_subclass_bound(self): - s = Session() + s = fixture_session() m1 = aliased(Manager, flat=True) q = ( s.query(m1) @@ -1421,7 +1421,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_subclass_from_relationship_polymorphic(self): - s = Session() + s = fixture_session() wp = with_polymorphic(Person, [Manager], flat=True) q = ( s.query(Company) @@ -1448,7 +1448,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_subclass_from_relationship_polymorphic_bound(self): - s = Session() + s = fixture_session() wp = with_polymorphic(Person, [Manager], flat=True) q = ( s.query(Company) @@ -1475,7 +1475,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_subclass_from_relationship(self): - s = Session() + s = fixture_session() q = ( s.query(Company) .join(Company.managers) @@ -1499,7 +1499,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_subclass_from_relationship_bound(self): - s = Session() + s = fixture_session() q = ( s.query(Company) .join(Company.managers) @@ -1528,7 +1528,7 @@ class InheritanceTest(_Polymorphic): # TODO: what is ".*"? this is not documented anywhere, how did this # get implemented without docs ? see #4390 - s = Session() + s = fixture_session() q = ( s.query(Manager) .order_by(Person.person_id) @@ -1545,7 +1545,7 @@ class InheritanceTest(_Polymorphic): # to have this ".*" featue. def test_load_only_subclass_of_type(self): - s = Session() + s = fixture_session() q = s.query(Company).options( joinedload(Company.employees.of_type(Manager)).load_only("status") ) @@ -1571,7 +1571,7 @@ class InheritanceTest(_Polymorphic): ) def test_wildcard_subclass_of_type(self): - s = Session() + s = fixture_session() q = s.query(Company).options( joinedload(Company.employees.of_type(Manager)).defer("*") ) @@ -1593,7 +1593,7 @@ class InheritanceTest(_Polymorphic): ) def test_defer_super_name_on_subclass(self): - s = Session() + s = fixture_session() q = s.query(Manager).order_by(Person.person_id).options(defer("name")) self.assert_compile( q, @@ -1608,7 +1608,7 @@ class InheritanceTest(_Polymorphic): ) def test_defer_super_name_on_subclass_bound(self): - s = Session() + s = fixture_session() q = ( s.query(Manager) .order_by(Person.person_id) @@ -1627,7 +1627,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_from_with_polymorphic(self): - s = Session() + s = fixture_session() wp = with_polymorphic(Person, [Manager], flat=True) @@ -1652,7 +1652,7 @@ class InheritanceTest(_Polymorphic): ) def test_load_only_of_type_with_polymorphic(self): - s = Session() + s = fixture_session() wp = with_polymorphic(Person, [Manager], flat=True) @@ -1755,7 +1755,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): def test_simple_expr(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = ( s.query(A) .options(with_expression(A.my_expr, A.x + A.y)) @@ -1768,7 +1768,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): def test_expr_default_value(self): A = self.classes.A C = self.classes.C - s = Session() + s = fixture_session() a1 = s.query(A).order_by(A.id).filter(A.x > 1) eq_(a1.all(), [A(my_expr=None), A(my_expr=None), A(my_expr=None)]) @@ -1789,7 +1789,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): def test_reuse_expr(self): A = self.classes.A - s = Session() + s = fixture_session() # so people will obv. want to say, "filter(A.my_expr > 10)". # but that means Query or Core has to post-modify the statement @@ -1807,7 +1807,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): def test_in_joinedload(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() q = ( s.query(A) @@ -1823,7 +1823,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): def test_no_refresh_unless_populate_existing(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).first() def go(): @@ -1855,7 +1855,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): def test_no_sql_not_set_up(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).first() def go(): @@ -1866,7 +1866,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): def test_dont_explode_on_expire_individual(self): A = self.classes.A - s = Session() + s = fixture_session() q = ( s.query(A) .options(with_expression(A.my_expr, A.x + A.y)) @@ -1895,7 +1895,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): def test_dont_explode_on_expire_whole(self): A = self.classes.A - s = Session() + s = fixture_session() q = ( s.query(A) .options(with_expression(A.my_expr, A.x + A.y)) @@ -1943,7 +1943,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_mapper_raise(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).first() assert_raises_message( sa.exc.InvalidRequestError, @@ -1956,7 +1956,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_mapper_defer_unraise(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).options(defer(A.z)).first() assert "z" not in a1.__dict__ eq_(a1.z, 4) @@ -1964,7 +1964,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_mapper_undefer_unraise(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).options(undefer(A.z)).first() assert "z" in a1.__dict__ eq_(a1.z, 4) @@ -1972,7 +1972,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_deferred_raise_option_raise_column_plain(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).options(defer(A.x)).first() a1.x @@ -1989,7 +1989,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_deferred_raise_option_load_column_unexpire(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).options(defer(A.x, raiseload=True)).first() s.expire(a1, ["x"]) @@ -1999,7 +1999,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_mapper_raise_after_expire_attr(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).first() s.expire(a1, ["z"]) @@ -2015,7 +2015,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_mapper_raise_after_expire_obj(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).first() s.expire(a1) @@ -2031,7 +2031,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_mapper_raise_after_modify_attr_expire_obj(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).first() a1.z = 10 @@ -2048,7 +2048,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_deferred_raise_option_load_after_expire_obj(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).options(defer(A.y, raiseload=True)).first() s.expire(a1) @@ -2059,7 +2059,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_option_raiseload_unexpire_modified_obj(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).options(defer(A.y, raiseload=True)).first() a1.y = 10 @@ -2071,7 +2071,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_option_raise_deferred(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).options(defer(A.y, raiseload=True)).first() assert_raises_message( @@ -2084,7 +2084,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def test_does_expire_cancel_normal_defer_option(self): A = self.classes.A - s = Session() + s = fixture_session() a1 = s.query(A).options(defer(A.x)).first() # expire object @@ -2119,7 +2119,7 @@ class AutoflushTest(fixtures.DeclarativeMappedTest): def test_deferred_autoflushes(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() a1 = A(id=1, bs=[B()]) s.add(a1) @@ -2190,14 +2190,14 @@ class DeferredPopulationTest(fixtures.MappedTest): def test_no_previous_query(self): Thing = self.classes.Thing - session = create_session() + session = fixture_session() thing = session.query(Thing).options(sa.orm.undefer("name")).first() self._test(thing) def test_query_twice_with_clear(self): Thing = self.classes.Thing - session = create_session() + session = fixture_session() result = session.query(Thing).first() # noqa session.expunge_all() thing = session.query(Thing).options(sa.orm.undefer("name")).first() @@ -2206,7 +2206,7 @@ class DeferredPopulationTest(fixtures.MappedTest): def test_query_twice_no_clear(self): Thing = self.classes.Thing - session = create_session() + session = fixture_session() result = session.query(Thing).first() # noqa thing = session.query(Thing).options(sa.orm.undefer("name")).first() self._test(thing) @@ -2214,7 +2214,7 @@ class DeferredPopulationTest(fixtures.MappedTest): def test_joinedload_with_clear(self): Thing, Human = self.classes.Thing, self.classes.Human - session = create_session() + session = fixture_session() human = ( # noqa session.query(Human).options(sa.orm.joinedload("thing")).first() ) @@ -2225,7 +2225,7 @@ class DeferredPopulationTest(fixtures.MappedTest): def test_joinedload_no_clear(self): Thing, Human = self.classes.Thing, self.classes.Human - session = create_session() + session = fixture_session() human = ( # noqa session.query(Human).options(sa.orm.joinedload("thing")).first() ) @@ -2235,7 +2235,7 @@ class DeferredPopulationTest(fixtures.MappedTest): def test_join_with_clear(self): Thing, Human = self.classes.Thing, self.classes.Human - session = create_session() + session = fixture_session() result = ( # noqa session.query(Human).add_entity(Thing).join("thing").first() ) @@ -2246,7 +2246,7 @@ class DeferredPopulationTest(fixtures.MappedTest): def test_join_no_clear(self): Thing, Human = self.classes.Thing, self.classes.Human - session = create_session() + session = fixture_session() result = ( # noqa session.query(Human).add_entity(Thing).join("thing").first() ) diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index c0d5a93d5..6d946cfe6 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -9,6 +9,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import literal_column +from sqlalchemy import MetaData from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import String @@ -16,6 +17,7 @@ from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import true +from sqlalchemy.engine import default from sqlalchemy.orm import aliased from sqlalchemy.orm import as_declarative from sqlalchemy.orm import attributes @@ -25,7 +27,6 @@ from sqlalchemy.orm import column_property from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import contains_alias from sqlalchemy.orm import contains_eager -from sqlalchemy.orm import create_session from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declared_attr from sqlalchemy.orm import defer @@ -51,18 +52,21 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import assertions from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import eq_ignore_whitespace from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.fixtures import ComparableEntity +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.mock import call from sqlalchemy.testing.mock import Mock from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from . import _fixtures from .inheritance import _poly_fixtures +from .test_bind import GetBindTest as _GetBindTest from .test_dynamic import _DynamicFixture from .test_events import _RemoveListeners from .test_options import PathTest as OptionsPathTest @@ -114,7 +118,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_deprecated_negative_slices(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User).order_by(User.id) with testing.expect_deprecated( @@ -143,7 +147,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_deprecated_negative_slices_compile(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User).order_by(User.id) with testing.expect_deprecated( @@ -181,7 +185,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_aliased(self): User = self.classes.User - s = create_session() + s = fixture_session() with testing.expect_deprecated_20(join_aliased_dep): q1 = s.query(User).join(User.addresses, aliased=True) @@ -197,7 +201,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - s = create_session() + s = fixture_session() u1 = aliased(User) @@ -219,7 +223,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - s = create_session() + s = fixture_session() u1 = aliased(User) @@ -236,7 +240,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_str_join_target(self): User = self.classes.User - s = create_session() + s = fixture_session() with testing.expect_deprecated_20(join_strings_dep): q1 = s.query(User).join("addresses") @@ -251,7 +255,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_str_rel_loader_opt(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User).options(joinedload("addresses")) @@ -272,7 +276,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_str_col_loader_opt(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User).options(defer("name")) @@ -286,7 +290,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - s = create_session() + s = fixture_session() u1 = User(id=1) @@ -321,7 +325,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_invalid_column(self): User = self.classes.User - s = create_session() + s = fixture_session() q = s.query(User.id) with testing.expect_deprecated(r"Query.add_column\(\) is deprecated"): @@ -334,7 +338,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_via_textasfrom_select_from(self): User = self.classes.User - s = create_session() + s = fixture_session() with self._expect_implicit_subquery(): eq_( @@ -350,7 +354,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_text_as_column(self): User = self.classes.User - s = create_session() + s = fixture_session() # TODO: this works as of "use rowproxy for ORM keyed tuple" # Ieb9085e9bcff564359095b754da9ae0af55679f0 @@ -374,7 +378,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_query_as_scalar(self): User = self.classes.User - s = Session() + s = fixture_session() with assertions.expect_deprecated( r"The Query.as_scalar\(\) method is deprecated and will " "be removed in a future release." @@ -385,7 +389,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): User, users = self.classes.User, self.tables.users sel = users.select() - sess = create_session() + sess = fixture_session() with self._expect_implicit_subquery(): eq_( @@ -399,7 +403,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_select_entity_from_select(self): User, users = self.classes.User, self.tables.users - sess = create_session() + sess = fixture_session() with self._expect_implicit_subquery(): self.assert_compile( sess.query(User.name).select_entity_from( @@ -413,7 +417,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_select_entity_from_q_statement(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User) with self._expect_implicit_subquery(): @@ -427,7 +431,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_select_from_q_statement_no_aliasing(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User) with self._expect_implicit_subquery(): @@ -455,7 +459,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): use_labels=True, order_by=[text("ulist.id"), addresses.c.id] ) ) - sess = create_session() + sess = fixture_session() # better way. use select_entity_from() def go(): @@ -477,7 +481,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.tables.users, ) - sess = create_session() + sess = fixture_session() # same thing, but alias addresses, so that the adapter # generated by select_entity_from() is wrapped within @@ -506,7 +510,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_select(self): users = self.tables.users - sess = create_session() + sess = fixture_session() with self._expect_implicit_subquery(): self.assert_compile( @@ -531,7 +535,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): # mapper(Address, addresses) sel = users.select(users.c.id.in_([7, 8])) - sess = create_session() + sess = fixture_session() with self._expect_implicit_subquery(): result = ( @@ -606,7 +610,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() sel = users.select(users.c.id.in_([7, 8])) with self._expect_implicit_subquery(): @@ -633,7 +637,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): User, users = self.classes.User, self.tables.users sel = users.select(users.c.id.in_([7, 8])) - sess = create_session() + sess = fixture_session() with self._expect_implicit_subquery(): eq_( @@ -649,7 +653,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): ) sel = users.select(users.c.id.in_([7, 8])) - sess = create_session() + sess = fixture_session() def go(): with self._expect_implicit_subquery(): @@ -727,7 +731,7 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.classes.User, ) - sess = Session() + sess = fixture_session() oalias = orders.select() @@ -800,7 +804,7 @@ class SelfRefFromSelfTest(fixtures.MappedTest, AssertsCompiledSQL): def insert_data(cls, connection): Node = cls.classes.Node - sess = create_session(connection) + sess = Session(connection) n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -819,7 +823,7 @@ class SelfRefFromSelfTest(fixtures.MappedTest, AssertsCompiledSQL): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n1 = aliased(Node) @@ -886,7 +890,7 @@ class SelfRefFromSelfTest(fixtures.MappedTest, AssertsCompiledSQL): def test_multiple_explicit_entities_two(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() parent = aliased(Node) grandparent = aliased(Node) @@ -906,7 +910,7 @@ class SelfRefFromSelfTest(fixtures.MappedTest, AssertsCompiledSQL): def test_multiple_explicit_entities_three(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() parent = aliased(Node) grandparent = aliased(Node) @@ -927,7 +931,7 @@ class SelfRefFromSelfTest(fixtures.MappedTest, AssertsCompiledSQL): def test_multiple_explicit_entities_five(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() parent = aliased(Node) grandparent = aliased(Node) @@ -952,7 +956,7 @@ class SelfRefFromSelfTest(fixtures.MappedTest, AssertsCompiledSQL): class DynamicTest(_DynamicFixture, _fixtures.FixtureTest): def test_negative_slice_access_raises(self): User, Address = self._user_address_fixture() - sess = create_session(testing.db) + sess = fixture_session() u1 = sess.get(User, 8) with testing.expect_deprecated_20( @@ -986,7 +990,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): User = self.classes.User - s = Session() + s = fixture_session() with self._from_self_deprecated(): q = s.query(User).from_self() @@ -1007,7 +1011,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_distinct_on(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() with self._from_self_deprecated(): q = ( @@ -1047,7 +1051,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): """ User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() with self._from_self_deprecated(): q = ( sess.query(User, Address.email_address) @@ -1068,7 +1072,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() with self._from_self_deprecated(): q = ( @@ -1105,7 +1109,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() # explicit onclause with from_self(), means # the onclause must be aliased against the query's custom # FROM object @@ -1127,7 +1131,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): Item, Keyword = self.classes.Item, self.classes.Keyword - sess = create_session() + sess = fixture_session() with self._from_self_deprecated(): self.assert_compile( @@ -1153,7 +1157,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): def test_single_prop_9(self): User = self.classes.User - sess = create_session() + sess = fixture_session() with self._from_self_deprecated(): self.assert_compile( sess.query(User) @@ -1171,7 +1175,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): def test_anonymous_expression_from_self_twice_oldstyle(self): # relies upon _orm_only_from_obj_alias setting - sess = create_session() + sess = fixture_session() c1, c2 = column("c1"), column("c2") q1 = sess.query(c1, c2).filter(c1 == "dog") with self._from_self_deprecated(): @@ -1193,7 +1197,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): Address = self.classes.Address addresses = self.tables.addresses - sess = create_session() + sess = fixture_session() q1 = sess.query(User.id).filter(User.id > 5) with self._from_self_deprecated(): q1 = q1.from_self() @@ -1220,7 +1224,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): Address = self.classes.Address addresses = self.tables.addresses - sess = create_session() + sess = fixture_session() q1 = sess.query(User.id).filter(User.id > 5) with self._from_self_deprecated(): q1 = q1.from_self() @@ -1242,7 +1246,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): # relies upon _orm_only_from_obj_alias setting from sqlalchemy.sql import column - sess = create_session() + sess = fixture_session() t1 = table("t1", column("c1"), column("c2")) q1 = sess.query(t1.c.c1, t1.c.c2).filter(t1.c.c1 == "dog") with self._from_self_deprecated(): @@ -1261,7 +1265,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): def test_self_referential(self): Order = self.classes.Order - sess = create_session() + sess = fixture_session() oalias = aliased(Order) with self._from_self_deprecated(): @@ -1364,7 +1368,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): # relies upon _orm_only_from_obj_alias setting Order = self.classes.Order - sess = create_session() + sess = fixture_session() # ensure column expressions are taken from inside the subquery, not # restated at the top @@ -1394,7 +1398,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): def test_column_access_from_self(self): User = self.classes.User - sess = create_session() + sess = fixture_session() with self._from_self_deprecated(): q = sess.query(User).from_self() @@ -1408,7 +1412,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): def test_column_access_from_self_twice(self): User = self.classes.User - sess = create_session() + sess = fixture_session() with self._from_self_deprecated(): q = sess.query(User).from_self(User.id, User.name).from_self() @@ -1428,7 +1432,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() adalias = aliased(Address) # select from aliasing + explicit aliasing @@ -1455,7 +1459,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() # anon + select from aliasing aa = aliased(Address) @@ -1475,7 +1479,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() adalias = aliased(Address) # test eager aliasing, with/without select_entity_from aliasing @@ -1606,7 +1610,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): with self._from_self_deprecated(): eq_( [User(id=8), User(id=9)], - create_session() + fixture_session() .query(User) .filter(User.id.in_([8, 9])) .from_self() @@ -1616,7 +1620,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): with self._from_self_deprecated(): eq_( [User(id=8), User(id=9)], - create_session() + fixture_session() .query(User) .order_by(User.id) .slice(1, 3) @@ -1628,7 +1632,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): eq_( [User(id=8)], list( - create_session() + fixture_session() .query(User) .filter(User.id.in_([8, 9])) .from_self() @@ -1647,7 +1651,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): (User(id=8), Address(id=4)), (User(id=9), Address(id=5)), ], - create_session() + fixture_session() .query(User) .filter(User.id.in_([8, 9])) .from_self() @@ -1661,7 +1665,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): Address = self.classes.Address eq_( - create_session() + fixture_session() .query(Address.user_id, func.count(Address.id).label("count")) .group_by(Address.user_id) .order_by(Address.user_id) @@ -1671,7 +1675,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): with self._from_self_deprecated(): eq_( - create_session() + fixture_session() .query(Address.user_id, Address.id) .from_self(Address.user_id, func.count(Address.id)) .group_by(Address.user_id) @@ -1683,7 +1687,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): def test_having(self): User = self.classes.User - s = create_session() + s = fixture_session() with self._from_self_deprecated(): self.assert_compile( @@ -1702,7 +1706,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): User = self.classes.User - s = create_session() + s = fixture_session() with self._from_self_deprecated(): q = s.query(User).options(joinedload(User.addresses)).from_self() @@ -1723,7 +1727,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() ualias = aliased(User) @@ -1775,7 +1779,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): def test_multiple_entities(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() with self._from_self_deprecated(): eq_( @@ -1805,7 +1809,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): # relies upon _orm_only_from_obj_alias setting User = self.classes.User - sess = create_session() + sess = fixture_session() with self._from_self_deprecated(): eq_( @@ -1879,7 +1883,7 @@ class SubqRelationsFromSelfTest(fixtures.DeclarativeMappedTest): def test_subq_w_from_self_one(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() cache = {} @@ -1956,7 +1960,7 @@ class SubqRelationsFromSelfTest(fixtures.DeclarativeMappedTest): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() cache = {} for i in range(3): @@ -2165,7 +2169,7 @@ class SessionTest(fixtures.RemovesEvents, _LocalFixture): with testing.expect_deprecated_20( "The Session.autocommit parameter is deprecated" ): - sess = Session(autocommit=True) + sess = Session(testing.db, autocommit=True) with sess.begin(): sess.add(User(name="u1")) @@ -2178,7 +2182,7 @@ class SessionTest(fixtures.RemovesEvents, _LocalFixture): with testing.expect_deprecated_20( "The Session.autocommit parameter is deprecated" ): - sess = Session(autocommit=True) + sess = Session(testing.db, autocommit=True) def go(): with sess.begin(): @@ -2193,6 +2197,58 @@ class SessionTest(fixtures.RemovesEvents, _LocalFixture): eq_(sess.query(User).count(), 1) +class AutocommitClosesOnFailTest(fixtures.MappedTest): + __requires__ = ("deferrable_fks",) + + @classmethod + def define_tables(cls, metadata): + Table("t1", metadata, Column("id", Integer, primary_key=True)) + + Table( + "t2", + metadata, + Column("id", Integer, primary_key=True), + Column( + "t1_id", + Integer, + ForeignKey("t1.id", deferrable=True, initially="deferred"), + ), + ) + + @classmethod + def setup_classes(cls): + class T1(cls.Comparable): + pass + + class T2(cls.Comparable): + pass + + @classmethod + def setup_mappers(cls): + T2, T1, t2, t1 = ( + cls.classes.T2, + cls.classes.T1, + cls.tables.t2, + cls.tables.t1, + ) + + mapper(T1, t1) + mapper(T2, t2) + + def test_close_transaction_on_commit_fail(self): + T2 = self.classes.T2 + + session = fixture_session(autocommit=True) + + # with a deferred constraint, this fails at COMMIT time instead + # of at INSERT time. + session.add(T2(t1_id=123)) + + assert_raises(sa.exc.IntegrityError, session.flush) + + assert session._legacy_transaction() is None + + class DeprecatedInhTest(_poly_fixtures._Polymorphic): def test_with_polymorphic(self): Person = _poly_fixtures.Person @@ -2217,7 +2273,7 @@ class DeprecatedInhTest(_poly_fixtures._Polymorphic): engineers = self.tables.engineers machines = self.tables.machines - sess = create_session() + sess = fixture_session() mach_alias = machines.select() with DeprecatedQueryTest._expect_implicit_subquery(): @@ -2364,7 +2420,7 @@ class DeprecatedMapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): m.add_property("_name", deferred(users.c.name)) m.add_property("name", synonym("_name")) - sess = create_session(autocommit=False) + sess = fixture_session(autocommit=False) assert sess.query(User).get(7) u = sess.query(User).filter_by(name="jack").one() @@ -2433,7 +2489,7 @@ class DeprecatedOptionAllTest(OptionsPathTest, _fixtures.FixtureTest): assert_raises_message( sa.exc.ArgumentError, message, - create_session() + fixture_session() .query(*entity_list) .options(*options) ._compile_context, @@ -2458,7 +2514,7 @@ class DeprecatedOptionAllTest(OptionsPathTest, _fixtures.FixtureTest): }, ) - sess = create_session() + sess = fixture_session() with testing.expect_deprecated( r"The \*addl_attrs on orm.defer is deprecated. " @@ -2779,7 +2835,7 @@ class NonPrimaryRelationshipLoaderTest(_fixtures.FixtureTest): User, Address, Order, Item = self.classes( "User", "Address", "Order", "Item" ) - q = create_session().query(User).order_by(User.id) + q = fixture_session().query(User).order_by(User.id) def go(): eq_( @@ -2813,21 +2869,21 @@ class NonPrimaryRelationshipLoaderTest(_fixtures.FixtureTest): self.assert_sql_count(testing.db, go, count) - sess = create_session() + sess = fixture_session() user = sess.query(User).get(7) closed_mapper = User.closed_orders.entity open_mapper = User.open_orders.entity eq_( [Order(id=1), Order(id=5)], - create_session() + fixture_session() .query(closed_mapper) .with_parent(user, property="closed_orders") .all(), ) eq_( [Order(id=3)], - create_session() + fixture_session() .query(open_mapper) .with_parent(user, property="open_orders") .all(), @@ -3004,7 +3060,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): use_labels=True, order_by=[text("ulist.id"), addresses.c.id] ) ) - sess = create_session() + sess = fixture_session() q = sess.query(User) # note this has multiple problems because we aren't giving Query @@ -3040,7 +3096,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): use_labels=True, order_by=[text("ulist.id"), addresses.c.id] ) ) - sess = create_session() + sess = fixture_session() q = sess.query(User) def go(): @@ -3066,7 +3122,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() selectquery = users.outerjoin(addresses).select( users.c.id < 10, @@ -3081,7 +3137,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): ): result = list( q.options(contains_eager("addresses")).instances( - selectquery.execute() + sess.execute(selectquery) ) ) assert self.static.user_address_result[0:3] == result @@ -3096,7 +3152,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): ): result = list( q.options(contains_eager(User.addresses)).instances( - selectquery.execute() + sess.connection().execute(selectquery) ) ) assert self.static.user_address_result[0:3] == result @@ -3110,7 +3166,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) adalias = addresses.alias("adalias") @@ -3131,7 +3187,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): result = list( q.options( contains_eager("addresses", alias="adalias") - ).instances(selectquery.execute()) + ).instances(sess.connection().execute(selectquery)) ) assert self.static.user_address_result == result @@ -3144,7 +3200,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) adalias = addresses.alias("adalias") @@ -3161,7 +3217,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): result = list( q.options( contains_eager("addresses", alias=adalias) - ).instances(selectquery.execute()) + ).instances(sess.connection().execute(selectquery)) ) assert self.static.user_address_result == result @@ -3176,7 +3232,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) oalias = orders.alias("o1") @@ -3202,7 +3258,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): q.options( contains_eager("orders", alias="o1"), contains_eager("orders.items", alias="i1"), - ).instances(query.execute()) + ).instances(sess.connection().execute(query)) ) assert self.static.user_order_result == result @@ -3217,7 +3273,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) oalias = orders.alias("o1") @@ -3245,7 +3301,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): q.options( contains_eager("orders", alias=oalias), contains_eager("orders.items", alias=ialias), - ).instances(query.execute()) + ).instances(sess.connection().execute(query)) ) assert self.static.user_order_result == result @@ -3268,7 +3324,7 @@ class DistinctOrderByImplicitTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_roundtrip_one(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_strings_dep): q = ( sess.query(User) @@ -3284,7 +3340,7 @@ class DistinctOrderByImplicitTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_roundtrip_two(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_strings_dep): q = ( sess.query(User) @@ -3300,7 +3356,7 @@ class DistinctOrderByImplicitTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_roundtrip_three(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = ( sess.query(User.id, User.name.label("foo"), Address.id) @@ -3331,7 +3387,7 @@ class DistinctOrderByImplicitTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_sql_one(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = ( sess.query(User.id, User.name.label("foo"), Address.id) @@ -3360,7 +3416,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def test_on_bulk_update_hook(self): User, users = self.classes.User, self.tables.users - sess = Session() + sess = fixture_session() canary = Mock() event.listen(sess, "after_bulk_update", canary.after_bulk_update) @@ -3390,7 +3446,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def test_on_bulk_delete_hook(self): User, users = self.classes.User, self.tables.users - sess = Session() + sess = fixture_session() canary = Mock() event.listen(sess, "after_bulk_delete", canary.after_bulk_delete) @@ -3438,7 +3494,7 @@ class ImmediateTest(_fixtures.FixtureTest): def test_value(self): User = self.classes.User - sess = create_session() + sess = fixture_session() with testing.expect_deprecated(r"Query.value\(\) is deprecated"): eq_(sess.query(User).filter_by(id=7).value(User.id), 7) @@ -3457,7 +3513,7 @@ class ImmediateTest(_fixtures.FixtureTest): def test_value_cancels_loader_opts(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = ( sess.query(User) @@ -3479,7 +3535,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() with testing.expect_deprecated(r"Query.values?\(\) is deprecated"): assert list(sess.query(User).values()) == list() @@ -3587,7 +3643,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): def test_values_specific_order_by(self): users, User = self.tables.users, self.classes.User - sess = create_session() + sess = fixture_session() with testing.expect_deprecated(r"Query.values?\(\) is deprecated"): assert list(sess.query(User).values()) == list() @@ -3626,6 +3682,13 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): "pg8000 parses the SQL itself before passing on " "to PG, doesn't parse this", ) + @testing.fails_on( + "postgresql+asyncpg", + "Asyncpg uses preprated statements that are not compatible with how " + "sqlalchemy passes the query. Fails with " + 'ERROR: column "users.name" must appear in the GROUP BY clause' + " or be used in an aggregate function", + ) @testing.fails_on("firebird", "unknown") def test_values_with_boolean_selects(self): """Tests a values clause that works with select boolean @@ -3633,7 +3696,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User) with testing.expect_deprecated(r"Query.values?\(\) is deprecated"): @@ -3690,7 +3753,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.Order, ) - sess = create_session() + sess = fixture_session() OrderAlias = aliased(Order) with testing.expect_deprecated_20(join_strings_dep): @@ -3777,7 +3840,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.Order, ) - sess = create_session() + sess = fixture_session() # no arg error with testing.expect_deprecated_20(join_aliased_dep): ( @@ -3799,7 +3862,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() # test a basic aliasized path with testing.expect_deprecated(join_aliased_dep, join_strings_dep): @@ -3885,7 +3948,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_overlapping_paths_two(self): User = self.classes.User - sess = create_session() + sess = fixture_session() # test overlapping paths. User->orders is used by both joins, but # rendered once. @@ -3918,7 +3981,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): with testing.expect_deprecated_20(*warnings): result = ( - create_session() + fixture_session() .query(User) .join("orders", "items", aliased=aliased_) .filter_by(id=3) @@ -3931,7 +3994,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_overlapping_paths_multilevel(self): User = self.classes.User - s = Session() + s = fixture_session() with testing.expect_deprecated_20(join_strings_dep, join_chain_dep): q = ( @@ -3959,7 +4022,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.Order, ) - sess = create_session() + sess = fixture_session() for oalias, ialias in [ (True, True), @@ -4023,7 +4086,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() with testing.expect_deprecated(join_tuple_form): q = ( @@ -4051,7 +4114,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_tuple_form): q = ( @@ -4080,7 +4143,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() # the old "backwards" form with testing.expect_deprecated_20(join_tuple_form, join_strings_dep): @@ -4108,7 +4171,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_strings_dep): q = ( @@ -4128,7 +4191,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "User", "Order", "Item", "Keyword" ) - sess = create_session() + sess = fixture_session() # ensure when the tokens are broken up that from_joinpoint # is set between them @@ -4156,7 +4219,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_single_name(self): User = self.classes.User - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_strings_dep): self.assert_compile( @@ -4206,7 +4269,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): User, ) = (self.classes.Order, self.classes.User) - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_chain_dep): self.assert_compile( sess.query(User).join(User.orders, Order.items), @@ -4221,7 +4284,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_single_prop_7(self): Order, User = (self.classes.Order, self.classes.User) - sess = create_session() + sess = fixture_session() # this query is somewhat nonsensical. the old system didn't render a # correct query for this. In this case its the most faithful to what # was asked - there's no linkage between User.orders and "oalias", @@ -4244,7 +4307,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): User, ) = (self.classes.Order, self.classes.User) - sess = create_session() + sess = fixture_session() # same as before using an aliased() for User as well ualias = aliased(User) oalias = aliased(Order) @@ -4263,7 +4326,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_single_prop_10(self): User, Address = (self.classes.User, self.classes.Address) - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_aliased_dep): self.assert_compile( sess.query(User) @@ -4282,7 +4345,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_aliased_dep, join_chain_dep): self.assert_compile( sess.query(User) @@ -4304,7 +4367,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_chain_dep, join_aliased_dep): self.assert_compile( @@ -4332,7 +4395,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() # this is now a very weird test, nobody should really # be using the aliased flag in this way. @@ -4397,7 +4460,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): with testing.expect_deprecated_20(join_strings_dep, join_chain_dep): result = ( - create_session() + fixture_session() .query(User) .select_from(users.join(oalias)) .filter( @@ -4411,7 +4474,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): with testing.expect_deprecated_20(join_strings_dep, join_chain_dep): result = ( - create_session() + fixture_session() .query(User) .select_from(users.join(oalias)) .filter( @@ -4437,7 +4500,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): # id 1 (order 3, owned by jack) with testing.expect_deprecated_20(*warnings): result = ( - create_session() + fixture_session() .query(User) .join("orders", "items", aliased=aliased_) .filter_by(id=3) @@ -4450,7 +4513,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): with testing.expect_deprecated_20(*warnings): result = ( - create_session() + fixture_session() .query(User) .join("orders", "items", aliased=aliased_, isouter=True) .filter_by(id=3) @@ -4463,7 +4526,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): with testing.expect_deprecated_20(*warnings): result = ( - create_session() + fixture_session() .query(User) .outerjoin("orders", "items", aliased=aliased_) .filter_by(id=3) @@ -4533,7 +4596,7 @@ class AliasFromCorrectLeftTest( def test_join_prop_to_string(self): A, B, X = self.classes("A", "B", "X") - s = Session() + s = fixture_session() with testing.expect_deprecated_20(join_strings_dep): q = s.query(B).join(B.a_list, "x_list").filter(X.name == "x1") @@ -4555,7 +4618,7 @@ class AliasFromCorrectLeftTest( def test_join_prop_to_prop(self): A, B, X = self.classes("A", "B", "X") - s = Session() + s = fixture_session() # B -> A, but both are Object. So when we say A.x_list, make sure # we pick the correct right side @@ -4622,7 +4685,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def insert_data(cls, connection): Node = cls.classes.Node - sess = create_session(connection) + sess = Session(connection) n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -4636,7 +4699,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_1(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_strings_dep, join_aliased_dep): node = ( @@ -4649,7 +4712,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_2(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_aliased_dep): ret = ( sess.query(Node.data) @@ -4661,7 +4724,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_3_filter_by(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20( join_strings_dep, join_aliased_dep, join_chain_dep ): @@ -4683,7 +4746,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_3_filter(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20( join_strings_dep, join_aliased_dep, join_chain_dep ): @@ -4705,7 +4768,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_4_filter_by(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_strings_dep, join_aliased_dep): q = ( @@ -4732,7 +4795,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_4_filter(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_strings_dep, join_aliased_dep): q = ( @@ -4765,7 +4828,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): Node = self.classes.Node - sess = create_session() + sess = fixture_session() nalias = aliased( Node, sess.query(Node).filter_by(data="n1").subquery() ) @@ -4804,7 +4867,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_string_or_prop_aliased_two(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() nalias = aliased( Node, sess.query(Node).filter_by(data="n1").subquery() ) @@ -4873,7 +4936,7 @@ class InheritedJoinTest(_poly_fixtures._Polymorphic, AssertsCompiledSQL): self.classes.Engineer, ) - sess = create_session() + sess = fixture_session() mach_alias = aliased(Machine, machines.select().subquery()) @@ -4906,7 +4969,7 @@ class InheritedJoinTest(_poly_fixtures._Polymorphic, AssertsCompiledSQL): self.classes.Paperwork, ) - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20(join_strings_dep, w_polymorphic_dep): self.assert_compile( @@ -4938,7 +5001,7 @@ class InheritedJoinTest(_poly_fixtures._Polymorphic, AssertsCompiledSQL): self.classes.Paperwork, ) - sess = create_session() + sess = fixture_session() with testing.expect_deprecated_20( join_strings_dep, w_polymorphic_dep, join_aliased_dep @@ -4995,7 +5058,7 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): def test_mapped_to_select_implicit_left_w_aliased(self): T1, T2 = self.classes.T1, self.classes.T2 - sess = Session() + sess = fixture_session() subq = ( sess.query(T2.t1_id, func.count(T2.id).label("count")) .group_by(T2.t1_id) @@ -5075,7 +5138,7 @@ class MultiplePathTest(fixtures.MappedTest, AssertsCompiledSQL): with testing.expect_deprecated_20(join_strings_dep): q = ( - create_session() + fixture_session() .query(T1) .join("t2s_1") .filter(t2.c.id == 5) @@ -5092,3 +5155,79 @@ class MultiplePathTest(fixtures.MappedTest, AssertsCompiledSQL): "WHERE t2.id = :id_1", use_default_dialect=True, ) + + +class BindSensitiveStringifyTest(fixtures.TestBase): + def _fixture(self): + # building a totally separate metadata /mapping here + # because we need to control if the MetaData is bound or not + + class User(object): + pass + + m = MetaData() + user_table = Table( + "users", + m, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) + + mapper(User, user_table) + return User + + def _dialect_fixture(self): + class MyDialect(default.DefaultDialect): + default_paramstyle = "qmark" + + from sqlalchemy.engine import base + + return base.Engine(mock.Mock(), MyDialect(), mock.Mock()) + + def _test(self, bound_session, session_present, expect_bound): + if bound_session: + eng = self._dialect_fixture() + else: + eng = None + + User = self._fixture() + + s = Session(eng if bound_session else None) + q = s.query(User).filter(User.id == 7) + if not session_present: + q = q.with_session(None) + + eq_ignore_whitespace( + str(q), + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id = ?" + if expect_bound + else "SELECT users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id = :id_1", + ) + + def test_query_bound_session(self): + self._test(True, True, True) + + def test_query_no_session(self): + self._test(False, False, False) + + def test_query_unbound_session(self): + self._test(False, True, False) + + +class GetBindTest(_GetBindTest): + @classmethod + def define_tables(cls, metadata): + super(GetBindTest, cls).define_tables(metadata) + metadata.bind = testing.db + + def test_fallback_table_metadata(self): + session = self._fixture({}) + is_(session.get_bind(self.classes.BaseClass), testing.db) + + def test_bind_base_table_concrete_sub_class(self): + base_class_bind = Mock() + session = self._fixture({self.tables.base_table: base_class_bind}) + + is_(session.get_bind(self.classes.ConcreteSubClass), testing.db) diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index 5f18b9bce..942e8383a 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -9,11 +9,9 @@ from sqlalchemy import testing from sqlalchemy.orm import attributes from sqlalchemy.orm import backref from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -21,6 +19,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import is_ from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures @@ -126,7 +125,7 @@ class _DynamicFixture(object): class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): def test_basic(self): User, Address = self._user_address_fixture() - sess = create_session() + sess = fixture_session() q = sess.query(User) eq_( @@ -152,7 +151,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): def test_slice_access(self): User, Address = self._user_address_fixture() - sess = create_session() + sess = fixture_session() u1 = sess.get(User, 8) eq_(u1.addresses.limit(1).one(), Address(id=2)) @@ -162,7 +161,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): def test_negative_slice_access_raises(self): User, Address = self._user_address_fixture() - sess = create_session(testing.db, future=True) + sess = fixture_session(future=True) u1 = sess.get(User, 8) with expect_raises_message( @@ -194,7 +193,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): would render, without any _clones called.""" User, Address = self._user_address_fixture() - sess = create_session() + sess = fixture_session() q = sess.query(User) u = q.filter(User.id == 7).first() @@ -208,7 +207,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): def test_detached_raise(self): User, Address = self._user_address_fixture() - sess = create_session() + sess = fixture_session() u = sess.query(User).get(8) sess.expunge(u) @@ -274,7 +273,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): def test_order_by(self): User, Address = self._user_address_fixture() - sess = create_session() + sess = fixture_session() u = sess.query(User).get(8) eq_( list(u.addresses.order_by(desc(Address.email_address))), @@ -291,7 +290,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): addresses_args={"order_by": addresses.c.email_address.desc()} ) - sess = create_session() + sess = fixture_session() u = sess.query(User).get(8) eq_( list(u.addresses), @@ -326,7 +325,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): def test_count(self): User, Address = self._user_address_fixture() - sess = create_session() + sess = fixture_session() u = sess.query(User).first() eq_(u.addresses.count(), 1) @@ -349,7 +348,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): ) mapper(User, users) - sess = create_session() + sess = fixture_session() ad = sess.query(Address).get(1) def go(): @@ -362,7 +361,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): def test_no_count(self): User, Address = self._user_address_fixture() - sess = create_session() + sess = fixture_session() q = sess.query(User) # dynamic collection cannot implement __len__() (at least one that @@ -400,7 +399,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): items_args={"backref": backref("orders", lazy="dynamic")} ) - sess = create_session() + sess = fixture_session() o1 = Order(id=15, description="order 10") i1 = Item(id=10, description="item 8") o1.items.append(i1) @@ -439,7 +438,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() o = sess.query(Order).first() self.assert_compile( @@ -477,7 +476,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() u1 = sess.query(User).first() self.assert_compile( @@ -529,7 +528,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): properties={"item_keywords": relationship(ItemKeyword)}, ) - sess = create_session() + sess = fixture_session() order = sess.query(Order).first() self.assert_compile( @@ -572,14 +571,14 @@ class UOWTest( addresses = self.tables.addresses User, Address = self._user_address_fixture() - sess = create_session() + sess = fixture_session() u1 = User(name="jack") a1 = Address(email_address="foo") sess.add_all([u1, a1]) sess.flush() eq_( - testing.db.scalar( + sess.connection().scalar( select(func.count(cast(1, Integer))).where( addresses.c.user_id != None ) @@ -591,16 +590,18 @@ class UOWTest( sess.flush() eq_( - testing.db.execute( + sess.connection() + .execute( select(addresses).where(addresses.c.user_id != None) # noqa - ).fetchall(), + ) + .fetchall(), [(a1.id, u1.id, "foo")], ) u1.addresses.remove(a1) sess.flush() eq_( - testing.db.scalar( + sess.connection().scalar( select(func.count(cast(1, Integer))).where( addresses.c.user_id != None ) @@ -611,9 +612,11 @@ class UOWTest( u1.addresses.append(a1) sess.flush() eq_( - testing.db.execute( + sess.connection() + .execute( select(addresses).where(addresses.c.user_id != None) # noqa - ).fetchall(), + ) + .fetchall(), [(a1.id, u1.id, "foo")], ) @@ -622,9 +625,11 @@ class UOWTest( u1.addresses.append(a2) sess.flush() eq_( - testing.db.execute( + sess.connection() + .execute( select(addresses).where(addresses.c.user_id != None) # noqa - ).fetchall(), + ) + .fetchall(), [(a2.id, u1.id, "bar")], ) @@ -633,7 +638,7 @@ class UOWTest( User, Address = self._user_address_fixture( addresses_args={"order_by": addresses.c.email_address} ) - sess = create_session() + sess = fixture_session(autoflush=False) u1 = User(name="jack") a1 = Address(email_address="a1") a2 = Address(email_address="a2") @@ -669,7 +674,7 @@ class UOWTest( User, Address = self._user_address_fixture( addresses_args={"order_by": addresses.c.email_address} ) - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session(autoflush=True, autocommit=False) u1 = User(name="jack") a1 = Address(email_address="a1") a2 = Address(email_address="a2") @@ -691,7 +696,7 @@ class UOWTest( # when flushing an append User, Address = self._user_address_fixture() - sess = Session() + sess = fixture_session() u1 = User(name="jack", addresses=[Address(email_address="a1")]) sess.add(u1) sess.commit() @@ -721,7 +726,7 @@ class UOWTest( # when flushing a remove User, Address = self._user_address_fixture() - sess = Session() + sess = fixture_session() u1 = User(name="jack", addresses=[Address(email_address="a1")]) a2 = Address(email_address="a2") u1.addresses.append(a2) @@ -757,7 +762,7 @@ class UOWTest( def test_rollback(self): User, Address = self._user_address_fixture() - sess = create_session( + sess = fixture_session( expire_on_commit=False, autocommit=False, autoflush=True ) u1 = User(name="jack") @@ -786,7 +791,7 @@ class UOWTest( } ) - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session(autoflush=True, autocommit=False) u = User(name="ed") u.addresses.extend( [Address(email_address=letter) for letter in "abcdef"] @@ -854,7 +859,7 @@ class UOWTest( }, ) - sess = Session() + sess = fixture_session() n2, n3 = Node(), Node() n1 = Node(children=[n2, n3]) sess.add(n1) @@ -872,7 +877,7 @@ class UOWTest( } ) - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session(autoflush=True, autocommit=False) u = User(name="ed") u.addresses.extend( [Address(email_address=letter) for letter in "abcdef"] @@ -893,7 +898,7 @@ class UOWTest( User, Address = self._user_address_fixture( addresses_args={"backref": "user"} ) - sess = create_session(autoflush=autoflush, autocommit=False) + sess = fixture_session(autoflush=autoflush, autocommit=False) u = User(name="buffy") @@ -948,29 +953,28 @@ class UOWTest( addresses_args={"backref": "user"} ) - session = create_session() - user = User() - user.name = "joe" - user.fullname = "Joe User" - user.password = "Joe's secret" - address = Address() - address.email_address = "joe@joesdomain.example" - address.user = user - session.add(user) - session.flush() - session.expunge_all() + with fixture_session() as session: + user = User() + user.name = "joe" + user.fullname = "Joe User" + user.password = "Joe's secret" + address = Address() + address.email_address = "joe@joesdomain.example" + address.user = user + session.add(user) + session.commit() def query1(): - session = create_session(testing.db) + session = fixture_session() user = session.query(User).first() return user.addresses.all() def query2(): - session = create_session(testing.db) + session = fixture_session() return session.query(User).first().addresses.all() def query3(): - session = create_session(testing.db) + session = fixture_session() return session.query(User).first().addresses.all() eq_(query1(), [Address(email_address="joe@joesdomain.example")]) @@ -997,7 +1001,7 @@ class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): u1 = User(name="u1") a1 = Address(email_address="a1") - s = Session(autoflush=autoflush) + s = fixture_session(autoflush=autoflush) s.add(u1) s.flush() return u1, a1, s @@ -1007,7 +1011,7 @@ class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): o1 = Order() i1 = Item(description="i1") - s = Session(autoflush=autoflush) + s = fixture_session(autoflush=autoflush) s.add(o1) s.flush() return o1, i1, s diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index 7bc82b2a3..4498fc1ff 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -17,7 +17,6 @@ from sqlalchemy.orm import backref from sqlalchemy.orm import close_all_sessions from sqlalchemy.orm import column_property from sqlalchemy.orm import contains_eager -from sqlalchemy.orm import create_session from sqlalchemy.orm import defaultload from sqlalchemy.orm import deferred from sqlalchemy.orm import joinedload @@ -38,6 +37,7 @@ from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.util import OrderedDict as odict @@ -68,7 +68,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) eq_( @@ -91,7 +91,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) m = mapper(User, users) - sess = create_session() + sess = fixture_session() sess.query(User).all() m.add_property("addresses", relationship(mapper(Address, addresses))) @@ -136,7 +136,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() user = sess.query(User).get(7) assert getattr(User, "addresses").hasparent( sa.orm.attributes.instance_state(user.addresses[0]), @@ -165,7 +165,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - q = create_session().query(User) + q = fixture_session().query(User) eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -202,7 +202,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - q = create_session().query(User) + q = fixture_session().query(User) eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -242,7 +242,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - q = create_session().query(User) + q = fixture_session().query(User) result = ( q.filter(User.id == Address.user_id) .order_by(Address.email_address) @@ -285,7 +285,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - sess = create_session() + sess = fixture_session() eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -318,7 +318,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): mapper(Address, addresses) mapper(User, users, properties=dict(addresses=relationship(Address))) - sess = create_session() + sess = fixture_session() q = ( sess.query(User) .join("addresses") @@ -369,7 +369,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(User, users) - sess = create_session() + sess = fixture_session() for q in [ sess.query(Address) @@ -559,7 +559,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): 5, ), ]: - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -585,7 +585,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): properties={"addresses": relationship(Address, lazy="dynamic")}, ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() assert_raises_message( sa.exc.InvalidRequestError, "User.addresses' does not support object " @@ -616,7 +616,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - q = create_session().query(Item).order_by(Item.id) + q = fixture_session().query(Item).order_by(Item.id) def go(): eq_(self.static.item_keyword_result, q.all()) @@ -662,7 +662,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - q = create_session().query(Item) + q = fixture_session().query(Item) def go(): eq_( @@ -704,7 +704,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): eq_(sa.orm.class_mapper(User).get_property("addresses").lazy, "joined") eq_(sa.orm.class_mapper(Address).get_property("user").lazy, "joined") - sess = create_session() + sess = fixture_session() eq_( self.static.user_address_result, sess.query(User).order_by(User.id).all(), @@ -926,7 +926,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): User, Address, Order, Item = self.classes( "User", "Address", "Order", "Item" ) - q = create_session().query(User).order_by(User.id) + q = fixture_session().query(User).order_by(User.id) def items(*ids): if no_items: @@ -993,14 +993,14 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): mapper(Address, addresses) mapper(Order, orders) - self.allusers = create_session().query(User).all() + self.allusers = fixture_session().query(User).all() # using a textual select, the columns will be 'id' and 'name'. the # eager loaders have aliases which should not hit on those columns, # they should be required to locate only their aliased/fully table # qualified column name. noeagers = ( - create_session() + fixture_session() .query(User) .from_statement(text("select * from users")) .all() @@ -1061,7 +1061,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) result = q.order_by(User.id).limit(2).offset(1).all() @@ -1098,7 +1098,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) def go(): @@ -1134,7 +1134,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): }, ) - q = create_session().query(User) + q = fixture_session().query(User) eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -1175,7 +1175,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - sess = create_session() + sess = fixture_session() q = sess.query(Item) result = ( q.filter( @@ -1240,7 +1240,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ), ) - sess = create_session() + sess = fixture_session() q = sess.query(User) @@ -1316,7 +1316,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(Order, orders) - sess = create_session() + sess = fixture_session() eq_( sess.query(User).first(), User( @@ -1369,7 +1369,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): }, ) - sess = create_session() + sess = fixture_session() u1 = sess.query(User).filter(User.id == 8).one() def go(): @@ -1413,7 +1413,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): class MyBogusOption(MapperOption): propagate_to_loaders = True - sess = create_session() + sess = fixture_session() u1 = ( sess.query(User) .options(MyBogusOption()) @@ -1485,7 +1485,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): mapper(Address, addresses) mapper(Item, items) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User).options(joinedload(User.orders)).limit(10), @@ -1641,7 +1641,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - q = create_session().query(User) + q = fixture_session().query(User) def go(): result = q.filter(users.c.id == 7).all() @@ -1666,7 +1666,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - q = create_session().query(User) + q = fixture_session().query(User) q = q.filter(users.c.id == 7).limit(1) self.assert_compile( @@ -1697,7 +1697,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): user=relationship(mapper(User, users), lazy="joined") ), ) - sess = create_session() + sess = fixture_session() q = sess.query(Address) def go(): @@ -1737,7 +1737,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - sess = create_session() + sess = fixture_session() def go(): o1 = ( @@ -1795,7 +1795,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - q = create_session().query(User) + q = fixture_session().query(User) result = q.filter(text("users.id in (7, 8, 9)")).order_by( text("users.id") @@ -1838,7 +1838,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): }, ) - q = create_session().query(User) + q = fixture_session().query(User) def go(): eq_( @@ -1878,7 +1878,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): properties={"order": relationship(Order, uselist=False)}, ) mapper(Order, orders) - s = create_session() + s = fixture_session() assert_raises( sa.exc.SAWarning, s.query(User).options(joinedload(User.order)).all ) @@ -1931,7 +1931,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): orders=relationship(Order, lazy=False, order_by=orders.c.id), ), ) - q = create_session().query(User) + q = fixture_session().query(User) def go(): eq_(self.static.user_all_result, q.order_by(User.id).all()) @@ -1959,7 +1959,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): mapper(User, users) mapper(Item, items) - q = create_session().query(Order) + q = fixture_session().query(Order) eq_( [Order(id=3, user=User(id=7)), Order(id=4, user=User(id=9))], q.all(), @@ -1992,7 +1992,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - q = create_session().query(User) + q = fixture_session().query(User) result = ( q.filter(addresses.c.email_address == "ed@lala.com") .filter(Address.user_id == User.id) @@ -2020,7 +2020,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - sess = create_session() + sess = fixture_session() eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -2079,7 +2079,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User), "SELECT users.id AS users_id, users.name AS users_name, " @@ -2171,7 +2171,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User), "SELECT users.id AS users_id, users.name AS users_name, " @@ -2286,7 +2286,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() q = sess.query(User).options( joinedload("orders", innerjoin=False).joinedload( "items", innerjoin=True @@ -2361,7 +2361,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): mapper(Order, orders) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() q = sess.query(User).options( joinedload("orders"), joinedload("addresses", innerjoin="unnested") ) @@ -2402,7 +2402,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): mapper(Order, orders) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() q = sess.query(User).options( joinedload("orders"), joinedload("addresses", innerjoin=True) ) @@ -2478,7 +2478,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(Keyword, keywords) - sess = create_session() + sess = fixture_session() q = ( sess.query(User) .join(User.orders) @@ -2549,7 +2549,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() # joining from user, its all LEFT OUTER JOINs self.assert_compile( @@ -2610,7 +2610,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User), @@ -2662,7 +2662,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) mapper(Item, items) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User).options(joinedload(User.orders, innerjoin=True)), "SELECT users.id AS users_id, users.name AS users_name, " @@ -2782,7 +2782,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() if use_load: opt = Load(User).defaultload("orders").lazyload("*") @@ -2833,7 +2833,7 @@ class SelectUniqueTest(_fixtures.FixtureTest): .order_by(Address.id) ) - s = create_session() + s = fixture_session() result = s.execute(stmt) eq_(result.scalars().all(), self.static.address_user_result) @@ -2842,7 +2842,7 @@ class SelectUniqueTest(_fixtures.FixtureTest): User = self.classes.User stmt = select(User).options(joinedload(User.addresses)) - s = create_session() + s = fixture_session() result = s.execute(stmt) with expect_raises_message( @@ -2857,7 +2857,7 @@ class SelectUniqueTest(_fixtures.FixtureTest): stmt = ( select(User).options(joinedload(User.addresses)).order_by(User.id) ) - s = create_session() + s = fixture_session() result = s.execute(stmt) eq_( @@ -2871,7 +2871,7 @@ class SelectUniqueTest(_fixtures.FixtureTest): stmt = ( select(User).options(joinedload(User.addresses)).order_by(User.id) ) - s = create_session() + s = fixture_session() result = s.execute(stmt) eq_(result.scalars().unique().all(), self.static.user_address_result) @@ -2886,7 +2886,7 @@ class SelectUniqueTest(_fixtures.FixtureTest): .options(joinedload(User.addresses)) .order_by(User.id, Address.id) ) - s = create_session() + s = fixture_session() result = s.execute(stmt) eq_( @@ -2908,7 +2908,7 @@ class SelectUniqueTest(_fixtures.FixtureTest): .options(joinedload(User.addresses)) .order_by(User.id) ) - s = create_session() + s = fixture_session() result = s.execute(stmt) eq_( @@ -3119,7 +3119,7 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): self.classes.C2, ) - s = Session() + s = fixture_session() q = s.query(A).options( joinedload(A.bs, innerjoin=False) @@ -3153,7 +3153,7 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): # test #3447 A = self.classes.A - s = Session() + s = fixture_session() q = s.query(A).options( joinedload("bs"), @@ -3181,7 +3181,7 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): def test_multiple_splice_points(self): A = self.classes.A - s = Session() + s = fixture_session() q = s.query(A).options( joinedload("bs", innerjoin=False), @@ -3229,7 +3229,7 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): a_mapper = inspect(A) a_mapper.add_property("bs_np", relationship(b_np, viewonly=True)) - s = Session() + s = fixture_session() q = s.query(A).options(joinedload("bs_np", innerjoin=False)) self.assert_compile( @@ -3345,7 +3345,7 @@ class InnerJoinSplicingWSecondaryTest( def test_joined_across(self): A = self.classes.A - s = Session() + s = fixture_session() q = s.query(A).options( joinedload("b") .joinedload("c", innerjoin=True) @@ -3407,7 +3407,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): self._fixture({"summation": column_property(cp.scalar_subquery())}) self.assert_compile( - create_session() + fixture_session() .query(A) .options(joinedload("bs")) .order_by(A.summation) @@ -3430,7 +3430,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): self._fixture({"summation": column_property(cp.scalar_subquery())}) self.assert_compile( - create_session() + fixture_session() .query(A) .options(joinedload("bs")) .order_by(A.summation.desc()) @@ -3455,7 +3455,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): self._fixture({"summation": column_property(cp.scalar_subquery())}) self.assert_compile( - create_session() + fixture_session() .query(A) .options(joinedload("bs")) .order_by(A.summation) @@ -3484,7 +3484,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): # the removal of a separate _make_proxy() from ScalarSelect # fixed that. self.assert_compile( - create_session() + fixture_session() .query(A) .options(joinedload("bs")) .order_by(cp) @@ -3511,7 +3511,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): .label("foo") ) self.assert_compile( - create_session() + fixture_session() .query(A) .options(joinedload("bs")) .order_by(cp) @@ -3540,7 +3540,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): # TODO: there is no test in Core that asserts what is happening # here as far as the label generation for the ORDER BY self.assert_compile( - create_session() + fixture_session() .query(A) .options(joinedload("bs")) .order_by(~cp) @@ -3581,7 +3581,7 @@ class LoadOnExistingTest(_fixtures.FixtureTest): ) mapper(Dingaling, self.tables.dingalings) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) return User, Address, Dingaling, sess def _collection_to_collection_fixture(self): @@ -3602,7 +3602,7 @@ class LoadOnExistingTest(_fixtures.FixtureTest): ) mapper(Item, self.tables.items) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) return User, Order, Item, sess def _eager_config_fixture(self): @@ -3613,7 +3613,7 @@ class LoadOnExistingTest(_fixtures.FixtureTest): properties={"addresses": relationship(Address, lazy="joined")}, ) mapper(Address, self.tables.addresses) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) return User, Address, sess def test_runs_query_on_refresh(self): @@ -3804,7 +3804,7 @@ class AddEntityTest(_fixtures.FixtureTest): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() oalias = sa.orm.aliased(Order) def go(): @@ -3861,7 +3861,7 @@ class AddEntityTest(_fixtures.FixtureTest): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() oalias = sa.orm.aliased(Order) @@ -3958,7 +3958,7 @@ class OrderBySecondaryTest(fixtures.MappedTest): ) mapper(B, b) - sess = create_session() + sess = fixture_session() eq_( sess.query(A).all(), [ @@ -3997,7 +3997,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest): ) }, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -4073,7 +4073,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest): ) }, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -4121,7 +4121,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest): "data": deferred(nodes.c.data), }, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -4178,7 +4178,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest): ) }, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -4249,7 +4249,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest): nodes, properties={"children": relationship(Node, lazy="joined")}, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -4372,7 +4372,7 @@ class MixedSelfReferentialEagerTest(fixtures.MappedTest): def test_eager_load(self): A, B = self.classes.A, self.classes.B - session = create_session() + session = fixture_session() def go(): eq_( @@ -4451,7 +4451,7 @@ class SelfReferentialM2MEagerTest(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() w1 = Widget(name="w1") w2 = Widget(name="w2") w1.children.append(w2) @@ -4540,7 +4540,7 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() # two FROM clauses def go(): @@ -4603,7 +4603,7 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() # two FROM clauses where there's a join on each one def go(): @@ -4711,7 +4711,7 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() oalias = sa.orm.aliased(Order) @@ -4748,7 +4748,7 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() oalias = sa.orm.aliased(Order) @@ -4780,7 +4780,7 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def test_aliased_entity_three(self): Order, User = (self.classes.Order, self.classes.User) - sess = create_session() + sess = fixture_session() oalias = sa.orm.aliased(Order) @@ -4894,7 +4894,7 @@ class SubqueryTest(fixtures.MappedTest): }, ) - session = create_session() + session = fixture_session() session.add( User( name="joe", @@ -5088,7 +5088,7 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -5105,7 +5105,7 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): self.assert_sql_count(testing.db, go, 1) - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -5115,7 +5115,7 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): self.assert_sql_count(testing.db, go, 2) - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -5128,7 +5128,7 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): self.assert_sql_count(testing.db, go, 1) - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -5199,7 +5199,7 @@ class CyclicalInheritingEagerTestOne(fixtures.MappedTest): mapper(SubT2, None, inherits=T2, polymorphic_identity="subt2") # testing a particular endless loop condition in eager load setup - create_session().query(SubT).all() + fixture_session().query(SubT).all() class CyclicalInheritingEagerTestTwo( @@ -5231,7 +5231,7 @@ class CyclicalInheritingEagerTestTwo( def test_from_subclass(self): Director = self.classes.Director - s = create_session() + s = fixture_session() self.assert_compile( s.query(Director).options(joinedload("*")), @@ -5298,7 +5298,7 @@ class CyclicalInheritingEagerTestThree( def test_gen_query_nodepth(self): PersistentObject = self.classes.PersistentObject - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(PersistentObject), "SELECT persistent.id AS persistent_id, " @@ -5311,7 +5311,7 @@ class CyclicalInheritingEagerTestThree( def test_gen_query_depth(self): PersistentObject = self.classes.PersistentObject Director = self.classes.Director - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(PersistentObject).options(joinedload(Director.other)), "SELECT persistent.id AS persistent_id, " @@ -5374,7 +5374,7 @@ class EnsureColumnsAddedTest( def test_joinedload_defered_pk_limit_o2m(self): Parent = self.classes.Parent - s = Session() + s = fixture_session() self.assert_compile( s.query(Parent) @@ -5394,7 +5394,7 @@ class EnsureColumnsAddedTest( def test_joinedload_defered_pk_limit_m2m(self): Parent = self.classes.Parent - s = Session() + s = fixture_session() self.assert_compile( s.query(Parent) @@ -5416,7 +5416,7 @@ class EnsureColumnsAddedTest( def test_joinedload_defered_pk_o2m(self): Parent = self.classes.Parent - s = Session() + s = fixture_session() self.assert_compile( s.query(Parent).options( @@ -5432,7 +5432,7 @@ class EnsureColumnsAddedTest( def test_joinedload_defered_pk_m2m(self): Parent = self.classes.Parent - s = Session() + s = fixture_session() self.assert_compile( s.query(Parent).options( @@ -5492,7 +5492,7 @@ class EntityViaMultiplePathTestOne(fixtures.DeclarativeMappedTest): def test_multi_path_load(self): A, B, C, D = self.classes("A", "B", "C", "D") - s = Session() + s = fixture_session() c = C(d=D()) @@ -5521,7 +5521,7 @@ class EntityViaMultiplePathTestOne(fixtures.DeclarativeMappedTest): def test_multi_path_load_of_type(self): A, B, C, D = self.classes("A", "B", "C", "D") - s = Session() + s = fixture_session() c = C(d=D()) @@ -5593,7 +5593,7 @@ class EntityViaMultiplePathTestTwo(fixtures.DeclarativeMappedTest): def test_multi_path_load_legacy_join_style(self): User, LD, A, LDA = self.classes("User", "LD", "A", "LDA") - s = Session() + s = fixture_session() u0 = User(data=42) l0 = LD(user=u0) @@ -5624,7 +5624,7 @@ class EntityViaMultiplePathTestTwo(fixtures.DeclarativeMappedTest): def test_multi_path_load_of_type(self): User, LD, A, LDA = self.classes("User", "LD", "A", "LDA") - s = Session() + s = fixture_session() u0 = User(data=42) l0 = LD(user=u0) @@ -5696,7 +5696,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_string_options_aliased_whatever(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) q = ( s.query(aa, A) @@ -5709,7 +5709,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_string_options_unaliased_whatever(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) q = ( s.query(A, aa) @@ -5722,7 +5722,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_lazyload_aliased_abs_bcs_one(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) q = ( s.query(aa, A) @@ -5735,7 +5735,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_lazyload_aliased_abs_bcs_two(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) q = ( s.query(aa, A) @@ -5748,7 +5748,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_pathed_lazyload_aliased_abs_bcs(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) opt = Load(A).joinedload(A.bs).joinedload(B.cs) @@ -5763,7 +5763,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_pathed_lazyload_plus_joined_aliased_abs_bcs(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) opt = Load(aa).defaultload(aa.bs).joinedload(B.cs) @@ -5778,7 +5778,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_pathed_joinedload_aliased_abs_bcs(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) opt = Load(aa).joinedload(aa.bs).joinedload(B.cs) @@ -5793,7 +5793,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_lazyload_plus_joined_aliased_abs_bcs(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) q = ( s.query(aa, A) @@ -5806,7 +5806,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_joinedload_aliased_abs_bcs(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) q = ( s.query(aa, A) @@ -5819,7 +5819,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_lazyload_unaliased_abs_bcs_one(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) q = ( s.query(A, aa) @@ -5832,7 +5832,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_lazyload_unaliased_abs_bcs_two(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) q = ( s.query(A, aa) @@ -5845,7 +5845,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_lazyload_plus_joined_unaliased_abs_bcs(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) q = ( s.query(A, aa) @@ -5858,7 +5858,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): def test_joinedload_unaliased_abs_bcs(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() aa = aliased(A) q = ( s.query(A, aa) @@ -5886,7 +5886,7 @@ class EntityViaMultiplePathTestThree(fixtures.DeclarativeMappedTest): def test_multi_path_load_lazy_none(self): A = self.classes.A - s = Session() + s = fixture_session() s.add_all( [ A(id=1, parent_id=None), @@ -5986,7 +5986,7 @@ class DeepOptionsTest(_fixtures.FixtureTest): def test_deep_options_1(self): User = self.classes.User - sess = create_session() + sess = fixture_session() # joinedload nothing. u = sess.query(User).order_by(User.id).all() @@ -6001,7 +6001,7 @@ class DeepOptionsTest(_fixtures.FixtureTest): User = self.classes.User - sess = create_session() + sess = fixture_session() result = ( sess.query(User) @@ -6018,7 +6018,7 @@ class DeepOptionsTest(_fixtures.FixtureTest): self.sql_count_(0, go) - sess = create_session() + sess = fixture_session() result = ( sess.query(User).options( @@ -6036,7 +6036,7 @@ class DeepOptionsTest(_fixtures.FixtureTest): def test_deep_options_3(self): User = self.classes.User - sess = create_session() + sess = fixture_session() # same thing, with separate options calls q2 = ( @@ -6060,7 +6060,7 @@ class DeepOptionsTest(_fixtures.FixtureTest): self.classes.Order, ) - sess = create_session() + sess = fixture_session() assert_raises_message( sa.exc.ArgumentError, @@ -6088,7 +6088,7 @@ class DeepOptionsTest(_fixtures.FixtureTest): self.sql_count_(2, go) - sess = create_session() + sess = fixture_session() q3 = ( sess.query(User) .order_by(User.id) @@ -6220,7 +6220,7 @@ class SecondaryOptionsTest(fixtures.MappedTest): def test_contains_eager(self): Child1, Related = self.classes.Child1, self.classes.Related - sess = create_session() + sess = fixture_session() child1s = ( sess.query(Child1) @@ -6258,7 +6258,7 @@ class SecondaryOptionsTest(fixtures.MappedTest): def test_joinedload_on_other(self): Child1, Related = self.classes.Child1, self.classes.Related - sess = create_session() + sess = fixture_session() child1s = ( sess.query(Child1) @@ -6300,7 +6300,7 @@ class SecondaryOptionsTest(fixtures.MappedTest): self.classes.Related, ) - sess = create_session() + sess = fixture_session() child1s = ( sess.query(Child1) diff --git a/test/orm/test_evaluator.py b/test/orm/test_evaluator.py index ec843d1c5..db56eeb83 100644 --- a/test/orm/test_evaluator.py +++ b/test/orm/test_evaluator.py @@ -13,12 +13,12 @@ from sqlalchemy.orm import evaluator from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -290,7 +290,7 @@ class M2OEvaluateTest(fixtures.DeclarativeMappedTest): def test_delete_not_expired(self): Parent, Child = self.classes("Parent", "Child") - session = Session(expire_on_commit=False) + session = fixture_session(expire_on_commit=False) p = Parent(id=1) session.add(p) @@ -307,7 +307,7 @@ class M2OEvaluateTest(fixtures.DeclarativeMappedTest): def test_delete_expired(self): Parent, Child = self.classes("Parent", "Child") - session = Session() + session = fixture_session() p = Parent(id=1) session.add(p) diff --git a/test/orm/test_events.py b/test/orm/test_events.py index fb05f6601..1c918a88c 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -13,7 +13,6 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import attributes from sqlalchemy.orm import class_mapper from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import deferred from sqlalchemy.orm import events from sqlalchemy.orm import EXT_SKIP @@ -35,6 +34,7 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_not from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.mock import ANY from sqlalchemy.testing.mock import call from sqlalchemy.testing.mock import Mock @@ -676,7 +676,7 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): canary = self.listen_all(User) named_canary = self.listen_all(User, named=True) - sess = create_session() + sess = fixture_session() u = User(name="u1") sess.add(u) sess.flush() @@ -769,13 +769,13 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): event.listen(mapper, "load", load) - s = Session() + s = fixture_session() u = User(name="u1") s.add(u) s.commit() - s = Session() + s = fixture_session() u2 = s.merge(u) - s = Session() + s = fixture_session() u2 = s.merge(User(name="u2")) # noqa s.commit() s.query(User).order_by(User.id).first() @@ -803,7 +803,7 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): canary2 = self.listen_all(User) canary3 = self.listen_all(AdminUser) - sess = create_session() + sess = fixture_session() am = AdminUser(name="au1", email_address="au1@e1") sess.add(am) sess.flush() @@ -871,7 +871,7 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): ) canary3 = self.listen_all(AdminUser) - sess = create_session() + sess = fixture_session() am = AdminUser(name="au1", email_address="au1@e1") sess.add(am) sess.flush() @@ -942,7 +942,7 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): canary1 = self.listen_all(Item) canary2 = self.listen_all(Keyword) - sess = create_session() + sess = fixture_session() i1 = Item(description="i1") k1 = Keyword(name="k1") sess.add(i1) @@ -998,7 +998,7 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): event.listen(mapper, "before_configured", m1) event.listen(mapper, "after_configured", m2) - s = Session() + s = fixture_session() s.query(User) eq_(m1.mock_calls, [call()]) @@ -1117,7 +1117,7 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): # fails by default because Mammal needs to be configured, and cannot # be: def probe(): - s = Session() + s = fixture_session() s.query(User) assert_raises(sa.exc.InvalidRequestError, probe) @@ -1181,7 +1181,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest): @_combinations def test_warning(self, target, event_name, fn): A = self.classes.A - s = Session() + s = fixture_session() target = testing.util.resolve_lambda(target, A=A, session=s) event.listen(target, event_name, fn) @@ -1205,7 +1205,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest): @_combinations def test_flag_resolves_existing(self, target, event_name, fn): A = self.classes.A - s = Session() + s = fixture_session() target = testing.util.resolve_lambda(target, A=A, session=s) a1 = s.query(A).all()[0] @@ -1244,7 +1244,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest): @_combinations def test_flag_resolves(self, target, event_name, fn): A = self.classes.A - s = Session() + s = fixture_session() target = testing.util.resolve_lambda(target, A=A, session=s) event.listen(target, event_name, fn, restore_load_context=True) @@ -1681,7 +1681,7 @@ class LoadTest(_fixtures.FixtureTest): canary = self._fixture() - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -1696,7 +1696,7 @@ class LoadTest(_fixtures.FixtureTest): canary = self._fixture() - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -1828,7 +1828,7 @@ class RefreshTest(_fixtures.FixtureTest): canary = self._fixture() - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -1855,7 +1855,7 @@ class RefreshTest(_fixtures.FixtureTest): def canary2(obj, context, props): obj.name = "refreshed name!" - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) sess.commit() @@ -1877,7 +1877,7 @@ class RefreshTest(_fixtures.FixtureTest): canary = self._fixture() - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -1891,7 +1891,7 @@ class RefreshTest(_fixtures.FixtureTest): canary = self._fixture() - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -1905,7 +1905,7 @@ class RefreshTest(_fixtures.FixtureTest): canary = self._fixture() - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -1920,7 +1920,7 @@ class RefreshTest(_fixtures.FixtureTest): canary = self._fixture() - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -1934,7 +1934,7 @@ class RefreshTest(_fixtures.FixtureTest): canary = self._fixture() - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -1949,7 +1949,7 @@ class RefreshTest(_fixtures.FixtureTest): canary = self._fixture() - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -1968,7 +1968,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): event.listen(Session, "before_flush", my_listener) - s = Session() + s = fixture_session() assert my_listener in s.dispatch.before_flush def test_sessionmaker_listen(self): @@ -2001,7 +2001,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def my_listener_one(*arg, **kw): pass - scope = scoped_session(lambda: Session()) + scope = scoped_session(lambda: fixture_session()) assert_raises_message( sa.exc.ArgumentError, @@ -2021,7 +2021,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): class NotASession(object): def __call__(self): - return Session() + return fixture_session() scope = scoped_session(NotASession) @@ -2055,7 +2055,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): return go - sess = Session(**kw) + sess = fixture_session(**kw) for evt in [ "after_transaction_create", @@ -2152,7 +2152,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = Session() + sess = fixture_session() assertions = [] @@ -2227,7 +2227,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def test_state_before_attach(self): User, users = self.classes.User, self.tables.users - sess = Session() + sess = fixture_session() @event.listens_for(sess, "before_attach") def listener(session, inst): @@ -2246,7 +2246,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def test_state_after_attach(self): User, users = self.classes.User, self.tables.users - sess = Session() + sess = fixture_session() @event.listens_for(sess, "after_attach") def listener(session, inst): @@ -2279,7 +2279,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def test_on_bulk_update_hook(self): User, users = self.classes.User, self.tables.users - sess = Session() + sess = fixture_session() canary = Mock() event.listen(sess, "after_begin", canary.after_begin) @@ -2299,7 +2299,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def test_on_bulk_delete_hook(self): User, users = self.classes.User, self.tables.users - sess = Session() + sess = fixture_session() canary = Mock() event.listen(sess, "after_begin", canary.after_begin) @@ -2317,7 +2317,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_(upd.result.rowcount, 0) def test_connection_emits_after_begin(self): - sess, canary = self._listener_fixture(bind=testing.db) + sess, canary = self._listener_fixture() sess.connection() # changed due to #5074 eq_(canary, ["after_transaction_create", "after_begin"]) @@ -2331,7 +2331,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def before_flush(session, flush_context, objects): session.flush() - sess = Session() + sess = fixture_session() event.listen(sess, "before_flush", before_flush) sess.add(User(name="foo")) assert_raises_message( @@ -2356,7 +2356,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): ) session.delete(x) - sess = Session() + sess = fixture_session() event.listen(sess, "before_flush", before_flush) u = User(name="u1") @@ -2400,7 +2400,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): for obj in list(session.identity_map.values()): obj.name += " modified" - sess = Session(autoflush=True) + sess = fixture_session(autoflush=True) event.listen(sess, "before_flush", before_flush) u = User(name="u1") @@ -2421,7 +2421,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(name="u1") @@ -2443,7 +2443,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(name="u1") @@ -2484,7 +2484,7 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): listener = Mock() - sess = Session() + sess = fixture_session() def start_events(): event.listen( @@ -3027,7 +3027,7 @@ class QueryEventsTest( return query User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User).filter_by(id=7) self.assert_compile( @@ -3046,7 +3046,7 @@ class QueryEventsTest( counter[0] += 1 User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User).filter_by(id=7) str(q) @@ -3060,7 +3060,7 @@ class QueryEventsTest( def fn(query): return query.add_columns(User.name) - s = Session() + s = fixture_session() q = s.query(User.id).filter_by(id=7) self.assert_compile( @@ -3088,7 +3088,7 @@ class QueryEventsTest( return query User = self.classes.User - s = Session() + s = fixture_session() with self.sql_execution_asserter() as asserter: s.query(User).filter_by(id=7).update({"name": "ed"}) @@ -3112,7 +3112,7 @@ class QueryEventsTest( return query User = self.classes.User - s = Session() + s = fixture_session() # note this deletes no rows with self.sql_execution_asserter() as asserter: @@ -3140,7 +3140,7 @@ class QueryEventsTest( ): opts.update(context.execution_options) - sess = create_session(bind=testing.db, autocommit=False) + sess = fixture_session(autocommit=False) sess.query(User).first() eq_(opts["my_option"], True) @@ -3185,7 +3185,7 @@ class RefreshFlushInReturningTest(fixtures.MappedTest): mock = Mock() event.listen(Thing, "refresh_flush", mock) t1 = Thing() - s = Session() + s = fixture_session() s.add(t1) s.flush() diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py index 5abaa03db..5771ee538 100644 --- a/test/orm/test_expire.py +++ b/test/orm/test_expire.py @@ -25,7 +25,7 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect @@ -48,7 +48,7 @@ class ExpireTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) u = sess.query(User).get(7) assert len(u.addresses) == 1 u.name = "foo" @@ -89,7 +89,7 @@ class ExpireTest(_fixtures.FixtureTest): mapper(User, users) mapper(Address, addresses, properties={"user": relationship(User)}) - s = Session() + s = fixture_session() a1 = s.query(Address).get(2) u1 = s.query(User).get(7) @@ -104,7 +104,7 @@ class ExpireTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - s = create_session() + s = fixture_session() u = s.query(User).get(7) s.expunge_all() @@ -119,7 +119,7 @@ class ExpireTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - s = create_session(autocommit=False) + s = fixture_session(autocommit=False) u = s.query(User).get(10) s.expire_all() @@ -142,7 +142,7 @@ class ExpireTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - s = create_session(autocommit=False) + s = fixture_session(autocommit=False) u = s.query(User).get(10) s.expire_all() @@ -157,7 +157,7 @@ class ExpireTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - s = create_session(autocommit=False) + s = fixture_session(autocommit=False) u = s.query(User).get(10) s.expire_all() @@ -178,7 +178,7 @@ class ExpireTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - s = create_session(autocommit=False) + s = fixture_session(autocommit=False) u = s.query(User).get(10) s.expire_all() s.execute(users.delete().where(User.id == 10)) @@ -206,7 +206,7 @@ class ExpireTest(_fixtures.FixtureTest): properties={"description": deferred(orders.c.description)}, ) - s = create_session() + s = fixture_session() o1 = s.query(Order).first() assert "description" not in o1.__dict__ s.expire(o1) @@ -233,7 +233,7 @@ class ExpireTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users, properties={"name": deferred(users.c.name)}) - s = create_session(autocommit=False) + s = fixture_session(autocommit=False) u = s.query(User).get(10) assert "name" not in u.__dict__ @@ -265,7 +265,7 @@ class ExpireTest(_fixtures.FixtureTest): }, ) mapper(Address, addresses) - s = create_session(autoflush=True, autocommit=False) + s = fixture_session(autoflush=True, autocommit=False) u = s.query(User).get(8) adlist = u.addresses eq_( @@ -309,7 +309,7 @@ class ExpireTest(_fixtures.FixtureTest): }, ) mapper(Address, addresses) - s = create_session(autoflush=True, autocommit=False) + s = fixture_session(autoflush=True, autocommit=False) u = s.query(User).get(8) assert_raises_message( sa_exc.InvalidRequestError, @@ -323,7 +323,7 @@ class ExpireTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - s = create_session() + s = fixture_session() u = s.query(User).get(7) s.expire(u) s.refresh(u) @@ -339,7 +339,7 @@ class ExpireTest(_fixtures.FixtureTest): mapper(User, users) - sess = create_session() + sess = fixture_session(autoflush=False) u = sess.query(User).get(7) sess.expire(u, attribute_names=["name"]) @@ -356,7 +356,7 @@ class ExpireTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = create_session() + sess = fixture_session() u = sess.query(User).get(7) sess.expire(u, attribute_names=["name"]) @@ -369,7 +369,7 @@ class ExpireTest(_fixtures.FixtureTest): # this was the opposite in 0.4, but the reasoning there seemed off. # expiring a pending instance makes no sense, so should raise mapper(User, users) - sess = create_session() + sess = fixture_session() u = User(id=15) sess.add(u) assert_raises(sa_exc.InvalidRequestError, sess.expire, u, ["name"]) @@ -382,7 +382,7 @@ class ExpireTest(_fixtures.FixtureTest): # is actually part of a larger behavior when postfetch needs to # occur during a flush() on an instance that was just inserted mapper(User, users) - sess = create_session() + sess = fixture_session(autoflush=False) u = sess.query(User).get(7) sess.expire(u, attribute_names=["name"]) @@ -398,7 +398,7 @@ class ExpireTest(_fixtures.FixtureTest): # same as test_no_instance_key, but the PK columns # are absent. ensure an error is raised. mapper(User, users) - sess = create_session() + sess = fixture_session() u = sess.query(User).get(7) sess.expire(u, attribute_names=["name", "id"]) @@ -415,7 +415,7 @@ class ExpireTest(_fixtures.FixtureTest): Order, orders = self.classes.Order, self.tables.orders mapper(Order, orders) - sess = create_session() + sess = fixture_session(autoflush=False) o = sess.query(Order).get(3) sess.expire(o) @@ -467,7 +467,7 @@ class ExpireTest(_fixtures.FixtureTest): mapper(Order, orders) - sess = create_session() + sess = fixture_session(autoflush=False) o = sess.query(Order).get(3) sess.expire(o) @@ -501,7 +501,7 @@ class ExpireTest(_fixtures.FixtureTest): }, ) mapper(Address, addresses) - s = create_session() + s = fixture_session(autoflush=False) u = s.query(User).get(8) assert u.addresses[0].email_address == "ed@wood.com" @@ -527,7 +527,7 @@ class ExpireTest(_fixtures.FixtureTest): }, ) mapper(Address, addresses) - s = create_session() + s = fixture_session(autoflush=False) u = s.query(User).get(8) assert u.addresses[0].email_address == "ed@wood.com" @@ -565,7 +565,7 @@ class ExpireTest(_fixtures.FixtureTest): properties={"addresses": relationship(Address, cascade=cascade)}, ) mapper(Address, addresses) - s = create_session() + s = fixture_session(autoflush=False) u = s.query(User).get(8) a = Address(id=12, email_address="foobar") @@ -598,7 +598,7 @@ class ExpireTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() u = sess.query(User).get(7) sess.expire(u) @@ -633,7 +633,7 @@ class ExpireTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() u = sess.query(User).get(7) sess.expire(u) @@ -681,7 +681,7 @@ class ExpireTest(_fixtures.FixtureTest): }, ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) u = sess.query(User).get(8) sess.expire(u, ["name", "addresses"]) u.addresses @@ -711,7 +711,7 @@ class ExpireTest(_fixtures.FixtureTest): }, ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() u = sess.query(User).get(8) sess.expire(u) u.id @@ -736,7 +736,7 @@ class ExpireTest(_fixtures.FixtureTest): properties={"addresses": relationship(Address, backref="user")}, ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() u = sess.query(User).options(joinedload(User.addresses)).get(8) sess.expire(u) u.id @@ -761,7 +761,7 @@ class ExpireTest(_fixtures.FixtureTest): }, ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() u = sess.query(User).get(8) sess.expire(u) @@ -780,7 +780,7 @@ class ExpireTest(_fixtures.FixtureTest): mapper(User, users, properties={"uname": sa.orm.synonym("name")}) - sess = create_session() + sess = fixture_session() u = sess.query(User).get(7) assert "name" in u.__dict__ assert u.uname == u.name @@ -804,7 +804,7 @@ class ExpireTest(_fixtures.FixtureTest): mapper(Order, orders) - sess = create_session() + sess = fixture_session(autoflush=False) o = sess.query(Order).get(3) sess.expire(o, attribute_names=["description"]) @@ -873,7 +873,7 @@ class ExpireTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) u = sess.query(User).get(8) sess.expire(u, ["name", "addresses"]) @@ -935,7 +935,7 @@ class ExpireTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) u = sess.query(User).get(8) sess.expire(u, ["name", "addresses"]) @@ -989,7 +989,7 @@ class ExpireTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) u = sess.query(User).get(8) assert "name" in u.__dict__ u.addresses @@ -1016,7 +1016,7 @@ class ExpireTest(_fixtures.FixtureTest): properties={"description": sa.orm.deferred(orders.c.description)}, ) - sess = create_session() + sess = fixture_session(autoflush=False) o = sess.query(Order).get(3) sess.expire(o, ["description", "isopen"]) assert "isopen" not in o.__dict__ @@ -1104,7 +1104,7 @@ class ExpireTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) u = sess.query(User).get(8) assert len(u.addresses) == 3 sess.expire(u) @@ -1136,7 +1136,7 @@ class ExpireTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) userlist = sess.query(User).order_by(User.id).all() eq_(self.static.user_address_result, userlist) eq_(len(list(sess)), 9) @@ -1158,7 +1158,7 @@ class ExpireTest(_fixtures.FixtureTest): mapper(User, users) - sess = create_session() + sess = fixture_session(autoflush=False) # deferred attribute option, gets the LoadDeferredColumns # callable @@ -1208,7 +1208,7 @@ class ExpireTest(_fixtures.FixtureTest): mapper(User, users, properties={"name": deferred(users.c.name)}) - sess = create_session() + sess = fixture_session(autoflush=False) u1 = sess.query(User).options(undefer(User.name)).first() assert "name" not in attributes.instance_state(u1).callables @@ -1262,7 +1262,7 @@ class ExpireTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) u1 = sess.query(User).options(lazyload(User.addresses)).first() assert isinstance( attributes.instance_state(u1).callables["addresses"], @@ -1308,7 +1308,7 @@ class ExpireTest(_fixtures.FixtureTest): properties={"description": deferred(orders.c.description)}, ) - s = Session() + s = fixture_session() item = Order(id=1) make_transient_to_detached(item) @@ -1324,7 +1324,7 @@ class ExpireTest(_fixtures.FixtureTest): properties={"description": deferred(orders.c.description)}, ) - s = Session() + s = fixture_session() item = s.query(Order).first() s.expire(item) @@ -1339,7 +1339,7 @@ class ExpireTest(_fixtures.FixtureTest): properties={"description": deferred(orders.c.description)}, ) - s = Session() + s = fixture_session() item = s.query(Order).first() s.expire(item, ["isopen", "description"]) @@ -1431,7 +1431,7 @@ class PolymorphicExpireTest(fixtures.MappedTest): self.classes.Engineer, ) - sess = create_session() + sess = fixture_session(autoflush=False) [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all() sess.expire(p1) @@ -1473,7 +1473,7 @@ class PolymorphicExpireTest(fixtures.MappedTest): def test_no_instance_key(self): Engineer = self.classes.Engineer - sess = create_session() + sess = fixture_session(autoflush=False) e1 = sess.query(Engineer).get(2) sess.expire(e1, attribute_names=["name"]) @@ -1488,7 +1488,7 @@ class PolymorphicExpireTest(fixtures.MappedTest): # same as test_no_instance_key, but the PK columns # are absent. ensure an error is raised. - sess = create_session() + sess = fixture_session(autoflush=False) e1 = sess.query(Engineer).get(2) sess.expire(e1, attribute_names=["name", "person_id"]) @@ -1520,7 +1520,7 @@ class ExpiredPendingTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) a1 = Address(email_address="a1") sess.add(a1) sess.flush() @@ -1608,7 +1608,7 @@ class LifecycleTest(fixtures.MappedTest): def test_attr_not_inserted(self): Data = self.classes.Data - sess = create_session() + sess = fixture_session() d1 = Data() sess.add(d1) @@ -1627,7 +1627,7 @@ class LifecycleTest(fixtures.MappedTest): def test_attr_not_inserted_expired(self): Data = self.classes.Data - sess = create_session() + sess = fixture_session(autoflush=False) d1 = Data() sess.add(d1) @@ -1646,7 +1646,7 @@ class LifecycleTest(fixtures.MappedTest): def test_attr_not_inserted_fetched(self): Data = self.classes.DataFetched - sess = create_session() + sess = fixture_session() d1 = Data() sess.add(d1) @@ -1667,7 +1667,7 @@ class LifecycleTest(fixtures.MappedTest): d1 = Data(data="d1") sess.add(d1) - sess = create_session() + sess = fixture_session() d1 = sess.query(Data).from_statement(select(Data.id)).first() # cols not present in the row are implicitly expired @@ -1720,7 +1720,7 @@ class RefreshTest(_fixtures.FixtureTest): ) }, ) - s = create_session() + s = fixture_session(autoflush=False) u = s.query(User).get(7) u.name = "foo" a = Address() @@ -1755,7 +1755,7 @@ class RefreshTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - s = create_session() + s = fixture_session() u = s.query(User).get(7) s.expunge_all() assert_raises_message( @@ -1771,7 +1771,7 @@ class RefreshTest(_fixtures.FixtureTest): mapper(User, users) mapper(Address, addresses, properties={"user": relationship(User)}) - s = Session() + s = fixture_session() a1 = s.query(Address).get(2) u1 = s.query(User).get(7) @@ -1786,7 +1786,7 @@ class RefreshTest(_fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - s = create_session() + s = fixture_session() u = s.query(User).get(7) s.expire(u) assert "name" not in u.__dict__ @@ -1805,7 +1805,7 @@ class RefreshTest(_fixtures.FixtureTest): self.tables.users, ) - s = create_session() + s = fixture_session() mapper( User, users, @@ -1840,13 +1840,13 @@ class RefreshTest(_fixtures.FixtureTest): }, ) - s = create_session() + s = fixture_session() u = s.query(User).get(8) assert len(u.addresses) == 3 s.refresh(u) assert len(u.addresses) == 3 - s = create_session() + s = fixture_session() u = s.query(User).get(8) assert len(u.addresses) == 3 s.expire(u) @@ -1870,7 +1870,7 @@ class RefreshTest(_fixtures.FixtureTest): mapper(Dingaling, dingalings) - s = create_session() + s = fixture_session() q = ( s.query(User) .filter_by(name="fred") @@ -1908,7 +1908,7 @@ class RefreshTest(_fixtures.FixtureTest): self.classes.User, ) - s = create_session() + s = fixture_session() mapper(Address, addresses) mapper( diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 055f24b5c..cc9596466 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -23,7 +23,6 @@ from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import column_property from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import contains_eager -from sqlalchemy.orm import create_session from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship @@ -37,6 +36,7 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from test.orm import _fixtures @@ -198,7 +198,7 @@ class QueryCorrelatesLikeSelect(QueryTest, AssertsCompiledSQL): ) def test_scalar_subquery_query_auto_correlate(self): - sess = create_session() + sess = fixture_session() Address, User = self.classes.Address, self.classes.User query = ( sess.query(func.count(Address.id)) @@ -211,7 +211,7 @@ class QueryCorrelatesLikeSelect(QueryTest, AssertsCompiledSQL): ) def test_scalar_subquery_query_explicit_correlate(self): - sess = create_session() + sess = fixture_session() Address, User = self.classes.Address, self.classes.User query = ( sess.query(func.count(Address.id)) @@ -226,7 +226,7 @@ class QueryCorrelatesLikeSelect(QueryTest, AssertsCompiledSQL): @testing.combinations(False, None) def test_scalar_subquery_query_correlate_off(self, value): - sess = create_session() + sess = fixture_session() Address, User = self.classes.Address, self.classes.User query = ( sess.query(func.count(Address.id)) @@ -241,7 +241,7 @@ class QueryCorrelatesLikeSelect(QueryTest, AssertsCompiledSQL): def test_correlate_to_union(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User) q = sess.query(User).union(q) @@ -312,7 +312,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): def test_select(self): addresses, users = self.tables.addresses, self.tables.users - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(users) @@ -379,19 +379,19 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): User = self.classes.User subq = select(User).filter(User.id.in_([8, 9])).subquery() - q = create_session().query(aliased(User, subq)) + q = fixture_session().query(aliased(User, subq)) eq_( [User(id=8), User(id=9)], q.all(), ) subq = select(User).order_by(User.id).slice(1, 3).subquery() - q = create_session().query(aliased(User, subq)) + q = fixture_session().query(aliased(User, subq)) eq_([User(id=8), User(id=9)], q.all()) subq = select(User).filter(User.id.in_([8, 9])).subquery() u = aliased(User, subq) - q = create_session().query(u).order_by(u.id) + q = fixture_session().query(u).order_by(u.id) eq_( [User(id=8)], list(q[0:1]), @@ -405,7 +405,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): u = aliased(User, stmt) q = ( - create_session() + fixture_session() .query(u) .join(u.addresses) .add_entity(Address) @@ -433,7 +433,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): # there's no reason to do aliased(Address) in this case but we're just # testing aq = aliased(Address, subq) - q = create_session().query(aq.user_id, subq.c.count) + q = fixture_session().query(aq.user_id, subq.c.count) eq_( q.all(), [(7, 1), (8, 3), (9, 1)], @@ -443,7 +443,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): aq = aliased(Address, subq) q = ( - create_session() + fixture_session() .query(aq.user_id, func.count(aq.id)) .group_by(aq.user_id) .order_by(aq.user_id) @@ -457,7 +457,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): def test_error_w_aliased_against_select(self): User = self.classes.User - s = create_session() + s = fixture_session() stmt = select(User.id) @@ -474,7 +474,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): def test_having(self): User = self.classes.User - s = create_session() + s = fixture_session() stmt = ( select(User.id) @@ -496,7 +496,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): User = self.classes.User - s = create_session() + s = fixture_session() subq = ( select(User) @@ -533,7 +533,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() ualias = aliased(User) @@ -587,7 +587,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): def test_multiple_entities(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() subq = ( select(User, Address) @@ -629,7 +629,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): ) uq2 = aliased(User, subq2) - sess = create_session() + sess = fixture_session() eq_( sess.query(uq2.id, subq2.c.foo).all(), @@ -639,7 +639,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): def test_multiple_with_column_entities_newstyle(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q1 = sess.query(User.id) @@ -665,7 +665,7 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): def test_select_entity_from(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User) q = sess.query(User).select_entity_from(q.statement.subquery()) @@ -678,7 +678,7 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): def test_select_entity_from_no_entities(self): User = self.classes.User - sess = create_session() + sess = fixture_session() assert_raises_message( sa.exc.ArgumentError, @@ -689,7 +689,7 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): def test_select_from_no_aliasing(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User) q = sess.query(User).select_from(q.statement.subquery()) @@ -704,7 +704,7 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): # relies upon _orm_only_from_obj_alias setting from sqlalchemy.sql import column - sess = create_session() + sess = fixture_session() c1, c2 = column("c1"), column("c2") q1 = sess.query(c1, c2).filter(c1 == "dog") q2 = sess.query(c1, c2).filter(c1 == "cat") @@ -806,7 +806,7 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): def test_anonymous_labeled_expression_oldstyle(self): # relies upon _orm_only_from_obj_alias setting - sess = create_session() + sess = fixture_session() c1, c2 = column("c1"), column("c2") q1 = sess.query(c1.label("foo"), c2.label("bar")).filter(c1 == "dog") q2 = sess.query(c1.label("foo"), c2.label("bar")).filter(c1 == "cat") @@ -839,7 +839,7 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): Address = self.classes.Address addresses = self.tables.addresses - sess = create_session() + sess = fixture_session() q1 = sess.query(User.id).filter(User.id > 5) uq = aliased(User, q1.apply_labels().subquery()) @@ -868,7 +868,7 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): Address = self.classes.Address addresses = self.tables.addresses - sess = create_session() + sess = fixture_session() q1 = sess.query(User.id).filter(User.id > 5).apply_labels().subquery() uq = aliased(User, q1) @@ -970,7 +970,7 @@ class AddEntityEquivalenceTest(fixtures.MappedTest, AssertsCompiledSQL): def insert_data(cls, connection): A, C, B = (cls.classes.A, cls.classes.C, cls.classes.B) - sess = create_session(connection) + sess = Session(connection) sess.add_all( [ B(name="b1"), @@ -984,7 +984,7 @@ class AddEntityEquivalenceTest(fixtures.MappedTest, AssertsCompiledSQL): def test_add_entity_equivalence(self): A, C, B = (self.classes.A, self.classes.C, self.classes.B) - sess = create_session() + sess = fixture_session() for q in [ sess.query(A, B).join(A.link), @@ -1033,7 +1033,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): use_labels=True, order_by=[text("ulist.id"), addresses.c.id] ) ) - sess = create_session() + sess = fixture_session() q = sess.query(User) def go(): @@ -1062,7 +1062,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): use_labels=True, order_by=[text("ulist.id"), addresses.c.id] ) ) - sess = create_session() + sess = fixture_session() q = sess.query(User) def go(): @@ -1092,7 +1092,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): use_labels=True, order_by=[text("ulist.id"), addresses.c.id] ) ) - sess = create_session() + sess = fixture_session() # better way. use select_entity_from() def go(): @@ -1113,7 +1113,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.tables.users, ) - sess = create_session() + sess = fixture_session() # same thing, but alias addresses, so that the adapter # generated by select_entity_from() is wrapped within @@ -1141,7 +1141,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): def test_contains_eager_one(self): addresses, User = (self.tables.addresses, self.classes.User) - sess = create_session() + sess = fixture_session() # test that contains_eager suppresses the normal outer join rendering q = ( @@ -1175,7 +1175,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() adalias = addresses.alias() q = ( @@ -1197,7 +1197,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() selectquery = users.outerjoin(addresses).select( users.c.id < 10, @@ -1224,7 +1224,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session(testing.db, future=True) + sess = fixture_session(future=True) selectquery = users.outerjoin(addresses).select( users.c.id < 10, @@ -1252,7 +1252,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): def test_contains_eager_aliased(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = sess.query(User) # Aliased object @@ -1277,7 +1277,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) oalias = orders.alias("o1") @@ -1315,7 +1315,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.Order, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) # test using Aliased with more than one level deep @@ -1344,7 +1344,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.Order, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) # test using Aliased with more than one level deep @@ -1375,7 +1375,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() q = ( sess.query(User) .join(User.addresses) @@ -1423,7 +1423,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() da = aliased(Dingaling, name="foob") q = ( sess.query(User) @@ -1471,7 +1471,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) @@ -1584,7 +1584,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): def test_alias_naming(self): User = self.classes.User - sess = create_session() + sess = fixture_session() ua = aliased(User, name="foobar") q = sess.query(ua) @@ -1604,7 +1604,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() subq = ( select(func.count()) @@ -1644,7 +1644,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): def test_column_queries_one(self): User = self.classes.User - sess = create_session() + sess = fixture_session() eq_( sess.query(User.name).all(), @@ -1657,7 +1657,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() sel = users.select(User.id.in_([7, 8])).alias() q = sess.query(User.name) q2 = q.select_entity_from(sel).all() @@ -1669,7 +1669,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() eq_( sess.query(User.name, Address.email_address) .filter(User.id == Address.user_id) @@ -1689,7 +1689,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() eq_( sess.query(User.name, func.count(Address.email_address)) .outerjoin(User.addresses) @@ -1705,7 +1705,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() eq_( sess.query(User, func.count(Address.email_address)) .outerjoin(User.addresses) @@ -1726,7 +1726,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() eq_( sess.query(func.count(Address.email_address), User) .outerjoin(User.addresses) @@ -1747,7 +1747,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() adalias = aliased(Address) eq_( sess.query(User, func.count(adalias.email_address)) @@ -1769,7 +1769,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() adalias = aliased(Address) eq_( sess.query(func.count(adalias.email_address), User) @@ -1791,7 +1791,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() adalias = aliased(Address) @@ -1823,7 +1823,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() # anon + select from aliasing aa = aliased(Address) @@ -1847,7 +1847,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() adalias = aliased(Address) @@ -1986,7 +1986,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): def test_column_from_limited_joinedload(self): User = self.classes.User - sess = create_session() + sess = fixture_session() def go(): results = ( @@ -2003,7 +2003,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): def test_self_referential_from_self(self): Order = self.classes.Order - sess = create_session() + sess = fixture_session() oalias = aliased(Order) q1 = ( @@ -2159,7 +2159,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - test_session = create_session() + test_session = fixture_session() (user7, user8, user9, user10) = test_session.query(User).all() ( @@ -2179,7 +2179,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): (user10, None), ] - sess = create_session(testing.db, future=True) + sess = fixture_session(future=True) selectquery = users.outerjoin(addresses).select( use_labels=True, order_by=[users.c.id, addresses.c.id] @@ -2235,7 +2235,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() (user7, user8, user9, user10) = sess.query(User).all() (address1, address2, address3, address4, address5) = sess.query( @@ -2272,7 +2272,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): def test_with_entities(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = sess.query(User).filter(User.id == 7).order_by(User.name) @@ -2292,7 +2292,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): def test_multi_columns(self): users, User = self.tables.users, self.classes.User - sess = create_session() + sess = fixture_session() expected = [(u, u.name) for u in sess.query(User).all()] @@ -2309,7 +2309,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): users, User = self.tables.users, self.classes.User - sess = create_session() + sess = fixture_session() eq_( sess.query(User.id).add_columns(users).all(), @@ -2326,7 +2326,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.tables.users, ) - sess = create_session() + sess = fixture_session() (user7, user8, user9, user10) = sess.query(User).all() expected = [(user7, 1), (user8, 3), (user9, 1), (user10, 0)] @@ -2370,7 +2370,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): User = self.classes.User users = self.tables.users - sess = create_session() + sess = fixture_session() q = sess.query(User.id, User.name) stmt = select(users).order_by(users.c.id) @@ -2385,7 +2385,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() (user7, user8, user9, user10) = sess.query(User).all() expected = [ (user7, 1, "Name:jack"), @@ -2396,7 +2396,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): adalias = addresses.alias() q = ( - create_session() + fixture_session() .query(User) .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name)) .outerjoin(adalias, "addresses") @@ -2417,7 +2417,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): .group_by(*[c for c in users.c]) .order_by(users.c.id) ) - q = create_session().query(User) + q = fixture_session().query(User) result = ( q.add_columns(s.selected_columns.count, s.selected_columns.concat) .from_statement(s) @@ -2429,7 +2429,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): # test with select_entity_from() q = ( - create_session() + fixture_session() .query(User) .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name)) .select_entity_from(users.outerjoin(addresses)) @@ -2441,7 +2441,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): sess.expunge_all() q = ( - create_session() + fixture_session() .query(User) .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name)) .outerjoin("addresses") @@ -2453,7 +2453,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): sess.expunge_all() q = ( - create_session() + fixture_session() .query(User) .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name)) .outerjoin(adalias, "addresses") @@ -2469,7 +2469,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): ua = aliased(User) aa = aliased(Address) - s = create_session() + s = fixture_session() for crit, j, exp in [ ( User.id + Address.id, @@ -2539,7 +2539,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): def test_aliased_adapt_on_names(self): User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() agg_address = sess.query( Address.id, func.sum(func.length(Address.email_address)).label( @@ -2597,7 +2597,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): mapper(Address, addresses) sel = users.select(users.c.id.in_([7, 8])).alias() - sess = create_session() + sess = fixture_session() eq_( sess.query(User).select_entity_from(sel).all(), @@ -2641,7 +2641,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): mapper(User, users) - sess = create_session() + sess = fixture_session() not_users = table("users", column("id"), column("name")) ua = aliased(User, select(not_users).alias(), adapt_on_names=True) @@ -2659,7 +2659,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): mapper(User, users) - sess = create_session() + sess = fixture_session() ua = aliased(User) @@ -2676,7 +2676,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): mapper(User, users) - sess = create_session() + sess = fixture_session() ua = users.alias() @@ -2696,7 +2696,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): mapper(User, users) - sess = create_session() + sess = fixture_session() sel = sess.query(User).filter(User.id.in_([7, 8])).subquery() ualias = aliased(User) @@ -2772,7 +2772,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): ua = aliased(User) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User).select_from(ua).join(User, ua.name > User.name), "SELECT users.id AS users_id, users.name AS users_name " @@ -2823,7 +2823,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): mapper(User, users) sel = users.select(users.c.id.in_([7, 8])) - sess = create_session() + sess = fixture_session() eq_( sess.query(User).select_entity_from(sel.subquery()).all(), @@ -2843,7 +2843,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): }, ) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User).select_from(Address).join("user"), @@ -2860,7 +2860,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): properties={"addresses": relationship(mapper(Address, addresses))}, ) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User) @@ -2881,7 +2881,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): properties={"addresses": relationship(mapper(Address, addresses))}, ) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User).select_from(Address).join(User), @@ -2901,7 +2901,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): mapper(Address, addresses) sel = users.select(users.c.id.in_([7, 8])) - sess = create_session() + sess = fixture_session() eq_( sess.query(User) @@ -3011,7 +3011,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): ) # m2m mapper(Keyword, keywords) - sess = create_session() + sess = fixture_session() sel = users.select(users.c.id.in_([7, 8])) eq_( @@ -3073,7 +3073,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): ) # m2m mapper(Keyword, keywords) - sess = create_session() + sess = fixture_session() sel = users.select(users.c.id.in_([7, 8])) @@ -3188,7 +3188,7 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): mapper(Address, addresses) sel = users.select(users.c.id.in_([7, 8])) - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -3316,7 +3316,7 @@ class CustomJoinTest(QueryTest): ), ), ) - q = create_session().query(User) + q = fixture_session().query(User) eq_( q.join("open_orders", "items", aliased=True) @@ -3389,7 +3389,7 @@ class CustomJoinTest(QueryTest): ), ), ) - q = create_session().query(User) + q = fixture_session().query(User) oo = aliased(Order) co = aliased(Order) @@ -3468,7 +3468,7 @@ class ExternalColumnsTest(QueryTest): }, ) - sess = create_session() + sess = fixture_session() sess.query(Address).options(joinedload("user")).all() @@ -3633,7 +3633,7 @@ class ExternalColumnsTest(QueryTest): Order, orders, properties={"address": relationship(Address)} ) # m2o - sess = create_session() + sess = fixture_session() def go(): o1 = ( @@ -3645,7 +3645,7 @@ class ExternalColumnsTest(QueryTest): self.assert_sql_count(testing.db, go, 1) - sess = create_session() + sess = fixture_session() def go(): o1 = ( @@ -3683,11 +3683,11 @@ class ExternalColumnsTest(QueryTest): ) }, ) - sess = create_session() + sess = fixture_session() a1 = sess.query(Address).first() eq_(a1.username, "jack") - sess = create_session() + sess = fixture_session() subq = sess.query(Address).subquery() aa = aliased(Address, subq) a1 = sess.query(aa).first() @@ -3753,7 +3753,7 @@ class TestOverlyEagerEquivalentCols(fixtures.MappedTest): mapper(Sub1, sub1) mapper(Sub2, sub2) - sess = create_session() + sess = fixture_session() s11 = Sub1(data="s11") s12 = Sub1(data="s12") @@ -3802,7 +3802,7 @@ class LabelCollideTest(fixtures.MappedTest): s.commit() def test_overlap_plain(self): - s = Session() + s = fixture_session() row = ( s.query(self.classes.Foo, self.classes.Bar) .join(self.classes.Bar, true()) @@ -3819,7 +3819,7 @@ class LabelCollideTest(fixtures.MappedTest): self.assert_sql_count(testing.db, go, 0) def test_overlap_subquery(self): - s = Session() + s = fixture_session() subq = ( s.query(self.classes.Foo, self.classes.Bar) diff --git a/test/orm/test_generative.py b/test/orm/test_generative.py index 0bca2f975..f6f1b5d74 100644 --- a/test/orm/test_generative.py +++ b/test/orm/test_generative.py @@ -3,11 +3,11 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import testing -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm import _fixtures @@ -45,14 +45,14 @@ class GenerativeQueryTest(fixtures.MappedTest): def test_selectby(self): Foo = self.classes.Foo - res = create_session().query(Foo).filter_by(range=5) + res = fixture_session().query(Foo).filter_by(range=5) assert res.order_by(Foo.bar)[0].bar == 5 assert res.order_by(sa.desc(Foo.bar))[0].bar == 95 def test_slice(self): Foo = self.classes.Foo - sess = create_session() + sess = fixture_session() query = sess.query(Foo).order_by(Foo.id) orig = query.all() @@ -73,7 +73,7 @@ class GenerativeQueryTest(fixtures.MappedTest): def test_aggregate(self): foo, Foo = self.tables.foo, self.classes.Foo - sess = create_session() + sess = fixture_session() query = sess.query(Foo) assert query.count() == 100 assert sess.query(func.min(foo.c.bar)).filter( @@ -99,7 +99,7 @@ class GenerativeQueryTest(fixtures.MappedTest): def test_aggregate_1(self): foo = self.tables.foo - query = create_session().query(func.sum(foo.c.bar)) + query = fixture_session().query(func.sum(foo.c.bar)) assert query.filter(foo.c.bar < 30).one() == (435,) @testing.fails_on("firebird", "FIXME: unknown") @@ -110,7 +110,7 @@ class GenerativeQueryTest(fixtures.MappedTest): def test_aggregate_2(self): foo = self.tables.foo - query = create_session().query(func.avg(foo.c.bar)) + query = fixture_session().query(func.avg(foo.c.bar)) avg = query.filter(foo.c.bar < 30).one()[0] eq_(float(round(avg, 1)), 14.5) @@ -121,7 +121,7 @@ class GenerativeQueryTest(fixtures.MappedTest): def test_aggregate_3(self): foo, Foo = self.tables.foo, self.classes.Foo - query = create_session().query(Foo) + query = fixture_session().query(Foo) avg_f = ( query.filter(foo.c.bar < 30) @@ -140,7 +140,7 @@ class GenerativeQueryTest(fixtures.MappedTest): def test_filter(self): Foo = self.classes.Foo - query = create_session().query(Foo) + query = fixture_session().query(Foo) assert query.count() == 100 assert query.filter(Foo.bar < 30).count() == 30 res2 = query.filter(Foo.bar < 30).filter(Foo.bar > 10) @@ -149,20 +149,20 @@ class GenerativeQueryTest(fixtures.MappedTest): def test_order_by(self): Foo = self.classes.Foo - query = create_session().query(Foo) + query = fixture_session().query(Foo) assert query.order_by(Foo.bar)[0].bar == 0 assert query.order_by(sa.desc(Foo.bar))[0].bar == 99 def test_offset_order_by(self): Foo = self.classes.Foo - query = create_session().query(Foo) + query = fixture_session().query(Foo) assert list(query.order_by(Foo.bar).offset(10))[0].bar == 10 def test_offset(self): Foo = self.classes.Foo - query = create_session().query(Foo) + query = fixture_session().query(Foo) assert len(list(query.limit(10))) == 10 @@ -212,7 +212,7 @@ class GenerativeTest2(fixtures.MappedTest): self.tables.table1, ) - query = create_session().query(Obj1) + query = fixture_session().query(Obj1) eq_(query.count(), 4) res = query.filter( @@ -264,7 +264,7 @@ class RelationshipsTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address - session = create_session() + session = fixture_session() q = ( session.query(User) .join("orders", "addresses") @@ -281,7 +281,7 @@ class RelationshipsTest(_fixtures.FixtureTest): self.classes.Address, ) - session = create_session() + session = fixture_session() q = ( session.query(User) .outerjoin("orders", "addresses") @@ -298,7 +298,7 @@ class RelationshipsTest(_fixtures.FixtureTest): self.classes.Address, ) - session = create_session() + session = fixture_session() q = ( session.query(User) @@ -317,7 +317,7 @@ class RelationshipsTest(_fixtures.FixtureTest): self.tables.addresses, ) - session = create_session() + session = fixture_session() sel = users.outerjoin(orders).outerjoin( addresses, orders.c.address_id == addresses.c.id @@ -376,7 +376,7 @@ class CaseSensitiveTest(fixtures.MappedTest): self.tables.Table1, ) - q = create_session(bind=testing.db).query(Obj1) + q = fixture_session().query(Obj1) assert q.count() == 4 res = q.filter( sa.and_(Table1.c.ID == Table2.c.T1ID, Table2.c.T1ID == 1) diff --git a/test/orm/test_hasparent.py b/test/orm/test_hasparent.py index ffc41fb86..50f577240 100644 --- a/test/orm/test_hasparent.py +++ b/test/orm/test_hasparent.py @@ -1,6 +1,5 @@ """test the current state of the hasparent() flag.""" - from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import testing @@ -8,10 +7,10 @@ from sqlalchemy.orm import attributes from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect @@ -83,7 +82,7 @@ class ParentRemovalTest(fixtures.MappedTest): def _fixture(self): User, Address = self.classes.User, self.classes.Address - s = Session() + s = fixture_session() u1 = User() a1 = Address() diff --git a/test/orm/test_immediate_load.py b/test/orm/test_immediate_load.py index 2fdf1afd9..7efd3436c 100644 --- a/test/orm/test_immediate_load.py +++ b/test/orm/test_immediate_load.py @@ -1,10 +1,10 @@ """basic tests of lazy loaded attributes""" -from sqlalchemy.orm import create_session from sqlalchemy.orm import immediateload from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.testing import eq_ +from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures @@ -22,7 +22,7 @@ class ImmediateTest(_fixtures.FixtureTest): mapper(Address, addresses) mapper(User, users, properties={"addresses": relationship(Address)}) - sess = create_session() + sess = fixture_session() result = ( sess.query(User) @@ -58,7 +58,7 @@ class ImmediateTest(_fixtures.FixtureTest): users, properties={"addresses": relationship(Address, lazy="immediate")}, ) - sess = create_session() + sess = fixture_session() result = sess.query(User).filter(users.c.id == 7).all() eq_(len(sess.identity_map), 2) diff --git a/test/orm/test_inspect.py b/test/orm/test_inspect.py index 8effb583c..d19d65e22 100644 --- a/test/orm/test_inspect.py +++ b/test/orm/test_inspect.py @@ -16,6 +16,7 @@ from sqlalchemy.orm.util import identity_key from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import is_ +from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures @@ -370,7 +371,7 @@ class TestORMInspection(_fixtures.FixtureTest): def test_instance_state_scalar_attr_hist(self): User = self.classes.User u1 = User(name="ed") - sess = Session() + sess = fixture_session() sess.add(u1) sess.commit() assert "name" not in u1.__dict__ @@ -393,7 +394,7 @@ class TestORMInspection(_fixtures.FixtureTest): def test_instance_state_scalar_attr_hist_load(self): User = self.classes.User u1 = User(name="ed") - sess = Session() + sess = fixture_session() sess.add(u1) sess.commit() assert "name" not in u1.__dict__ @@ -640,7 +641,7 @@ class %s(SuperCls): insp = inspect(u1) is_(insp.session, None) - s = Session() + s = fixture_session() s.add(u1) is_(insp.session, s) diff --git a/test/orm/test_instrumentation.py b/test/orm/test_instrumentation.py index 16ccff936..c9b2442be 100644 --- a/test/orm/test_instrumentation.py +++ b/test/orm/test_instrumentation.py @@ -7,7 +7,6 @@ from sqlalchemy import util from sqlalchemy.orm import attributes from sqlalchemy.orm import class_mapper from sqlalchemy.orm import clear_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import instrumentation from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship @@ -16,6 +15,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import ne_ +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -788,7 +788,7 @@ class MiscTest(fixtures.ORMTest): a = A() b.a = a - session = create_session() + session = fixture_session() session.add(b) assert a in session, "base is %s" % base @@ -832,7 +832,7 @@ class MiscTest(fixtures.ORMTest): b = B() b.a = a - session = create_session() + session = fixture_session() session.add(a) assert b in session, "base: %s" % base clear_mappers() diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py index c2548c879..79d3db8f0 100644 --- a/test/orm/test_joins.py +++ b/test/orm/test_joins.py @@ -17,7 +17,6 @@ from sqlalchemy import true from sqlalchemy.engine import default from sqlalchemy.orm import aliased from sqlalchemy.orm import backref -from sqlalchemy.orm import create_session from sqlalchemy.orm import join from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper @@ -30,6 +29,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from test.orm import _fixtures from .inheritance import _poly_fixtures @@ -44,7 +44,7 @@ class InheritedJoinTest(InheritedTest, AssertsCompiledSQL): def test_single_prop(self): Company = self.classes.Company - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company).join(Company.employees), @@ -63,7 +63,7 @@ class InheritedJoinTest(InheritedTest, AssertsCompiledSQL): self.classes.Engineer, ) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company) @@ -81,7 +81,7 @@ class InheritedJoinTest(InheritedTest, AssertsCompiledSQL): def test_force_via_select_from(self): Company, Engineer = self.classes.Company, self.classes.Engineer - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company) @@ -114,7 +114,7 @@ class InheritedJoinTest(InheritedTest, AssertsCompiledSQL): def test_single_prop_of_type(self): Company, Engineer = self.classes.Company, self.classes.Engineer - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company).join(Company.employees.of_type(Engineer)), @@ -130,7 +130,7 @@ class InheritedJoinTest(InheritedTest, AssertsCompiledSQL): def test_explicit_polymorphic_join_one(self): Company, Engineer = self.classes.Company, self.classes.Engineer - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company) @@ -149,7 +149,7 @@ class InheritedJoinTest(InheritedTest, AssertsCompiledSQL): def test_explicit_polymorphic_join_two(self): Company, Engineer = self.classes.Company, self.classes.Engineer - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Company) .join(Engineer, Company.company_id == Engineer.company_id) @@ -167,7 +167,7 @@ class InheritedJoinTest(InheritedTest, AssertsCompiledSQL): def test_auto_aliasing_multi_link(self): # test [ticket:2903] - sess = create_session() + sess = fixture_session() Company, Engineer, Manager, Boss = ( self.classes.Company, @@ -221,7 +221,7 @@ class JoinOnSynonymTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_join_on_synonym(self): User = self.classes.User self.assert_compile( - Session().query(User).join(User.ad_syn), + fixture_session().query(User).join(User.ad_syn), "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN addresses ON users.id = addresses.user_id", ) @@ -233,7 +233,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_filter_by_from_full_join(self): User, Address = self.classes("User", "Address") - sess = create_session() + sess = fixture_session() q = ( sess.query(User) @@ -249,7 +249,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_invalid_kwarg_join(self): User = self.classes.User - sess = create_session() + sess = fixture_session() assert_raises_message( TypeError, "unknown arguments: bar, foob", @@ -271,7 +271,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User, literal_column("x")).join(Address), @@ -288,7 +288,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_left_is_none_and_query_has_no_entities(self): Address = self.classes.Address - sess = create_session() + sess = fixture_session() assert_raises_message( sa_exc.InvalidRequestError, @@ -301,7 +301,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): User = self.classes.User self.assert_compile( - create_session().query(User).join(User.orders, isouter=True), + fixture_session().query(User).join(User.orders, isouter=True), "SELECT users.id AS users_id, users.name AS users_name " "FROM users LEFT OUTER JOIN orders ON users.id = orders.user_id", ) @@ -310,7 +310,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): User = self.classes.User self.assert_compile( - create_session().query(User).outerjoin(User.orders, full=True), + fixture_session().query(User).outerjoin(User.orders, full=True), "SELECT users.id AS users_id, users.name AS users_name " "FROM users FULL OUTER JOIN orders ON users.id = orders.user_id", ) @@ -318,7 +318,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_single_prop_1(self): User = self.classes.User - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User).join(User.orders), "SELECT users.id AS users_id, users.name AS users_name " @@ -328,7 +328,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_single_prop_2(self): Order, User = (self.classes.Order, self.classes.User) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User).join(Order.user), "SELECT users.id AS users_id, users.name AS users_name " @@ -338,7 +338,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_single_prop_3(self): Order, User = (self.classes.Order, self.classes.User) - sess = create_session() + sess = fixture_session() oalias1 = aliased(Order) self.assert_compile( @@ -354,7 +354,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): User, ) = (self.classes.Order, self.classes.User) - sess = create_session() + sess = fixture_session() oalias1 = aliased(Order) oalias2 = aliased(Order) # another nonsensical query. (from [ticket:1537]). @@ -370,7 +370,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_single_prop_6(self): User = self.classes.User - sess = create_session() + sess = fixture_session() ualias = aliased(User) self.assert_compile( sess.query(ualias).join(ualias.orders), @@ -381,7 +381,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_single_prop_9(self): User = self.classes.User - sess = create_session() + sess = fixture_session() subq = ( sess.query(User) @@ -409,7 +409,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() oalias1 = aliased(Order) # test #1 for [ticket:1706] ualias = aliased(User) @@ -430,7 +430,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() # test #2 for [ticket:1706] ualias = aliased(User) ualias2 = aliased(User) @@ -451,7 +451,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Order = self.classes.Order - sess = create_session() + sess = fixture_session() # test overlapping paths. User->orders is used by both joins, but # rendered once. @@ -475,7 +475,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): Order = self.classes.Order Address = self.classes.Address - s = Session() + s = fixture_session() q = ( s.query(User) .join(User.orders) @@ -501,7 +501,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): Order, User = self.classes.Order, self.classes.User - sess = create_session() + sess = fixture_session() # intentionally join() with a non-existent "left" side self.assert_compile( @@ -516,7 +516,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): # a more controversial feature. join from # User->Address, but the onclause is Address.user. - sess = create_session() + sess = fixture_session() eq_( sess.query(User) @@ -554,7 +554,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_multiple_with_aliases(self): Order, User = self.classes.Order, self.classes.User - sess = create_session() + sess = fixture_session() ualias = aliased(User) oalias1 = aliased(Order) @@ -577,7 +577,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_select_from_orm_joins(self): User, Order = self.classes.User, self.classes.Order - sess = create_session() + sess = fixture_session() ualias = aliased(User) oalias1 = aliased(Order) @@ -715,7 +715,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_overlapping_backwards_joins(self): User, Order = self.classes.User, self.classes.Order - sess = create_session() + sess = fixture_session() oalias1 = aliased(Order) oalias2 = aliased(Order) @@ -740,7 +740,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Address, User) @@ -762,7 +762,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_invalid_join_entity_from_single_from_clause(self): Address, Item = (self.classes.Address, self.classes.Item) - sess = create_session() + sess = fixture_session() q = sess.query(Address).select_from(Address) @@ -776,7 +776,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_invalid_join_entity_from_no_from_clause(self): Address, Item = (self.classes.Address, self.classes.Item) - sess = create_session() + sess = fixture_session() q = sess.query(Address) @@ -797,7 +797,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.Address, self.classes.Item, ) - sess = create_session() + sess = fixture_session() q = sess.query(Address, User).join(Address.dingaling).join(User.orders) @@ -816,7 +816,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): User = self.classes.User - sess = create_session() + sess = fixture_session() u1 = aliased(User) @@ -852,7 +852,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): User = self.classes.User - sess = create_session() + sess = fixture_session() u1 = aliased(User) u2 = aliased(User) @@ -892,7 +892,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.Dingaling, ) - sess = create_session() + sess = fixture_session() q = sess.query(Address, User).join(Address.dingaling).join(User.orders) @@ -948,7 +948,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.Dingaling, ) - sess = create_session() + sess = fixture_session() q = sess.query(Order, Dingaling) @@ -1000,7 +1000,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() a1 = aliased(Address) @@ -1043,7 +1043,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() a1 = aliased(Address) @@ -1068,7 +1068,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_pure_expression_error(self): addresses, users = self.tables.addresses, self.tables.users - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(users).join(addresses), @@ -1083,7 +1083,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.Order, ) - sess = create_session() + sess = fixture_session() eq_( sess.query(User) @@ -1119,7 +1119,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() eq_( sess.query(User) @@ -1176,7 +1176,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_aliased_classes(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() (user7, user8, user9, user10) = sess.query(User).all() (address1, address2, address3, address4, address5) = sess.query( @@ -1248,7 +1248,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_expression_onclauses(self): Order, User = self.classes.Order, self.classes.User - sess = create_session() + sess = fixture_session() subq = sess.query(User).subquery() @@ -1281,7 +1281,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_aliased_classes_m2m(self): Item, Order = self.classes.Item, self.classes.Order - sess = create_session() + sess = fixture_session() (order1, order2, order3, order4, order5) = sess.query(Order).all() (item1, item2, item3, item4, item5) = sess.query(Item).all() @@ -1323,7 +1323,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): # test for #1853 - session = create_session() + session = fixture_session() first = session.query(User) second = session.query(User) unioned = first.union(second) @@ -1366,7 +1366,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): # test for #1853 - session = create_session() + session = fixture_session() first = session.query(User) second = session.query(User) unioned = first.union(second) @@ -1415,7 +1415,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): oalias = orders.alias("oalias") result = ( - create_session() + fixture_session() .query(User) .select_from(users.join(oalias)) .filter( @@ -1429,7 +1429,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): assert [User(id=7, name="jack"), User(id=9, name="fred")] == result result = ( - create_session() + fixture_session() .query(User) .select_from(users.join(oalias)) .filter( @@ -1445,7 +1445,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_aliased_order_by(self): User = self.classes.User - sess = create_session() + sess = fixture_session() ualias = aliased(User) eq_( @@ -1466,7 +1466,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_plain_table(self): addresses, User = self.tables.addresses, self.classes.User - sess = create_session() + sess = fixture_session() eq_( sess.query(User.name) @@ -1479,7 +1479,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_no_joinpoint_expr(self): User, users = self.classes.User, self.tables.users - sess = create_session() + sess = fixture_session() # these are consistent regardless of # select_from() being present. @@ -1506,7 +1506,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_on_clause_no_right_side_one(self): User = self.classes.User Address = self.classes.Address - sess = create_session() + sess = fixture_session() # coercions does not catch this due to the # legacy=True flag for JoinTargetRole @@ -1532,7 +1532,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): def test_on_clause_no_right_side_two(self): User = self.classes.User Address = self.classes.Address - sess = create_session() + sess = fixture_session() assert_raises_message( sa_exc.ArgumentError, @@ -1562,7 +1562,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, ) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(Item.id) .select_from(User) @@ -1617,7 +1617,7 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): def test_select_mapped_to_mapped_explicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 - sess = Session() + sess = fixture_session() subq = ( sess.query(T2.t1_id, func.count(T2.id).label("count")) .group_by(T2.t1_id) @@ -1638,7 +1638,7 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): def test_select_mapped_to_mapped_implicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 - sess = Session() + sess = fixture_session() subq = ( sess.query(T2.t1_id, func.count(T2.id).label("count")) .group_by(T2.t1_id) @@ -1657,7 +1657,7 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): def test_select_mapped_to_select_explicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 - sess = Session() + sess = fixture_session() subq = ( sess.query(T2.t1_id, func.count(T2.id).label("count")) .group_by(T2.t1_id) @@ -1677,7 +1677,7 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): def test_select_mapped_to_select_implicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 - sess = Session() + sess = fixture_session() subq = ( sess.query(T2.t1_id, func.count(T2.id).label("count")) .group_by(T2.t1_id) @@ -1709,7 +1709,7 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): def test_mapped_select_to_mapped_implicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 - sess = Session() + sess = fixture_session() subq = ( sess.query(T2.t1_id, func.count(T2.id).label("count")) .group_by(T2.t1_id) @@ -1739,7 +1739,7 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): def test_mapped_select_to_mapped_explicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 - sess = Session() + sess = fixture_session() subq = ( sess.query(T2.t1_id, func.count(T2.id).label("count")) .group_by(T2.t1_id) @@ -1759,7 +1759,7 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): def test_mapped_select_to_select_explicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 - sess = Session() + sess = fixture_session() subq = ( sess.query(T2.t1_id, func.count(T2.id).label("count")) .group_by(T2.t1_id) @@ -1780,7 +1780,7 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): def test_mapped_select_to_select_implicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 - sess = Session() + sess = fixture_session() subq = ( sess.query(T2.t1_id, func.count(T2.id).label("count")) .group_by(T2.t1_id) @@ -1866,7 +1866,7 @@ class SelfRefMixedTest(fixtures.MappedTest, AssertsCompiledSQL): def test_o2m_aliased_plus_o2m(self): Node, Sub = self.classes.Node, self.classes.Sub - sess = create_session() + sess = fixture_session() n1 = aliased(Node) self.assert_compile( @@ -1886,7 +1886,7 @@ class SelfRefMixedTest(fixtures.MappedTest, AssertsCompiledSQL): def test_m2m_aliased_plus_o2m(self): Node, Sub = self.classes.Node, self.classes.Sub - sess = create_session() + sess = fixture_session() n1 = aliased(Node) self.assert_compile( @@ -1962,7 +1962,7 @@ class CreateJoinsTest(fixtures.ORMTest, AssertsCompiledSQL): def test_double_level_aliased_exists(self): A, B, C, Base = self._inherits_fixture() - s = Session() + s = fixture_session() self.assert_compile( s.query(A).filter(A.b.has(B.c.has(C.id == 5))), "SELECT a.id AS a_id, base.id AS base_id, a.b_id AS a_b_id " @@ -2029,7 +2029,7 @@ class JoinToNonPolyAliasesTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_parent_child(self): Parent = self.classes.Parent - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(Parent) .join(Parent.npc) @@ -2045,7 +2045,7 @@ class JoinToNonPolyAliasesTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_parent_child_select_from(self): Parent = self.classes.Parent npc = self.npc - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(npc) .select_from(Parent) @@ -2061,7 +2061,7 @@ class JoinToNonPolyAliasesTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_select_parent_child(self): Parent = self.classes.Parent npc = self.npc - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(Parent, npc) .join(Parent.npc) @@ -2120,7 +2120,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def insert_data(cls, connection): Node = cls.classes.Node - sess = create_session(connection) + sess = Session(connection) n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -2134,7 +2134,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_4_explicit_join(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() na = aliased(Node) na2 = aliased(Node) @@ -2189,7 +2189,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n1 = aliased(Node) @@ -2266,7 +2266,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_to_self_no_aliases_raises(self): Node = self.classes.Node - s = Session() + s = fixture_session() assert_raises_message( sa.exc.InvalidRequestError, "Can't construct a join from mapped class Node->nodes to mapped " @@ -2316,7 +2316,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_explicit_join_4(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n1 = aliased(Node) n2 = aliased(Node) @@ -2331,7 +2331,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_explicit_join_5(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n1 = aliased(Node) n2 = aliased(Node) @@ -2346,7 +2346,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_explicit_join_6(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n1 = aliased(Node) node = ( @@ -2359,7 +2359,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_explicit_join_7(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n1 = aliased(Node) n2 = aliased(Node) @@ -2373,7 +2373,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_explicit_join_8(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n1 = aliased(Node) n2 = aliased(Node) @@ -2390,7 +2390,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_explicit_join_9(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n1 = aliased(Node) n2 = aliased(Node) @@ -2406,7 +2406,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_explicit_join_10(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n1 = aliased(Node) n2 = aliased(Node) @@ -2427,7 +2427,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_join_to_nonaliased(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n1 = aliased(Node) @@ -2457,7 +2457,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_multiple_explicit_entities_one(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() parent = aliased(Node) grandparent = aliased(Node) @@ -2475,7 +2475,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_multiple_explicit_entities_two(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() parent = aliased(Node) grandparent = aliased(Node) @@ -2502,7 +2502,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_multiple_explicit_entities_three(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() parent = aliased(Node) grandparent = aliased(Node) @@ -2529,7 +2529,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_multiple_explicit_entities_four(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() parent = aliased(Node) grandparent = aliased(Node) @@ -2548,7 +2548,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_multiple_explicit_entities_five(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() parent = aliased(Node) grandparent = aliased(Node) @@ -2575,7 +2575,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_any(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() eq_( sess.query(Node) .filter(Node.children.any(Node.data == "n1")) @@ -2605,7 +2605,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_has(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() eq_( sess.query(Node) @@ -2628,7 +2628,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_contains(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n122 = sess.query(Node).filter(Node.data == "n122").one() eq_( @@ -2645,7 +2645,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def test_eq_ne(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n12 = sess.query(Node).filter(Node.data == "n12").one() eq_( @@ -2723,7 +2723,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest): ) }, ) - sess = create_session(connection) + sess = Session(connection) n1 = Node(data="n1") n2 = Node(data="n2") n3 = Node(data="n3") @@ -2746,7 +2746,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest): def test_any(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() eq_( sess.query(Node) .filter(Node.children.any(Node.data == "n3")) @@ -2758,7 +2758,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest): def test_contains(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n4 = sess.query(Node).filter_by(data="n4").one() eq_( @@ -2785,7 +2785,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest): def test_explicit_join(self): Node = self.classes.Node - sess = create_session() + sess = fixture_session() n1 = aliased(Node) eq_( @@ -2863,7 +2863,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): def test_select_subquery(self): Person, Book = self.classes("Person", "Book") - s = Session() + s = fixture_session() subq = ( s.query(Book.book_id) @@ -2889,7 +2889,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): def test_select_subquery_sef_implicit_correlate(self): Person, Book = self.classes("Person", "Book") - s = Session() + s = fixture_session() stmt = s.query(Person).subquery() @@ -2922,7 +2922,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): def test_select_subquery_sef_implicit_correlate_coreonly(self): Person, Book = self.classes("Person", "Book") - s = Session() + s = fixture_session() stmt = s.query(Person).subquery() @@ -2955,7 +2955,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): def test_select_subquery_sef_explicit_correlate_coreonly(self): Person, Book = self.classes("Person", "Book") - s = Session() + s = fixture_session() stmt = s.query(Person).subquery() @@ -2989,7 +2989,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): def test_select_subquery_sef_explicit_correlate(self): Person, Book = self.classes("Person", "Book") - s = Session() + s = fixture_session() stmt = s.query(Person).subquery() @@ -3023,7 +3023,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): def test_from_function(self): Bookcase = self.classes.Bookcase - s = Session() + s = fixture_session() srf = lateral(func.generate_series(1, Bookcase.bookcase_shelves)) @@ -3041,7 +3041,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): def test_from_function_select_entity_from(self): Bookcase = self.classes.Bookcase - s = Session() + s = fixture_session() subq = s.query(Bookcase).subquery() diff --git a/test/orm/test_lambdas.py b/test/orm/test_lambdas.py index b190f46d6..7591f844f 100644 --- a/test/orm/test_lambdas.py +++ b/test/orm/test_lambdas.py @@ -18,6 +18,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from .inheritance import _poly_fixtures @@ -70,7 +71,7 @@ class LambdaTest(QueryTest, AssertsCompiledSQL): def test_user_cols_single_lambda_query(self, plain_fixture): User, Address = plain_fixture - s = Session() + s = fixture_session() q = s.query(lambda: (User.id, User.name)).select_from(lambda: User) self.assert_compile( diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index c81de142c..3061de309 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -18,7 +18,6 @@ from sqlalchemy import util from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship @@ -29,6 +28,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.types import TypeDecorator @@ -56,7 +56,7 @@ class LazyTest(_fixtures.FixtureTest): ) }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) eq_( [ @@ -87,7 +87,7 @@ class LazyTest(_fixtures.FixtureTest): ) }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) u = q.filter(users.c.id == 7).first() sess.expunge(u) @@ -112,7 +112,7 @@ class LazyTest(_fixtures.FixtureTest): ) }, ) - q = create_session().query(User) + q = fixture_session().query(User) assert [ User(id=7, addresses=[Address(id=1)]), User( @@ -145,7 +145,7 @@ class LazyTest(_fixtures.FixtureTest): users, properties=dict(addresses=relationship(Address, lazy="select")), ) - q = create_session().query(User) + q = fixture_session().query(User) result = ( q.filter(users.c.id == addresses.c.user_id) .order_by(addresses.c.email_address) @@ -185,7 +185,7 @@ class LazyTest(_fixtures.FixtureTest): ) ), ) - sess = create_session() + sess = fixture_session() assert [ User(id=7, addresses=[Address(id=1)]), User( @@ -221,7 +221,7 @@ class LazyTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() user = sess.query(User).get(7) assert getattr(User, "addresses").hasparent( attributes.instance_state(user.addresses[0]), optimistic=True @@ -276,7 +276,7 @@ class LazyTest(_fixtures.FixtureTest): }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) if testing.against("mssql"): @@ -330,7 +330,7 @@ class LazyTest(_fixtures.FixtureTest): }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) # use a union all to get a lot of rows to join against @@ -362,7 +362,7 @@ class LazyTest(_fixtures.FixtureTest): properties={"order": relationship(Order, uselist=False)}, ) mapper(Order, orders) - s = create_session() + s = fixture_session() u1 = s.query(User).filter(User.id == 7).one() assert_raises(sa.exc.SAWarning, getattr, u1, "order") @@ -390,7 +390,7 @@ class LazyTest(_fixtures.FixtureTest): ), ) - s = Session() + s = fixture_session() ed = s.query(User).filter_by(name="ed").one() eq_( ed.addresses, @@ -421,7 +421,7 @@ class LazyTest(_fixtures.FixtureTest): ) ), ) - q = create_session().query(User) + q = fixture_session().query(User) result = q.filter(users.c.id == 7).all() assert [User(id=7, address=Address(id=1))] == result @@ -453,7 +453,7 @@ class LazyTest(_fixtures.FixtureTest): ) ), ) - q = create_session().query(User) + q = fixture_session().query(User) eq_( [ User(id=7, address=None), @@ -597,7 +597,7 @@ class LazyTest(_fixtures.FixtureTest): User, Address, Order, Item = self.classes( "User", "Address", "Order", "Item" ) - q = create_session().query(User).order_by(User.id) + q = fixture_session().query(User).order_by(User.id) def items(*ids): if no_items: @@ -643,21 +643,21 @@ class LazyTest(_fixtures.FixtureTest): else: self.assert_sql_count(testing.db, go, 15) - sess = create_session() + sess = fixture_session() user = sess.query(User).get(7) closed_mapper = User.closed_orders.entity open_mapper = User.open_orders.entity eq_( [Order(id=1), Order(id=5)], - create_session() + fixture_session() .query(closed_mapper) .with_parent(user, property="closed_orders") .all(), ) eq_( [Order(id=3)], - create_session() + fixture_session() .query(open_mapper) .with_parent(user, property="open_orders") .all(), @@ -683,7 +683,7 @@ class LazyTest(_fixtures.FixtureTest): ), ) - q = create_session().query(Item) + q = fixture_session().query(Item) assert self.static.item_keyword_result == q.all() eq_( @@ -717,7 +717,7 @@ class LazyTest(_fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() # load address a1 = ( @@ -789,7 +789,7 @@ class LazyTest(_fixtures.FixtureTest): properties=dict(user=relationship(mapper(User, users))), ) - sess = create_session(bind=testing.db) + sess = fixture_session() # load address a1 = ( @@ -823,7 +823,7 @@ class LazyTest(_fixtures.FixtureTest): user=relationship(mapper(User, users), lazy="select") ), ) - sess = create_session() + sess = fixture_session() q = sess.query(Address) a = q.filter(addresses.c.id == 1).one() @@ -847,7 +847,7 @@ class LazyTest(_fixtures.FixtureTest): properties={"addresses": relationship(Address, backref="user")}, ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) ad = sess.query(Address).filter_by(id=1).one() assert ad.user.id == 7 @@ -938,8 +938,8 @@ class GetterStateTest(_fixtures.FixtureTest): }, ) - metadata.create_all() - sess = Session(autoflush=False) + metadata.create_all(testing.db) + sess = Session(testing.db, autoflush=False) data = {"im": "unhashable"} a1 = Article(id=1, data=data) c1 = Category(id=1, data=data) @@ -983,7 +983,7 @@ class GetterStateTest(_fixtures.FixtureTest): }, ) - sess = create_session() + sess = fixture_session() a1 = Address(email_address="a1") sess.add(a1) if populate_user: @@ -1143,7 +1143,7 @@ class M2OGetTest(_fixtures.FixtureTest): mapper(Address, addresses, properties={"user": relationship(User)}) - sess = create_session() + sess = fixture_session() ad1 = Address(email_address="somenewaddress", id=12) sess.add(ad1) sess.flush() @@ -1232,7 +1232,7 @@ class CorrelatedTest(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() eq_( sess.query(User).all(), @@ -1433,7 +1433,7 @@ class RefersToSelfLazyLoadInterferenceTest(fixtures.MappedTest): def test_lazy_doesnt_interfere(self): A, B, C = self.classes("A", "B", "C") - session = Session() + session = fixture_session() b = B() session.add(b) session.flush() @@ -1512,7 +1512,7 @@ class TypeCoerceTest(fixtures.MappedTest, testing.AssertsExecutionResults): Person = self.classes.Person Pet = self.classes.Pet - s = Session() + s = fixture_session() s.add_all([Person(id=5), Pet(id=1, person_id=5)]) s.commit() diff --git a/test/orm/test_loading.py b/test/orm/test_loading.py index 819bc8bed..e15dbb09f 100644 --- a/test/orm/test_loading.py +++ b/test/orm/test_loading.py @@ -6,11 +6,11 @@ from sqlalchemy.orm import aliased from sqlalchemy.orm import loading from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.fixtures import fixture_session from . import _fixtures # class GetFromIdentityTest(_fixtures.FixtureTest): @@ -40,7 +40,7 @@ class InstanceProcessorTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - s = Session() + s = fixture_session() def go(): eq_( @@ -69,7 +69,7 @@ class InstancesTest(_fixtures.FixtureTest): def test_cursor_close_w_failed_rowproc(self): User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User) @@ -83,7 +83,7 @@ class InstancesTest(_fixtures.FixtureTest): def test_row_proc_not_created(self): User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User.id, User.name) stmt = select(User.id) @@ -107,7 +107,7 @@ class MergeResultTest(_fixtures.FixtureTest): def _fixture(self): User = self.classes.User - s = Session() + s = fixture_session() u1, u2, u3, u4 = ( User(id=1, name="u1"), User(id=2, name="u2"), @@ -130,7 +130,7 @@ class MergeResultTest(_fixtures.FixtureTest): def test_single_column(self): User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User.id) collection = [(1,), (2,), (7,), (8,)] diff --git a/test/orm/test_lockmode.py b/test/orm/test_lockmode.py index a3dd42fc2..f82c5cf7c 100644 --- a/test/orm/test_lockmode.py +++ b/test/orm/test_lockmode.py @@ -4,10 +4,10 @@ from sqlalchemy.engine import default from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures @@ -27,7 +27,7 @@ class ForUpdateTest(_fixtures.FixtureTest): assert_sel_of=None, ): User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User).with_for_update( read=read, nowait=nowait, of=of, key_share=key_share ) @@ -81,7 +81,7 @@ class BackendTest(_fixtures.FixtureTest): def test_inner_joinedload_w_limit(self): User = self.classes.User - sess = Session() + sess = fixture_session() q = ( sess.query(User) .options(joinedload(User.addresses, innerjoin=True)) @@ -97,7 +97,7 @@ class BackendTest(_fixtures.FixtureTest): def test_inner_joinedload_wo_limit(self): User = self.classes.User - sess = Session() + sess = fixture_session() sess.query(User).options( joinedload(User.addresses, innerjoin=True) ).with_for_update().all() @@ -105,7 +105,7 @@ class BackendTest(_fixtures.FixtureTest): def test_outer_joinedload_w_limit(self): User = self.classes.User - sess = Session() + sess = fixture_session() q = sess.query(User).options( joinedload(User.addresses, innerjoin=False) ) @@ -125,7 +125,7 @@ class BackendTest(_fixtures.FixtureTest): def test_outer_joinedload_wo_limit(self): User = self.classes.User - sess = Session() + sess = fixture_session() q = sess.query(User).options( joinedload(User.addresses, innerjoin=False) ) @@ -141,14 +141,14 @@ class BackendTest(_fixtures.FixtureTest): def test_join_w_subquery(self): User = self.classes.User Address = self.classes.Address - sess = Session() + sess = fixture_session() q1 = sess.query(User).with_for_update().subquery() sess.query(q1).join(Address).all() sess.close() def test_plain(self): User = self.classes.User - sess = Session() + sess = fixture_session() sess.query(User).with_for_update().all() sess.close() @@ -167,7 +167,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_default_update(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(), "SELECT users.id AS users_id FROM users FOR UPDATE", @@ -176,7 +176,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_not_supported_by_dialect_should_just_use_update(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(read=True), "SELECT users.id AS users_id FROM users FOR UPDATE", @@ -185,7 +185,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_read(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(read=True), "SELECT users.id AS users_id FROM users FOR SHARE", @@ -194,7 +194,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_read_nowait(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(read=True, nowait=True), "SELECT users.id AS users_id FROM users FOR SHARE NOWAIT", @@ -203,7 +203,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_update(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(), "SELECT users.id AS users_id FROM users FOR UPDATE", @@ -212,7 +212,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_update_of(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(of=User.id), "SELECT users.id AS users_id FROM users FOR UPDATE OF users", @@ -221,7 +221,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_update_of_entity(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(of=User), "SELECT users.id AS users_id FROM users FOR UPDATE OF users", @@ -232,7 +232,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id, Address.id).with_for_update( of=[User, Address] @@ -244,7 +244,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_for_no_key_update(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(key_share=True), "SELECT users.id AS users_id FROM users FOR NO KEY UPDATE", @@ -253,7 +253,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_for_no_key_nowait_update(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(key_share=True, nowait=True), "SELECT users.id AS users_id FROM users FOR NO KEY UPDATE NOWAIT", @@ -262,7 +262,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_update_of_list(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update( of=[User.id, User.id, User.id] @@ -273,7 +273,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_update_skip_locked(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(skip_locked=True), "SELECT users.id AS users_id FROM users FOR UPDATE SKIP LOCKED", @@ -282,7 +282,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_oracle_update(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(), "SELECT users.id AS users_id FROM users FOR UPDATE", @@ -291,7 +291,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_oracle_update_skip_locked(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(skip_locked=True), "SELECT users.id AS users_id FROM users FOR UPDATE SKIP LOCKED", @@ -300,7 +300,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_mysql_read(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User.id).with_for_update(read=True), "SELECT users.id AS users_id FROM users LOCK IN SHARE MODE", @@ -309,7 +309,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_for_update_on_inner_w_joinedload(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User) .options(joinedload(User.addresses)) @@ -328,7 +328,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_for_update_on_inner_w_joinedload_no_render_oracle(self): User = self.classes.User - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(User) .options(joinedload(User.addresses)) diff --git a/test/orm/test_manytomany.py b/test/orm/test_manytomany.py index 8b51d7e20..79c63872d 100644 --- a/test/orm/test_manytomany.py +++ b/test/orm/test_manytomany.py @@ -7,11 +7,10 @@ from sqlalchemy.orm import backref from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session -from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -167,7 +166,7 @@ class M2MTest(fixtures.MappedTest): }, ) - sess = Session() + sess = fixture_session() p1 = Place("place1") p2 = Place("place2") p3 = Place("place3") @@ -216,7 +215,7 @@ class M2MTest(fixtures.MappedTest): }, ) - sess = Session() + sess = fixture_session() p1 = Place("place1") p2 = Place("place2") p2.parent_places = [p1] @@ -271,7 +270,7 @@ class M2MTest(fixtures.MappedTest): tran.inputs.append(Place("place1")) tran.outputs.append(Place("place2")) tran.outputs.append(Place("place3")) - sess = Session() + sess = fixture_session() sess.add(tran) sess.commit() @@ -327,7 +326,7 @@ class M2MTest(fixtures.MappedTest): p2 = Place("place2") p3 = Place("place3") - sess = Session() + sess = fixture_session() sess.add_all([p3, p1, t1, t2, p2, t3]) t1.inputs.append(p1) @@ -380,7 +379,7 @@ class M2MTest(fixtures.MappedTest): p1 = Place("place1") t1 = Transition("t1") p1.transitions.append(t1) - sess = sessionmaker()() + sess = fixture_session() sess.add_all([p1, t1]) sess.commit() @@ -494,7 +493,7 @@ class AssortedPersistenceTests(fixtures.MappedTest): A, B = self.classes.A, self.classes.B secondary = self.tables.secondary - sess = Session() + sess = fixture_session() sess.add_all( [A(data="a1", bs=[B(data="b1")]), A(data="a2", bs=[B(data="b2")])] ) @@ -516,7 +515,7 @@ class AssortedPersistenceTests(fixtures.MappedTest): A, B = self.classes.A, self.classes.B secondary = self.tables.secondary - sess = Session() + sess = fixture_session() sess.add_all([A(data="a1", bs=[B(data="b1"), B(data="b2")])]) sess.commit() diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index edbb4b0cd..013eb21e1 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -18,7 +18,6 @@ from sqlalchemy.orm import class_mapper from sqlalchemy.orm import column_property from sqlalchemy.orm import composite from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import deferred from sqlalchemy.orm import dynamic_loader from sqlalchemy.orm import mapper @@ -36,6 +35,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import ne_ from sqlalchemy.testing.fixtures import ComparableMixin +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm import _fixtures @@ -274,7 +274,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): properties={"user_name": synonym("_name")}, ) - s = create_session() + s = fixture_session() u = s.query(User).get(7) eq_(u._name, "jack") eq_(u._id, 7) @@ -324,7 +324,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): User, users = self.classes.User, self.tables.users m = self.mapper(User, users) - session = create_session() + session = fixture_session() session.connection(mapper=m) def test_incomplete_columns(self): @@ -333,7 +333,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): addresses, Address = self.tables.addresses, self.classes.Address self.mapper(Address, addresses) - s = create_session() + s = fixture_session() a = ( s.query(Address) .from_statement( @@ -708,7 +708,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): m.add_property("name", synonym("_name")) m.add_property("addresses", relationship(Address)) - sess = create_session(autocommit=False) + sess = fixture_session(autocommit=False) assert sess.query(User).get(7) u = sess.query(User).filter_by(name="jack").one() @@ -754,7 +754,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): m.add_property("_name", users.c.name) m.add_property("name", synonym("_name")) - sess = create_session() + sess = fixture_session() u = sess.query(User).filter_by(name="jack").one() eq_(u._name, "jack") eq_(u.name, "jack") @@ -810,7 +810,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): # add property using annotated User.name, # needs to be deannotated m.add_property("x", column_property(User.name + "name")) - s = create_session() + s = fixture_session() q = s.query(m2).select_from(Address).join(Address.foo) self.assert_compile( q, @@ -884,7 +884,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): properties={"addresses": relationship(Address, backref="_user")}, ) - sess = create_session() + sess = fixture_session() u1 = sess.query(User).get(7) u2 = sess.query(User).get(8) # comparaison ops need to work @@ -1118,9 +1118,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert hasattr(Foo, "type") assert Foo.type.property.columns[0] is t.c.type - @testing.provide_metadata - def test_prop_filters_defaults(self): - metadata = self.metadata + def test_prop_filters_defaults(self, metadata, connection): t = Table( "t", metadata, @@ -1132,13 +1130,14 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): ), Column("x", Integer(), nullable=False, server_default="0"), ) - t.create() + + t.create(connection) class A(object): pass self.mapper(A, t, include_properties=["id"]) - s = Session() + s = Session(connection) s.add(A()) s.commit() @@ -1219,7 +1218,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): primary_key=[users.c.id], properties={"add_id": addresses.c.id}, ) - result = create_session().query(User).order_by(users.c.id).all() + result = fixture_session().query(User).order_by(users.c.id).all() eq_(result, self.static.user_result[:3]) def test_mapping_to_join_exclude_prop(self): @@ -1240,7 +1239,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): primary_key=[users.c.id], exclude_properties=[addresses.c.id], ) - result = create_session().query(User).order_by(users.c.id).all() + result = fixture_session().query(User).order_by(users.c.id).all() eq_(result, self.static.user_result[:3]) def test_mapping_to_join_no_pk(self): @@ -1259,13 +1258,23 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert addresses in m._pks_by_table assert email_bounces not in m._pks_by_table - sess = create_session() + sess = fixture_session() a = Address(id=10, email_address="e1") sess.add(a) sess.flush() - eq_(select(func.count("*")).select_from(addresses).scalar(), 6) - eq_(select(func.count("*")).select_from(email_bounces).scalar(), 5) + eq_( + sess.connection().scalar( + select(func.count("*")).select_from(addresses) + ), + 6, + ) + eq_( + sess.connection().scalar( + select(func.count("*")).select_from(email_bounces) + ), + 5, + ) def test_mapping_to_outerjoin(self): """Mapping to an outer join with a nullable composite primary key.""" @@ -1283,7 +1292,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): properties=dict(address_id=addresses.c.id), ) - session = create_session() + session = fixture_session() result = session.query(User).order_by(User.id, User.address_id).all() eq_( @@ -1315,7 +1324,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): properties=dict(address_id=addresses.c.id), ) - session = create_session() + session = fixture_session() result = session.query(User).order_by(User.id, User.address_id).all() eq_( @@ -1371,7 +1380,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.mapper(User, users, properties=dict(orders=relationship(Order))) - session = create_session() + session = fixture_session() result = ( session.query(User) .select_from(users.join(orders).join(order_items).join(items)) @@ -1403,7 +1412,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): ) self.mapper(User, s) - sess = create_session() + sess = fixture_session() result = sess.query(User).order_by(s.c.id).all() for idx, total in enumerate((14, 16)): @@ -1417,7 +1426,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.mapper(User, users) - session = create_session() + session = fixture_session() q = session.query(User) eq_(q.count(), 4) @@ -1445,7 +1454,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): ), ) - session = create_session() + session = fixture_session() q = ( session.query(Item) .join("keywords") @@ -1566,7 +1575,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert User.uname.property assert User.adlist.property - sess = create_session() + sess = fixture_session() # test RowTuple names row = sess.query(User.id, User.uname).first() @@ -1601,7 +1610,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): User, users, properties={"x": synonym("id"), "y": synonym("x")} ) - s = Session() + s = fixture_session() u = s.query(User).filter(User.y == 8).one() eq_(u.y, 8) @@ -1733,7 +1742,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert hasattr(User, "name") assert hasattr(User, "_name") - sess = create_session() + sess = fixture_session() u = sess.query(User).filter(User.name == "jack").one() eq_(u.name, "jack") u.name = "foo" @@ -1815,7 +1824,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): User() eq_(recon, []) - create_session().query(User).first() + fixture_session().query(User).first() eq_(recon, ["go"]) def test_reconstructor_inheritance(self): @@ -1852,7 +1861,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): C() eq_(recon, []) - sess = create_session() + sess = fixture_session() sess.query(A).first() sess.query(B).first() sess.query(C).first() @@ -1875,7 +1884,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): eq_(recon, ["go"]) recon[:] = [] - create_session().query(User).first() + fixture_session().query(User).first() eq_(recon, ["go"]) def test_reconstructor_init_inheritance(self): @@ -1913,7 +1922,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): eq_(recon, ["A", "B", "C"]) recon[:] = [] - sess = create_session() + sess = fixture_session() sess.query(A).first() sess.query(B).first() sess.query(C).first() @@ -1937,7 +1946,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): User() eq_(recon, []) - create_session().query(User).first() + fixture_session().query(User).first() eq_(recon, ["go"]) def test_unmapped_error(self): @@ -2041,7 +2050,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): # using it with an ORM operation, raises assert_raises( - sa.orm.exc.UnmappedClassError, create_session().add, Sub() + sa.orm.exc.UnmappedClassError, fixture_session().add, Sub() ) def test_unmapped_subclass_error_premap(self): @@ -2065,7 +2074,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): # using it with an ORM operation, raises assert_raises( - sa.orm.exc.UnmappedClassError, create_session().add, Sub() + sa.orm.exc.UnmappedClassError, fixture_session().add, Sub() ) def test_oldstyle_mixin(self): @@ -2223,7 +2232,7 @@ class RequirementsTest(fixtures.MappedTest): self.mapper(H3, ht3) self.mapper(H6, ht6) - s = create_session() + s = fixture_session() s.add_all([H1("abc"), H1("def")]) h1 = H1("ghi") s.add(h1) @@ -2232,7 +2241,7 @@ class RequirementsTest(fixtures.MappedTest): h1.h1s.append(H1()) s.flush() - eq_(select(func.count("*")).select_from(ht1).scalar(), 4) + eq_(s.connection().scalar(select(func.count("*")).select_from(ht1)), 4) h6 = H6() h6.h1a = h1 @@ -2300,7 +2309,7 @@ class RequirementsTest(fixtures.MappedTest): H1, ht1, properties={"h2s": relationship(H2, backref="h1")} ) self.mapper(H2, ht2) - s = Session() + s = fixture_session() s.add_all( [ H1( @@ -2479,7 +2488,7 @@ class MagicNamesTest(fixtures.MappedTest): ) Map(state="AK", mapper=c) - sess = create_session() + sess = fixture_session() sess.add(c) sess.flush() sess.expunge_all() @@ -2633,7 +2642,7 @@ class ORMLoggingTest(_fixtures.FixtureTest): User, users = self.classes.User, self.tables.users tb = users.select().alias() self.mapper(User, tb) - s = Session() + s = fixture_session() s.add(User(name="ed")) s.commit() diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index 57d3ce01d..e0a76c6a0 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -13,14 +13,12 @@ from sqlalchemy import Text from sqlalchemy.orm import attributes from sqlalchemy.orm import backref from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import defer from sqlalchemy.orm import deferred from sqlalchemy.orm import foreign from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import Session -from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import synonym from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.orm.interfaces import MapperOption @@ -30,6 +28,7 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import not_in +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.util import OrderedSet @@ -57,7 +56,7 @@ class MergeTest(_fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = create_session() + sess = fixture_session() load = self.load_tracker(User) u = User(id=7, name="fred") @@ -77,7 +76,7 @@ class MergeTest(_fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = create_session() + sess = fixture_session() u = User(name="fred") def go(): @@ -89,7 +88,7 @@ class MergeTest(_fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = create_session() + sess = fixture_session(autoflush=False) u = User(name="fred") sess.add(u) @@ -104,7 +103,7 @@ class MergeTest(_fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = create_session() + sess = fixture_session(autoflush=False) u = User(id=1, name="fred") sess.add(u) @@ -148,7 +147,7 @@ class MergeTest(_fixtures.FixtureTest): ) eq_(load.called, 0) - sess = create_session() + sess = fixture_session() sess.merge(u) eq_(load.called, 3) @@ -188,7 +187,7 @@ class MergeTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - s = Session() + s = fixture_session() u = User( id=7, name="fred", @@ -239,7 +238,7 @@ class MergeTest(_fixtures.FixtureTest): ) eq_(load.called, 0) - sess = create_session() + sess = fixture_session() sess.merge(u) eq_(load.called, 3) @@ -269,7 +268,7 @@ class MergeTest(_fixtures.FixtureTest): mapper(User, users) load = self.load_tracker(User) - sess = create_session() + sess = fixture_session() u = User(id=7, name="fred") sess.add(u) sess.flush() @@ -323,7 +322,7 @@ class MergeTest(_fixtures.FixtureTest): ] ), ) - sess = create_session() + sess = fixture_session() sess.add(u) sess.flush() sess.expunge_all() @@ -408,7 +407,7 @@ class MergeTest(_fixtures.FixtureTest): name="fred", addresses=OrderedSet([a, Address(id=2, email_address="fred2")]), ) - sess = create_session() + sess = fixture_session() sess.add(u) sess.flush() sess.expunge_all() @@ -459,7 +458,7 @@ class MergeTest(_fixtures.FixtureTest): ) load = self.load_tracker(User) self.load_tracker(Address, load) - sess = create_session() + sess = fixture_session() u = User(id=7, name="fred") a1 = Address(email_address="foo@bar.com") @@ -515,7 +514,7 @@ class MergeTest(_fixtures.FixtureTest): mapper(User, dingalings) - sess = create_session() + sess = fixture_session(autoflush=False) # merge empty stuff. goes in as NULL. # not sure what this was originally trying to @@ -597,7 +596,7 @@ class MergeTest(_fixtures.FixtureTest): ) u1 = User(id=7, name="fred") u1.addresses["foo@bar.com"] = Address(email_address="foo@bar.com") - sess = create_session() + sess = fixture_session() sess.merge(u1) sess.flush() assert list(u1.addresses.keys()) == ["foo@bar.com"] @@ -625,22 +624,21 @@ class MergeTest(_fixtures.FixtureTest): load = self.load_tracker(User) self.load_tracker(Address, load) - sess = create_session() + with fixture_session(expire_on_commit=False) as sess, sess.begin(): - # set up data and save - u = User( - id=7, - name="fred", - addresses=[ - Address(email_address="foo@bar.com"), - Address(email_address="hoho@la.com"), - ], - ) - sess.add(u) - sess.flush() + # set up data and save + u = User( + id=7, + name="fred", + addresses=[ + Address(email_address="foo@bar.com"), + Address(email_address="hoho@la.com"), + ], + ) + sess.add(u) # assert data was saved - sess2 = create_session() + sess2 = fixture_session() u2 = sess2.query(User).get(7) eq_( u2, @@ -661,88 +659,91 @@ class MergeTest(_fixtures.FixtureTest): eq_(load.called, 3) # new session, merge modified data into session - sess3 = create_session() - u3 = sess3.merge(u) - eq_(load.called, 6) + with fixture_session(expire_on_commit=False) as sess3: + u3 = sess3.merge(u) + eq_(load.called, 6) - # ensure local changes are pending - eq_( - u3, - User( - id=7, - name="fred2", - addresses=[ - Address(email_address="foo@bar.com"), - Address(email_address="hoho@lalala.com"), - ], - ), - ) + # ensure local changes are pending + eq_( + u3, + User( + id=7, + name="fred2", + addresses=[ + Address(email_address="foo@bar.com"), + Address(email_address="hoho@lalala.com"), + ], + ), + ) - # save merged data - sess3.flush() + # save merged data + sess3.commit() # assert modified/merged data was saved - sess.expunge_all() - u = sess.query(User).get(7) - eq_( - u, - User( - id=7, - name="fred2", - addresses=[ - Address(email_address="foo@bar.com"), - Address(email_address="hoho@lalala.com"), - ], - ), - ) - eq_(load.called, 9) + with fixture_session() as sess: + u = sess.query(User).get(7) + eq_( + u, + User( + id=7, + name="fred2", + addresses=[ + Address(email_address="foo@bar.com"), + Address(email_address="hoho@lalala.com"), + ], + ), + ) + eq_(load.called, 9) # merge persistent object into another session - sess4 = create_session() - u = sess4.merge(u) - assert len(u.addresses) - for a in u.addresses: - assert a.user is u + with fixture_session(expire_on_commit=False) as sess4: + u = sess4.merge(u) + assert len(u.addresses) + for a in u.addresses: + assert a.user is u - def go(): - sess4.flush() + def go(): + sess4.flush() - # no changes; therefore flush should do nothing - self.assert_sql_count(testing.db, go, 0) - eq_(load.called, 12) + # no changes; therefore flush should do nothing + self.assert_sql_count(testing.db, go, 0) - # test with "dontload" merge - sess5 = create_session() - u = sess5.merge(u, load=False) - assert len(u.addresses) - for a in u.addresses: - assert a.user is u + sess4.commit() - def go(): - sess5.flush() + eq_(load.called, 12) - # no changes; therefore flush should do nothing - # but also, load=False wipes out any difference in committed state, - # so no flush at all - self.assert_sql_count(testing.db, go, 0) + # test with "dontload" merge + with fixture_session(expire_on_commit=False) as sess5: + u = sess5.merge(u, load=False) + assert len(u.addresses) + for a in u.addresses: + assert a.user is u + + def go(): + sess5.flush() + + # no changes; therefore flush should do nothing + # but also, load=False wipes out any difference in committed state, + # so no flush at all + self.assert_sql_count(testing.db, go, 0) eq_(load.called, 15) - sess4 = create_session() - u = sess4.merge(u, load=False) - # post merge change - u.addresses[1].email_address = "afafds" + with fixture_session(expire_on_commit=False) as sess4, sess4.begin(): + u = sess4.merge(u, load=False) + # post merge change + u.addresses[1].email_address = "afafds" - def go(): - sess4.flush() + def go(): + sess4.flush() - # afafds change flushes - self.assert_sql_count(testing.db, go, 1) + # afafds change flushes + self.assert_sql_count(testing.db, go, 1) eq_(load.called, 18) - sess5 = create_session() - u2 = sess5.query(User).get(u.id) - eq_(u2.name, "fred2") - eq_(u2.addresses[1].email_address, "afafds") + with fixture_session(expire_on_commit=False) as sess5: + u2 = sess5.query(User).get(u.id) + eq_(u2.name, "fred2") + eq_(u2.addresses[1].email_address, "afafds") eq_(load.called, 21) def test_dont_send_neverset_to_get(self): @@ -754,7 +755,7 @@ class MergeTest(_fixtures.FixtureTest): mapper(CompositePk, composite_pk_table) cp1 = CompositePk(j=1, k=1) - sess = Session() + sess = fixture_session() rec = [] @@ -788,7 +789,7 @@ class MergeTest(_fixtures.FixtureTest): u1 = User(id=5, name="some user") cp1 = CompositePk(j=1, k=1) u1.elements.append(cp1) - sess = Session() + sess = fixture_session() rec = [] @@ -819,7 +820,7 @@ class MergeTest(_fixtures.FixtureTest): properties={"user": relationship(User, cascade="save-update")}, ) mapper(User, users) - sess = create_session() + sess = fixture_session() u1 = User(name="fred") a1 = Address(email_address="asdf", user=u1) sess.add(a1) @@ -858,18 +859,18 @@ class MergeTest(_fixtures.FixtureTest): load = self.load_tracker(User) self.load_tracker(Address, load) - sess = create_session() + sess = fixture_session(expire_on_commit=False) u = User(name="fred") a1 = Address(email_address="foo@bar") a2 = Address(email_address="foo@quux") u.addresses.extend([a1, a2]) sess.add(u) - sess.flush() + sess.commit() eq_(load.called, 0) - sess2 = create_session() + sess2 = fixture_session() u2 = sess2.query(User).get(u.id) eq_(load.called, 1) @@ -878,7 +879,7 @@ class MergeTest(_fixtures.FixtureTest): eq_(u2.addresses[1].email_address, "addr 2 modified") eq_(load.called, 3) - sess3 = create_session() + sess3 = fixture_session() u3 = sess3.query(User).get(u.id) eq_(load.called, 4) @@ -902,23 +903,23 @@ class MergeTest(_fixtures.FixtureTest): a1 = Address(id=1, email_address="a1", user=u1) u2 = User(id=2, name="u2") - sess = create_session() + sess = fixture_session(expire_on_commit=False) sess.add_all([a1, u2]) - sess.flush() + sess.commit() a1.user = u2 - sess2 = create_session() - a2 = sess2.merge(a1) - eq_(attributes.get_history(a2, "user"), ([u2], (), ())) - assert a2 in sess2.dirty + with fixture_session(expire_on_commit=False) as sess2: + a2 = sess2.merge(a1) + eq_(attributes.get_history(a2, "user"), ([u2], (), ())) + assert a2 in sess2.dirty sess.refresh(a1) - sess2 = create_session() - a2 = sess2.merge(a1, load=False) - eq_(attributes.get_history(a2, "user"), ((), [u1], ())) - assert a2 not in sess2.dirty + with fixture_session(expire_on_commit=False) as sess2: + a2 = sess2.merge(a1, load=False) + eq_(attributes.get_history(a2, "user"), ((), [u1], ())) + assert a2 not in sess2.dirty def test_many_to_many_cascade(self): items, Order, orders, order_items, Item = ( @@ -942,41 +943,41 @@ class MergeTest(_fixtures.FixtureTest): load = self.load_tracker(Order) self.load_tracker(Item, load) - sess = create_session() + with fixture_session(expire_on_commit=False) as sess: - i1 = Item() - i1.description = "item 1" + i1 = Item() + i1.description = "item 1" - i2 = Item() - i2.description = "item 2" + i2 = Item() + i2.description = "item 2" - o = Order() - o.description = "order description" - o.items.append(i1) - o.items.append(i2) + o = Order() + o.description = "order description" + o.items.append(i1) + o.items.append(i2) - sess.add(o) - sess.flush() + sess.add(o) + sess.commit() eq_(load.called, 0) - sess2 = create_session() - o2 = sess2.query(Order).get(o.id) - eq_(load.called, 1) + with fixture_session(expire_on_commit=False) as sess2: + o2 = sess2.query(Order).get(o.id) + eq_(load.called, 1) - o.items[1].description = "item 2 modified" - sess2.merge(o) - eq_(o2.items[1].description, "item 2 modified") - eq_(load.called, 3) + o.items[1].description = "item 2 modified" + sess2.merge(o) + eq_(o2.items[1].description, "item 2 modified") + eq_(load.called, 3) - sess3 = create_session() - o3 = sess3.query(Order).get(o.id) - eq_(load.called, 4) + with fixture_session(expire_on_commit=False) as sess3: + o3 = sess3.query(Order).get(o.id) + eq_(load.called, 4) - o.description = "desc modified" - sess3.merge(o) - eq_(load.called, 6) - eq_(o3.description, "desc modified") + o.description = "desc modified" + sess3.merge(o) + eq_(load.called, 6) + eq_(o3.description, "desc modified") def test_one_to_one_cascade(self): users, Address, addresses, User = ( @@ -997,7 +998,7 @@ class MergeTest(_fixtures.FixtureTest): ) load = self.load_tracker(User) self.load_tracker(Address, load) - sess = create_session() + sess = fixture_session(expire_on_commit=False) u = User() u.id = 7 @@ -1007,11 +1008,11 @@ class MergeTest(_fixtures.FixtureTest): u.address = a1 sess.add(u) - sess.flush() + sess.commit() eq_(load.called, 0) - sess2 = create_session() + sess2 = fixture_session() u2 = sess2.query(User).get(7) eq_(load.called, 1) u2.name = "fred2" @@ -1039,7 +1040,7 @@ class MergeTest(_fixtures.FixtureTest): ) }, ) - sess = sessionmaker()() + sess = fixture_session() u = User( id=7, name="fred", @@ -1065,7 +1066,7 @@ class MergeTest(_fixtures.FixtureTest): mapper(User, users) - sess = create_session() + sess = fixture_session() u = User() assert_raises_message( sa.exc.InvalidRequestError, @@ -1104,13 +1105,13 @@ class MergeTest(_fixtures.FixtureTest): Address(email_address="ad2"), ], ) - sess = create_session() + sess = fixture_session() sess.add(u) sess.flush() sess.close() assert "user" in u.addresses[1].__dict__ - sess = create_session() + sess = fixture_session() u2 = sess.merge(u, load=False) assert "user" in u2.addresses[1].__dict__ eq_(u2.addresses[1].user, User(id=7, name="fred")) @@ -1119,7 +1120,7 @@ class MergeTest(_fixtures.FixtureTest): assert "user" not in u2.addresses[1].__dict__ sess.close() - sess = create_session() + sess = fixture_session() u = sess.merge(u2, load=False) assert "user" not in u.addresses[1].__dict__ eq_(u.addresses[1].user, User(id=7, name="fred")) @@ -1150,21 +1151,21 @@ class MergeTest(_fixtures.FixtureTest): users, properties={"addresses": relationship(mapper(Address, addresses))}, ) - sess = create_session() - u = User() - u.id = 7 - u.name = "fred" - a1 = Address() - a1.email_address = "foo@bar.com" - u.addresses.append(a1) + with fixture_session(expire_on_commit=False) as sess: + u = User() + u.id = 7 + u.name = "fred" + a1 = Address() + a1.email_address = "foo@bar.com" + u.addresses.append(a1) - sess.add(u) - sess.flush() + sess.add(u) + sess.commit() - sess2 = create_session() + sess2 = fixture_session() u2 = sess2.query(User).options(sa.orm.joinedload("addresses")).get(7) - sess3 = create_session() + sess3 = fixture_session() u3 = sess3.merge(u2, load=False) # noqa def go(): @@ -1182,15 +1183,15 @@ class MergeTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = create_session() - u = User() - u.id = 7 - u.name = "fred" - sess.add(u) - sess.flush() + with fixture_session(expire_on_commit=False) as sess: + u = User() + u.id = 7 + u.name = "fred" + sess.add(u) + sess.commit() u.name = "ed" - sess2 = create_session() + sess2 = fixture_session() try: sess2.merge(u, load=False) assert False @@ -1203,7 +1204,7 @@ class MergeTest(_fixtures.FixtureTest): u2 = sess2.query(User).get(7) - sess3 = create_session() + sess3 = fixture_session() u3 = sess3.merge(u2, load=False) # noqa assert not sess3.dirty @@ -1230,7 +1231,7 @@ class MergeTest(_fixtures.FixtureTest): }, ) - sess = create_session() + sess = fixture_session() u = User() u.id = 7 u.name = "fred" @@ -1243,7 +1244,7 @@ class MergeTest(_fixtures.FixtureTest): assert u.addresses[0].user is u - sess2 = create_session() + sess2 = fixture_session() u2 = sess2.merge(u, load=False) assert not sess2.dirty @@ -1285,33 +1286,33 @@ class MergeTest(_fixtures.FixtureTest): ) }, ) - sess = create_session() - u = User() - u.id = 7 - u.name = "fred" - a1 = Address() - a1.email_address = "foo@bar.com" - u.addresses.append(a1) - sess.add(u) - sess.flush() + with fixture_session(expire_on_commit=False) as sess: + u = User() + u.id = 7 + u.name = "fred" + a1 = Address() + a1.email_address = "foo@bar.com" + u.addresses.append(a1) + sess.add(u) + sess.commit() assert u.addresses[0].user is u - sess2 = create_session() - u2 = sess2.merge(u, load=False) - assert not sess2.dirty - a2 = u2.addresses[0] - a2.email_address = "somenewaddress" - assert not sa.orm.object_mapper(a2)._is_orphan( - sa.orm.attributes.instance_state(a2) - ) - sess2.flush() - sess2.expunge_all() + with fixture_session(expire_on_commit=False) as sess2: + u2 = sess2.merge(u, load=False) + assert not sess2.dirty + a2 = u2.addresses[0] + a2.email_address = "somenewaddress" + assert not sa.orm.object_mapper(a2)._is_orphan( + sa.orm.attributes.instance_state(a2) + ) + sess2.commit() - eq_( - sess2.query(User).get(u2.id).addresses[0].email_address, - "somenewaddress", - ) + with fixture_session() as sess2: + eq_( + sess2.query(User).get(u2.id).addresses[0].email_address, + "somenewaddress", + ) # this use case is not supported; this is with a pending Address # on the pre-merged object, and we currently don't support @@ -1321,10 +1322,11 @@ class MergeTest(_fixtures.FixtureTest): # instances. so if we start supporting 'dirty' with load=False, # this test will need to pass - sess = create_session() + sess2 = fixture_session() + sess = fixture_session() u = sess.query(User).get(7) u.addresses.append(Address()) - sess2 = create_session() + sess2 = fixture_session() try: u2 = sess2.merge(u, load=False) assert False @@ -1359,7 +1361,7 @@ class MergeTest(_fixtures.FixtureTest): mapper(User, users, properties={"uid": synonym("id")}) - sess = create_session() + sess = fixture_session() u = User() u.name = "ed" sess.add(u) @@ -1377,7 +1379,7 @@ class MergeTest(_fixtures.FixtureTest): self.tables.users, ) - s = create_session(autoflush=True, autocommit=False) + s = fixture_session(autoflush=True, autocommit=False) mapper( User, users, @@ -1406,7 +1408,7 @@ class MergeTest(_fixtures.FixtureTest): self.tables.users, ) - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session(autoflush=True, autocommit=False) mapper( User, users, @@ -1446,7 +1448,7 @@ class MergeTest(_fixtures.FixtureTest): u = User( id=7, name="fred", addresses=[Address(id=1, email_address="fred1")] ) - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session(autoflush=True, autocommit=False) sess.add(u) sess.commit() @@ -1471,7 +1473,7 @@ class MergeTest(_fixtures.FixtureTest): mapper(User, users) u = User(id=7) - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session(autoflush=True, autocommit=False) u = sess.merge(u) assert not bool(attributes.instance_state(u).expired_attributes) @@ -1493,18 +1495,18 @@ class MergeTest(_fixtures.FixtureTest): opt1, opt2 = Option(), Option() - sess = sessionmaker()() + sess = fixture_session() umapper = mapper(User, users) sess.add_all([User(id=1, name="u1"), User(id=2, name="u2")]) sess.commit() - sess2 = sessionmaker()() + sess2 = fixture_session() s2_users = sess2.query(User).options(opt2).all() # test 1. no options are replaced by merge options - sess = sessionmaker()() + sess = fixture_session() s1_users = sess.query(User).all() for u in s1_users: @@ -1521,7 +1523,7 @@ class MergeTest(_fixtures.FixtureTest): eq_(ustate.load_options, set([opt2])) # test 2. present options are replaced by merge options - sess = sessionmaker()() + sess = fixture_session() s1_users = sess.query(User).options(opt1).all() for u in s1_users: ustate = attributes.instance_state(u) @@ -1559,7 +1561,7 @@ class MergeTest(_fixtures.FixtureTest): Order(description="o3", address=Address(email_address="c")), ] - sess = Session() + sess = fixture_session() sess.merge(u1) sess.flush() @@ -1593,7 +1595,7 @@ class MergeTest(_fixtures.FixtureTest): Order(description="o3", address=Address(id=1, email_address="c")), ] - sess = Session() + sess = fixture_session() sess.merge(u1) sess.flush() @@ -1615,7 +1617,7 @@ class MergeTest(_fixtures.FixtureTest): mapper(Order, orders, properties={"address": relationship(Address)}) mapper(Address, addresses) - sess = Session() + sess = fixture_session() sess.add(Address(id=1, email_address="z")) sess.commit() @@ -1626,7 +1628,7 @@ class MergeTest(_fixtures.FixtureTest): Order(description="o3", address=Address(id=1, email_address="c")), ] - sess = Session() + sess = fixture_session() sess.merge(u1) sess.flush() @@ -1718,7 +1720,7 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): # address.user_id is 1, you get a load. def test_persistent_access_none(self): User, Address = self.classes.User, self.classes.Address - s = Session() + s = fixture_session() def go(): u1 = User(id=1, addresses=[Address(id=1), Address(id=2)]) @@ -1728,7 +1730,7 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): def test_persistent_access_one(self): User, Address = self.classes.User, self.classes.Address - s = Session() + s = fixture_session() def go(): u1 = User(id=1, addresses=[Address(id=1), Address(id=2)]) @@ -1740,7 +1742,7 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): def test_persistent_access_two(self): User, Address = self.classes.User, self.classes.Address - s = Session() + s = fixture_session() def go(): u1 = User(id=1, addresses=[Address(id=1), Address(id=2)]) @@ -1759,7 +1761,7 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): # persistent. def test_pending_access_one(self): User, Address = self.classes.User, self.classes.Address - s = Session() + s = fixture_session() def go(): u1 = User( @@ -1778,7 +1780,7 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): def test_pending_access_two(self): User, Address = self.classes.User, self.classes.Address - s = Session() + s = fixture_session() def go(): u1 = User( @@ -1819,7 +1821,7 @@ class DeferredMergeTest(fixtures.MappedTest): # defer 'excerpt' at mapping level instead of query level Book, book = self.classes.Book, self.tables.book mapper(Book, book, properties={"excerpt": deferred(book.c.excerpt)}) - sess = sessionmaker()() + sess = fixture_session() b = Book( id=1, @@ -1864,7 +1866,7 @@ class DeferredMergeTest(fixtures.MappedTest): def test_deferred_column_query(self): Book, book = self.classes.Book, self.tables.book mapper(Book, book) - sess = sessionmaker()() + sess = fixture_session() b = Book( id=1, @@ -1929,7 +1931,7 @@ class MutableMergeTest(fixtures.MappedTest): Data, data = self.classes.Data, self.tables.data mapper(Data, data) - sess = sessionmaker()() + sess = fixture_session() d = Data(data=["this", "is", "a", "list"]) sess.add(d) @@ -1959,7 +1961,7 @@ class CompositeNullPksTest(fixtures.MappedTest): Data, data = self.classes.Data, self.tables.data mapper(Data, data) - sess = sessionmaker()() + sess = fixture_session() d1 = Data(pk1="someval", pk2=None) @@ -1972,7 +1974,7 @@ class CompositeNullPksTest(fixtures.MappedTest): Data, data = self.classes.Data, self.tables.data mapper(Data, data, allow_partial_pks=False) - sess = sessionmaker()() + sess = fixture_session() d1 = Data(pk1="someval", pk2=None) @@ -2022,7 +2024,7 @@ class LoadOnPendingTest(fixtures.MappedTest): }, ) mapper(self.classes.Bug, self.tables.bugs) - self.sess = sessionmaker()() + self.sess = fixture_session() def _merge_delete_orphan_o2o_with(self, bug): # create a transient rock with passed bug @@ -2104,7 +2106,7 @@ class PolymorphicOnTest(fixtures.MappedTest): inherits=employee_mapper, polymorphic_identity="engineer", ) - self.sess = sessionmaker()() + self.sess = fixture_session() def test_merge_polymorphic_on(self): """merge() should succeed with a polymorphic object even when diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py index d814b0cab..4a3fee634 100644 --- a/test/orm/test_naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -14,7 +14,6 @@ from sqlalchemy import testing from sqlalchemy import TypeDecorator from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.orm.session import make_transient from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -22,14 +21,17 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import ne_ -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm import _fixtures def _backend_specific_fk_args(): - if testing.requires.deferrable_fks.enabled: + if ( + testing.requires.deferrable_fks.enabled + and testing.requires.non_updating_cascade.enabled + ): fk_args = dict(deferrable=True, initially="deferred") elif not testing.requires.on_update_cascade.enabled: fk_args = dict() @@ -108,7 +110,7 @@ class NaturalPKTest(fixtures.MappedTest): mapper(User, users) - sess = create_session() + sess = fixture_session() u1 = User(username="jack", fullname="jack") sess.add(u1) @@ -134,7 +136,7 @@ class NaturalPKTest(fixtures.MappedTest): mapper(User, users) - sess = create_session() + sess = fixture_session() u1 = User(username="jack", fullname="jack") sess.add(u1) @@ -161,7 +163,7 @@ class NaturalPKTest(fixtures.MappedTest): mapper(User, users) - sess = create_session() + sess = fixture_session() u1 = User(username="jack", fullname="jack") sess.add(u1) @@ -180,7 +182,7 @@ class NaturalPKTest(fixtures.MappedTest): mapper(User, users) - sess = create_session() + sess = fixture_session() u1 = User(username="jack", fullname="jack") sess.add(u1) @@ -196,7 +198,7 @@ class NaturalPKTest(fixtures.MappedTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = create_session() + sess = fixture_session() u1 = User(username="jack", fullname="jack") sess.add(u1) @@ -235,7 +237,7 @@ class NaturalPKTest(fixtures.MappedTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() u1 = User(username="jack", fullname="jack") u1.addresses.append(Address(email="jack1")) u1.addresses.append(Address(email="jack2")) @@ -324,7 +326,7 @@ class NaturalPKTest(fixtures.MappedTest): properties={"user": relationship(User, passive_updates=False)}, ) - sess = create_session() + sess = fixture_session() u1 = sess.query(User).first() a1, a2 = sess.query(Address).all() u1.username = "ed" @@ -353,7 +355,7 @@ class NaturalPKTest(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() a1 = Address(email="jack1") a2 = Address(email="jack2") a3 = Address(email="fred") @@ -432,7 +434,7 @@ class NaturalPKTest(fixtures.MappedTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() u1 = User(username="jack", fullname="jack") sess.add(u1) sess.flush() @@ -487,7 +489,7 @@ class NaturalPKTest(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session(autoflush=False) a1 = Address(email="jack1") a2 = Address(email="jack2") @@ -570,7 +572,7 @@ class NaturalPKTest(fixtures.MappedTest): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() u1 = User(username="jack") u2 = User(username="fred") i1 = Item(itemname="item1") @@ -639,7 +641,7 @@ class NaturalPKTest(fixtures.MappedTest): }, ) - s = Session() + s = fixture_session() a1 = Address(email="jack1") u1 = User(username="jack", fullname="jack") @@ -756,7 +758,7 @@ class TransientExceptionTesst(_fixtures.FixtureTest): mapper(User, users) mapper(Address, addresses, properties={"user": relationship(User)}) - sess = create_session() + sess = fixture_session() u1 = User(id=5, name="u1") ad1 = Address(email_address="e1", user=u1) sess.add_all([u1, ad1]) @@ -809,7 +811,7 @@ class ReversePKsTest(fixtures.MappedTest): mapper(User, user) - session = sa.orm.sessionmaker()() + session = fixture_session() a_published = User(1, PUBLISHED, "a") session.add(a_published) @@ -849,7 +851,7 @@ class ReversePKsTest(fixtures.MappedTest): mapper(User, user) - session = sa.orm.sessionmaker()() + session = fixture_session() a_published = User(1, PUBLISHED, "a") session.add(a_published) @@ -916,7 +918,7 @@ class SelfReferentialTest(fixtures.MappedTest): }, ) - sess = Session() + sess = fixture_session() n1 = Node(name="n1") sess.add(n1) n2 = Node(name="n11", parentnode=n1) @@ -954,7 +956,7 @@ class SelfReferentialTest(fixtures.MappedTest): }, ) - sess = Session() + sess = fixture_session() n1 = Node(name="n1") n1.children.append(Node(name="n11")) n1.children.append(Node(name="n12")) @@ -995,7 +997,7 @@ class SelfReferentialTest(fixtures.MappedTest): }, ) - sess = Session() + sess = fixture_session() n1 = Node(name="n1") n11 = Node(name="n11", parentnode=n1) n12 = Node(name="n12", parentnode=n1) @@ -1082,7 +1084,7 @@ class NonPKCascadeTest(fixtures.MappedTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() u1 = User(username="jack", fullname="jack") u1.addresses.append(Address(email="jack1")) u1.addresses.append(Address(email="jack2")) @@ -1230,7 +1232,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() a1 = Address(username="ed", email="ed@host1") u1 = User(username="ed", addresses=[a1]) u2 = User(username="jack") @@ -1271,7 +1273,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) a1 = Address(username="ed", email="ed@host1") u1 = User(username="ed", addresses=[a1]) u2 = User(username="jack") @@ -1318,7 +1320,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): }, ) - sess = create_session() + sess = fixture_session() u1 = User(username="jack") if uselist: a1 = Address(user=[u1], email="foo@bar") @@ -1358,7 +1360,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): }, ) - sess = create_session() + sess = fixture_session() u1 = User(username="jack") u2 = User(username="ed") a1 = Address(user=u1, email="foo@bar") @@ -1383,7 +1385,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): properties={"user": relationship(User, passive_updates=True)}, ) - sess = create_session() + sess = fixture_session() u1 = User(username="ed") a1 = Address(user=u1, email="ed@host1") @@ -1445,7 +1447,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() a1, a2 = ( Address(username="ed", email="ed@host1"), Address(username="ed", email="ed@host2"), @@ -1625,7 +1627,7 @@ class JoinedInheritanceTest(fixtures.MappedTest): def _test_pk(self, passive_updates): (Engineer,) = self.classes("Engineer") self._mapping_fixture(False, passive_updates) - sess = sa.orm.sessionmaker()() + sess = fixture_session() e1 = Engineer(name="dilbert", primary_language="java") sess.add(e1) @@ -1655,7 +1657,7 @@ class JoinedInheritanceTest(fixtures.MappedTest): self._mapping_fixture(False, passive_updates) - sess = sa.orm.sessionmaker()() + sess = fixture_session() m1 = Manager(name="dogbert", paperwork="lots") e1, e2 = ( @@ -1700,7 +1702,7 @@ class JoinedInheritanceTest(fixtures.MappedTest): self._mapping_fixture(True, passive_updates) - sess = sa.orm.sessionmaker()() + sess = fixture_session() o1 = Owner(name="dogbert", owner_name="dog") sess.add(o1) @@ -1738,7 +1740,7 @@ class JoinedInheritanceTest(fixtures.MappedTest): Owner, Engineer = self.classes("Owner", "Engineer") self._mapping_fixture(True, passive_updates) - sess = sa.orm.sessionmaker()() + sess = fixture_session() m1 = Owner(name="dogbert", paperwork="lots", owner_name="dog") e1, e2 = ( @@ -1811,7 +1813,7 @@ class UnsortablePKTest(fixtures.MappedTest): def test_updates_sorted(self): Data = self.classes.Data - s = Session() + s = fixture_session() s.add_all( [ @@ -1903,7 +1905,7 @@ class JoinedInheritancePKOnFKTest(fixtures.MappedTest): polymorphic_identity="engineer", ) - sess = sa.orm.sessionmaker()() + sess = fixture_session() e1 = Engineer(name="dilbert", primary_language="java") sess.add(e1) diff --git a/test/orm/test_of_type.py b/test/orm/test_of_type.py index e40e815aa..bc32e322d 100644 --- a/test/orm/test_of_type.py +++ b/test/orm/test_of_type.py @@ -16,6 +16,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.entities import ComparableEntity +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from .inheritance._poly_fixtures import _PolymorphicAliasedJoins from .inheritance._poly_fixtures import _PolymorphicJoins @@ -33,14 +34,14 @@ class _PolymorphicTestBase(object): __dialect__ = "default" def test_any_one(self): - sess = Session() + sess = fixture_session() any_ = Company.employees.of_type(Engineer).any( Engineer.primary_language == "cobol" ) eq_(sess.query(Company).filter(any_).one(), self.c2) def test_any_two(self): - sess = Session() + sess = fixture_session() calias = aliased(Company) any_ = calias.employees.of_type(Engineer).any( Engineer.primary_language == "cobol" @@ -48,26 +49,26 @@ class _PolymorphicTestBase(object): eq_(sess.query(calias).filter(any_).one(), self.c2) def test_any_three(self): - sess = Session() + sess = fixture_session() any_ = Company.employees.of_type(Boss).any(Boss.golf_swing == "fore") eq_(sess.query(Company).filter(any_).one(), self.c1) def test_any_four(self): - sess = Session() + sess = fixture_session() any_ = Company.employees.of_type(Manager).any( Manager.manager_name == "pointy" ) eq_(sess.query(Company).filter(any_).one(), self.c1) def test_any_five(self): - sess = Session() + sess = fixture_session() any_ = Company.employees.of_type(Engineer).any( and_(Engineer.primary_language == "cobol") ) eq_(sess.query(Company).filter(any_).one(), self.c2) def test_join_to_subclass_one(self): - sess = Session() + sess = fixture_session() eq_( sess.query(Company) .join(Company.employees.of_type(Engineer)) @@ -77,7 +78,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_two(self): - sess = Session() + sess = fixture_session() eq_( sess.query(Company) .join(Company.employees.of_type(Engineer), "machines") @@ -87,7 +88,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_three(self): - sess = Session() + sess = fixture_session() eq_( sess.query(Company, Engineer) .join(Company.employees.of_type(Engineer)) @@ -97,7 +98,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_four(self): - sess = Session() + sess = fixture_session() # test [ticket:2093] eq_( sess.query(Company.company_id, Engineer) @@ -108,7 +109,7 @@ class _PolymorphicTestBase(object): ) def test_join_to_subclass_five(self): - sess = Session() + sess = fixture_session() eq_( sess.query(Company) .join(Company.employees.of_type(Engineer)) @@ -118,7 +119,7 @@ class _PolymorphicTestBase(object): ) def test_with_polymorphic_join_compile_one(self): - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(Company).join( @@ -134,7 +135,7 @@ class _PolymorphicTestBase(object): ) def test_with_polymorphic_join_exec_contains_eager_one(self): - sess = Session() + sess = fixture_session() def go(): wp = with_polymorphic( @@ -163,7 +164,7 @@ class _PolymorphicTestBase(object): def test_with_polymorphic_join_exec_contains_eager_two( self, contains_eager_option ): - sess = Session() + sess = fixture_session() wp = with_polymorphic(Person, [Engineer, Manager], aliased=True) contains_eager_option = testing.resolve_lambda( @@ -187,7 +188,7 @@ class _PolymorphicTestBase(object): ) def test_with_polymorphic_any(self): - sess = Session() + sess = fixture_session() wp = with_polymorphic(Person, [Engineer], aliased=True) eq_( sess.query(Company.company_id) @@ -201,7 +202,7 @@ class _PolymorphicTestBase(object): ) def test_subqueryload_implicit_withpoly(self): - sess = Session() + sess = fixture_session() def go(): eq_( @@ -215,7 +216,7 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, 4) def test_joinedload_implicit_withpoly(self): - sess = Session() + sess = fixture_session() def go(): eq_( @@ -229,7 +230,7 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, 3) def test_subqueryload_explicit_withpoly(self): - sess = Session() + sess = fixture_session() def go(): target = with_polymorphic(Person, Engineer) @@ -244,7 +245,7 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, 4) def test_joinedload_explicit_withpoly(self): - sess = Session() + sess = fixture_session() def go(): target = with_polymorphic(Person, Engineer, flat=True) @@ -259,7 +260,7 @@ class _PolymorphicTestBase(object): self.assert_sql_count(testing.db, go, 3) def test_joinedload_stacked_of_type(self): - sess = Session() + sess = fixture_session() def go(): eq_( @@ -467,7 +468,7 @@ class PolymorphicJoinsTest(_PolymorphicTestBase, _PolymorphicJoins): ) def test_joinedload_explicit_with_unaliased_poly_compile(self): - sess = Session() + sess = fixture_session() target = with_polymorphic(Person, Engineer) q = ( sess.query(Company) @@ -481,7 +482,7 @@ class PolymorphicJoinsTest(_PolymorphicTestBase, _PolymorphicJoins): ) def test_joinedload_explicit_with_flataliased_poly_compile(self): - sess = Session() + sess = fixture_session() target = with_polymorphic(Person, Engineer, flat=True) q = ( sess.query(Company) @@ -752,7 +753,7 @@ class SubclassRelationshipTest( Job_P = with_polymorphic(Job, SubJob, aliased=True, flat=True) - s = Session() + s = fixture_session() q = ( s.query(Job) .join(DataContainer.jobs) @@ -782,7 +783,7 @@ class SubclassRelationshipTest( Job_A = aliased(Job) - s = Session() + s = fixture_session() q = ( s.query(Job) .join(DataContainer.jobs) @@ -814,7 +815,7 @@ class SubclassRelationshipTest( Job_P = with_polymorphic(Job, SubJob) - s = Session() + s = fixture_session() q = s.query(DataContainer).join(DataContainer.jobs.of_type(Job_P)) self.assert_compile( q, @@ -832,7 +833,7 @@ class SubclassRelationshipTest( self.classes.SubJob, ) - s = Session() + s = fixture_session() q = s.query(DataContainer).join(DataContainer.jobs.of_type(SubJob)) # note the of_type() here renders JOIN for the Job->SubJob. # this is because it's using the SubJob mapper directly within @@ -856,7 +857,7 @@ class SubclassRelationshipTest( Job_P = with_polymorphic(Job, SubJob, innerjoin=True) - s = Session() + s = fixture_session() q = s.query(DataContainer).join(DataContainer.jobs.of_type(Job_P)) self.assert_compile( q, @@ -875,7 +876,7 @@ class SubclassRelationshipTest( Job_A = aliased(Job) - s = Session() + s = fixture_session() q = s.query(DataContainer).join(DataContainer.jobs.of_type(Job_A)) self.assert_compile( q, @@ -894,7 +895,7 @@ class SubclassRelationshipTest( Job_P = with_polymorphic(Job, SubJob) - s = Session() + s = fixture_session() q = s.query(DataContainer).join(Job_P, DataContainer.jobs) self.assert_compile( q, @@ -915,7 +916,7 @@ class SubclassRelationshipTest( Job_P = with_polymorphic(Job, SubJob, flat=True) - s = Session() + s = fixture_session() q = s.query(DataContainer).join(Job_P, DataContainer.jobs) self.assert_compile( q, @@ -936,7 +937,7 @@ class SubclassRelationshipTest( Job_P = with_polymorphic(Job, SubJob, aliased=True) - s = Session() + s = fixture_session() q = s.query(DataContainer).join(Job_P, DataContainer.jobs) self.assert_compile( q, @@ -1185,7 +1186,7 @@ class SubclassRelationshipTest3( B1 = aliased(B1, name="bbb") C1 = aliased(C1, name="ccc") - sess = Session() + sess = fixture_session() abc = sess.query(A1) if join_of_type: diff --git a/test/orm/test_onetoone.py b/test/orm/test_onetoone.py index a487c17e6..ae9f9b3a1 100644 --- a/test/orm/test_onetoone.py +++ b/test/orm/test_onetoone.py @@ -1,10 +1,10 @@ from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -59,7 +59,7 @@ class O2OTest(fixtures.MappedTest): ), ) - session = create_session() + session = fixture_session() j = Jack(number="101") session.add(j) diff --git a/test/orm/test_options.py b/test/orm/test_options.py index b4befcea3..b22b318e9 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -9,7 +9,6 @@ from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes from sqlalchemy.orm import class_mapper from sqlalchemy.orm import column_property -from sqlalchemy.orm import create_session from sqlalchemy.orm import defaultload from sqlalchemy.orm import defer from sqlalchemy.orm import exc as orm_exc @@ -18,7 +17,6 @@ from sqlalchemy.orm import Load from sqlalchemy.orm import load_only from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.orm import strategy_options from sqlalchemy.orm import subqueryload from sqlalchemy.orm import synonym @@ -27,6 +25,7 @@ from sqlalchemy.orm import with_polymorphic from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises_message from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures from .inheritance._poly_fixtures import _Polymorphic from .inheritance._poly_fixtures import Company @@ -205,7 +204,7 @@ class LoadTest(PathTest, QueryTest): def test_gen_path_attr_str_not_mapped(self): OrderWProp = self.classes.OrderWProp - sess = Session() + sess = fixture_session() q = sess.query(OrderWProp).options(defer("some_attr")) assert_raises_message( @@ -280,7 +279,7 @@ class OfTypePathingTest(PathTest, QueryTest): SubAddr.sub_attr ) - sess = Session() + sess = fixture_session() q = sess.query(User) self._assert_path_result( l1, @@ -297,7 +296,7 @@ class OfTypePathingTest(PathTest, QueryTest): .defer(SubAddr.sub_attr) ) - sess = Session() + sess = fixture_session() q = sess.query(User) self._assert_path_result( l1, @@ -310,7 +309,7 @@ class OfTypePathingTest(PathTest, QueryTest): l1 = defaultload(User.addresses.of_type(SubAddr)).defer("sub_attr") - sess = Session() + sess = fixture_session() q = sess.query(User) self._assert_path_result( l1, @@ -327,7 +326,7 @@ class OfTypePathingTest(PathTest, QueryTest): .defer("sub_attr") ) - sess = Session() + sess = fixture_session() q = sess.query(User) self._assert_path_result( l1, @@ -342,7 +341,7 @@ class OfTypePathingTest(PathTest, QueryTest): SubAddr.dings ) - sess = Session() + sess = fixture_session() q = sess.query(User) self._assert_path_result( l1, q, [(User, "addresses"), (User, "addresses", SubAddr, "dings")] @@ -357,7 +356,7 @@ class OfTypePathingTest(PathTest, QueryTest): .joinedload(SubAddr.dings) ) - sess = Session() + sess = fixture_session() q = sess.query(User) self._assert_path_result( l1, q, [(User, "addresses"), (User, "addresses", SubAddr, "dings")] @@ -368,7 +367,7 @@ class OfTypePathingTest(PathTest, QueryTest): l1 = defaultload(User.addresses.of_type(SubAddr)).joinedload("dings") - sess = Session() + sess = fixture_session() q = sess.query(User) self._assert_path_result( l1, q, [(User, "addresses"), (User, "addresses", SubAddr, "dings")] @@ -383,7 +382,7 @@ class OfTypePathingTest(PathTest, QueryTest): .defer("sub_attr") ) - sess = Session() + sess = fixture_session() q = sess.query(User) self._assert_path_result( l1, @@ -401,7 +400,7 @@ class OptionsTest(PathTest, QueryTest): def test_get_path_one_level_string(self): User = self.classes.User - sess = Session() + sess = fixture_session() q = sess.query(User) opt = self._option_fixture("addresses") @@ -410,7 +409,7 @@ class OptionsTest(PathTest, QueryTest): def test_get_path_one_level_attribute(self): User = self.classes.User - sess = Session() + sess = fixture_session() q = sess.query(User) opt = self._option_fixture(User.addresses) @@ -422,7 +421,7 @@ class OptionsTest(PathTest, QueryTest): # ensure "current path" is fully consumed before # matching against current entities. # see [ticket:2098] - sess = Session() + sess = fixture_session() q = sess.query(User) opt = self._option_fixture("email_address", "id") q = sess.query(Address)._with_current_path( @@ -435,7 +434,7 @@ class OptionsTest(PathTest, QueryTest): def test_get_path_one_level_with_unrelated(self): Order = self.classes.Order - sess = Session() + sess = fixture_session() q = sess.query(Order) opt = self._option_fixture("addresses") self._assert_path_result(opt, q, []) @@ -447,7 +446,7 @@ class OptionsTest(PathTest, QueryTest): self.classes.Order, ) - sess = Session() + sess = fixture_session() q = sess.query(User) opt = self._option_fixture("orders.items.keywords") @@ -468,7 +467,7 @@ class OptionsTest(PathTest, QueryTest): self.classes.Order, ) - sess = Session() + sess = fixture_session() q = sess.query(User) opt = self._option_fixture(User.orders, Order.items, Item.keywords) @@ -489,7 +488,7 @@ class OptionsTest(PathTest, QueryTest): self.classes.Order, ) - sess = Session() + sess = fixture_session() q = sess.query(Item)._with_current_path( self._make_path_registry([User, "orders", Order, "items"]) ) @@ -504,7 +503,7 @@ class OptionsTest(PathTest, QueryTest): self.classes.Order, ) - sess = Session() + sess = fixture_session() q = sess.query(Item)._with_current_path( self._make_path_registry([User, "orders", Order, "items"]) ) @@ -519,7 +518,7 @@ class OptionsTest(PathTest, QueryTest): self.classes.Order, ) - sess = Session() + sess = fixture_session() q = sess.query(Item)._with_current_path( self._make_path_registry([User, "orders", Order, "items"]) ) @@ -537,7 +536,7 @@ class OptionsTest(PathTest, QueryTest): self.classes.Order, ) - sess = Session() + sess = fixture_session() q = sess.query(Item)._with_current_path( self._make_path_registry([User, "orders", Order, "items"]) ) @@ -555,7 +554,7 @@ class OptionsTest(PathTest, QueryTest): self.classes.Order, ) - sess = Session() + sess = fixture_session() q = sess.query(Item)._with_current_path( self._make_path_registry( [inspect(aliased(User)), "orders", Order, "items"] @@ -588,7 +587,7 @@ class OptionsTest(PathTest, QueryTest): ) ac = aliased(User) - sess = Session() + sess = fixture_session() q = sess.query(Item)._with_current_path( self._make_path_registry([inspect(ac), "orders", Order, "items"]) ) @@ -602,7 +601,7 @@ class OptionsTest(PathTest, QueryTest): def test_from_base_to_subclass_attr(self): Dingaling, Address = self.classes.Dingaling, self.classes.Address - sess = Session() + sess = fixture_session() class SubAddr(Address): pass @@ -621,7 +620,7 @@ class OptionsTest(PathTest, QueryTest): def test_from_subclass_to_subclass_attr(self): Dingaling, Address = self.classes.Dingaling, self.classes.Address - sess = Session() + sess = fixture_session() class SubAddr(Address): pass @@ -640,7 +639,7 @@ class OptionsTest(PathTest, QueryTest): def test_from_base_to_base_attr_via_subclass(self): Dingaling, Address = self.classes.Dingaling, self.classes.Address - sess = Session() + sess = fixture_session() class SubAddr(Address): pass @@ -661,7 +660,7 @@ class OptionsTest(PathTest, QueryTest): def test_of_type(self): User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() class SubAddr(Address): pass @@ -692,7 +691,7 @@ class OptionsTest(PathTest, QueryTest): def test_of_type_string_attr(self): User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() class SubAddr(Address): pass @@ -725,7 +724,7 @@ class OptionsTest(PathTest, QueryTest): self.classes.Address, ) - sess = Session() + sess = fixture_session() class SubAddr(Address): pass @@ -760,7 +759,7 @@ class OptionsTest(PathTest, QueryTest): def test_aliased_single(self): User = self.classes.User - sess = Session() + sess = fixture_session() ualias = aliased(User) q = sess.query(ualias) opt = self._option_fixture(ualias.addresses) @@ -769,7 +768,7 @@ class OptionsTest(PathTest, QueryTest): def test_with_current_aliased_single(self): User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() ualias = aliased(User) q = sess.query(ualias)._with_current_path( self._make_path_registry([Address, "user"]) @@ -780,7 +779,7 @@ class OptionsTest(PathTest, QueryTest): def test_with_current_aliased_single_nonmatching_option(self): User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() ualias = aliased(User) q = sess.query(User)._with_current_path( self._make_path_registry([Address, "user"]) @@ -791,7 +790,7 @@ class OptionsTest(PathTest, QueryTest): def test_with_current_aliased_single_nonmatching_entity(self): User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() ualias = aliased(User) q = sess.query(ualias)._with_current_path( self._make_path_registry([Address, "user"]) @@ -803,7 +802,7 @@ class OptionsTest(PathTest, QueryTest): Item = self.classes.Item Order = self.classes.Order opt = self._option_fixture(Order.items) - sess = Session() + sess = fixture_session() q = sess.query(Item, Order) self._assert_path_result(opt, q, [(Order, "items")]) @@ -811,7 +810,7 @@ class OptionsTest(PathTest, QueryTest): Item = self.classes.Item Order = self.classes.Order opt = self._option_fixture("items") - sess = Session() + sess = fixture_session() q = sess.query(Item, Order) self._assert_path_result(opt, q, []) @@ -819,7 +818,7 @@ class OptionsTest(PathTest, QueryTest): Item = self.classes.Item Order = self.classes.Order opt = self._option_fixture("items") - sess = Session() + sess = fixture_session() q = sess.query(Item.id, Order.id) self._assert_path_result(opt, q, []) @@ -828,7 +827,7 @@ class OptionsTest(PathTest, QueryTest): Item = self.classes.Item Order = self.classes.Order opt = self._option_fixture(User.orders) - sess = Session() + sess = fixture_session() q = sess.query(Item)._with_current_path( self._make_path_registry([User, "orders", Order, "items"]) ) @@ -837,7 +836,7 @@ class OptionsTest(PathTest, QueryTest): def test_chained(self): User = self.classes.User Order = self.classes.Order - sess = Session() + sess = fixture_session() q = sess.query(User) opt = self._option_fixture(User.orders).joinedload("items") self._assert_path_result( @@ -848,7 +847,7 @@ class OptionsTest(PathTest, QueryTest): User = self.classes.User Order = self.classes.Order Item = self.classes.Item - sess = Session() + sess = fixture_session() q = sess.query(User) opt = self._option_fixture("orders.items").joinedload("keywords") self._assert_path_result( @@ -865,7 +864,7 @@ class OptionsTest(PathTest, QueryTest): User = self.classes.User Order = self.classes.Order Item = self.classes.Item - sess = Session() + sess = fixture_session() q = sess.query(User) opt = self._option_fixture(User.orders, Order.items).joinedload( "keywords" @@ -918,7 +917,7 @@ class FromSubclassOptionsTest(PathTest, fixtures.DeclarativeMappedTest): BaseCls, SubClass, Related, SubRelated = self.classes( "BaseCls", "SubClass", "Related", "SubRelated" ) - sess = Session() + sess = fixture_session() q = sess.query(Related)._with_current_path( self._make_path_registry([inspect(SubClass), "related"]) @@ -1258,7 +1257,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Item = self.classes.Item context = ( - create_session() + fixture_session() .query(*entity_list) .options(joinedload(option)) ._compile_state() @@ -1270,7 +1269,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): assert_raises_message( orm_exc.LoaderStrategyException, message, - create_session() + fixture_session() .query(*entity_list) .options(*options) ._compile_state, @@ -1282,7 +1281,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): assert_raises_message( sa.exc.ArgumentError, message, - create_session() + fixture_session() .query(*entity_list) .options(*options) ._compile_state, @@ -1294,7 +1293,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): assert_raises_message( sa.exc.ArgumentError, message, - create_session() + fixture_session() .query(column) .options(joinedload(eager_option)) ._compile_state, @@ -1303,7 +1302,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): class OptionsNoPropTestInh(_Polymorphic): def test_missing_attr_wpoly_subclasss(self): - s = Session() + s = fixture_session() wp = with_polymorphic(Person, [Manager], flat=True) @@ -1316,7 +1315,7 @@ class OptionsNoPropTestInh(_Polymorphic): ) def test_missing_attr_of_type_subclass(self): - s = Session() + s = fixture_session() assert_raises_message( sa.exc.ArgumentError, @@ -1332,7 +1331,7 @@ class OptionsNoPropTestInh(_Polymorphic): ) def test_missing_attr_of_type_subclass_name_matches(self): - s = Session() + s = fixture_session() # the name "status" is present on Engineer also, make sure # that doesn't get mixed up here @@ -1350,7 +1349,7 @@ class OptionsNoPropTestInh(_Polymorphic): ) def test_missing_str_attr_of_type_subclass(self): - s = Session() + s = fixture_session() assert_raises_message( sa.exc.ArgumentError, @@ -1366,7 +1365,7 @@ class OptionsNoPropTestInh(_Polymorphic): ) def test_missing_attr_of_type_wpoly_subclass(self): - s = Session() + s = fixture_session() wp = with_polymorphic(Person, [Manager], flat=True) @@ -1384,7 +1383,7 @@ class OptionsNoPropTestInh(_Polymorphic): ) def test_missing_attr_is_missing_of_type_for_alias(self): - s = Session() + s = fixture_session() pa = aliased(Person) @@ -1465,7 +1464,7 @@ class PickleTest(PathTest, QueryTest): opt.__setstate__(state) - query = create_session().query(User) + query = fixture_session().query(User) attr = {} load = opt._bind_loader( [ @@ -1502,7 +1501,7 @@ class PickleTest(PathTest, QueryTest): opt.__setstate__(state) - query = create_session().query(User) + query = fixture_session().query(User) attr = {} load = opt._bind_loader( [ @@ -1545,7 +1544,7 @@ class LocalOptsTest(PathTest, QueryTest): def _assert_attrs(self, opts, expected): User = self.classes.User - query = create_session().query(User) + query = fixture_session().query(User) attr = {} for opt in opts: @@ -1702,7 +1701,7 @@ class SubOptionsTest(PathTest, QueryTest): defaultload(User.orders).defer(Order.description), ] - sess = Session() + sess = fixture_session() self._assert_opts(sess.query(User), sub_opt, non_sub_opts) def test_two(self): @@ -1721,7 +1720,7 @@ class SubOptionsTest(PathTest, QueryTest): defaultload(User.orders).defer(Order.description), ] - sess = Session() + sess = fixture_session() self._assert_opts(sess.query(User), sub_opt, non_sub_opts) def test_three(self): @@ -1730,7 +1729,7 @@ class SubOptionsTest(PathTest, QueryTest): ) sub_opt = defaultload(User.orders).options(defer("*")) non_sub_opts = [defaultload(User.orders).defer("*")] - sess = Session() + sess = fixture_session() self._assert_opts(sess.query(User), sub_opt, non_sub_opts) def test_four(self): @@ -1759,7 +1758,7 @@ class SubOptionsTest(PathTest, QueryTest): .defaultload(Item.keywords) .defer(Keyword.name), ] - sess = Session() + sess = fixture_session() self._assert_opts(sess.query(User), sub_opt, non_sub_opts) def test_four_strings(self): @@ -1788,7 +1787,7 @@ class SubOptionsTest(PathTest, QueryTest): .defaultload(Item.keywords) .defer(Keyword.name), ] - sess = Session() + sess = fixture_session() self._assert_opts(sess.query(User), sub_opt, non_sub_opts) def test_five(self): @@ -1800,7 +1799,7 @@ class SubOptionsTest(PathTest, QueryTest): joinedload(User.orders), defaultload(User.orders).load_only(Order.description), ] - sess = Session() + sess = fixture_session() self._assert_opts(sess.query(User), sub_opt, non_sub_opts) def test_five_strings(self): @@ -1812,7 +1811,7 @@ class SubOptionsTest(PathTest, QueryTest): joinedload(User.orders), defaultload(User.orders).load_only(Order.description), ] - sess = Session() + sess = fixture_session() self._assert_opts(sess.query(User), sub_opt, non_sub_opts) def test_invalid_one(self): @@ -1832,7 +1831,7 @@ class SubOptionsTest(PathTest, QueryTest): joinedload(User.orders).joinedload(Item.keywords), defaultload(User.orders).joinedload(Order.items), ] - sess = Session() + sess = fixture_session() self._assert_opts(sess.query(User), sub_opt, non_sub_opts) def test_invalid_two(self): @@ -1852,7 +1851,7 @@ class SubOptionsTest(PathTest, QueryTest): joinedload(User.orders).joinedload(Item.keywords), defaultload(User.orders).joinedload(Order.items), ] - sess = Session() + sess = fixture_session() self._assert_opts(sess.query(User), sub_opt, non_sub_opts) def test_not_implemented_fromload(self): @@ -1905,7 +1904,7 @@ class MapperOptionsTest(_fixtures.FixtureTest): ) def go(): - sess = create_session() + sess = fixture_session() u = ( sess.query(User) .order_by(User.id) @@ -1936,7 +1935,7 @@ class MapperOptionsTest(_fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() result = ( sess.query(User) .order_by(User.id) @@ -1966,7 +1965,7 @@ class MapperOptionsTest(_fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() u = ( sess.query(User) .options(sa.orm.joinedload("addresses")) @@ -2003,7 +2002,7 @@ class MapperOptionsTest(_fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() u = ( sess.query(User) .options(sa.orm.lazyload("addresses")) @@ -2039,7 +2038,7 @@ class MapperOptionsTest(_fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() # first test straight eager load, 1 statement def go(): @@ -2054,7 +2053,7 @@ class MapperOptionsTest(_fixtures.FixtureTest): # then assert the data, which will launch 3 more lazy loads # (previous users in session fell out of scope and were removed from # session's identity map) - r = users.select().order_by(users.c.id).execute() + r = sess.connection().execute(users.select().order_by(users.c.id)) ctx = sess.query(User)._compile_context() @@ -2140,7 +2139,7 @@ class MapperOptionsTest(_fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() # first test straight eager load, 1 statement def go(): @@ -2153,7 +2152,7 @@ class MapperOptionsTest(_fixtures.FixtureTest): # then select just from users. run it into instances. # then assert the data, which will launch 6 more lazy loads - r = users.select().execute() + r = sess.connection().execute(users.select()) ctx = sess.query(User)._compile_context() @@ -2183,7 +2182,7 @@ class MapperOptionsTest(_fixtures.FixtureTest): ), ) - sess = create_session() + sess = fixture_session() result = ( sess.query(User) .order_by(User.id) @@ -2214,7 +2213,7 @@ class MapperOptionsTest(_fixtures.FixtureTest): ) mapper(Item, items) - sess = create_session() + sess = fixture_session() oalias = aliased(Order) opt1 = sa.orm.joinedload(User.orders, Order.items) diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index a5a983740..189fd2d27 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -9,14 +9,11 @@ from sqlalchemy import testing from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes from sqlalchemy.orm import clear_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import instrumentation from sqlalchemy.orm import lazyload from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session -from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import state as sa_state from sqlalchemy.orm import subqueryload from sqlalchemy.orm import with_polymorphic @@ -25,6 +22,7 @@ from sqlalchemy.orm.collections import column_mapped_collection from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.pickleable import Address from sqlalchemy.testing.pickleable import Child1 from sqlalchemy.testing.pickleable import Child2 @@ -114,7 +112,7 @@ class PickleTest(fixtures.MappedTest): properties={"dingaling": relationship(Dingaling)}, ) mapper(Dingaling, dingalings) - sess = create_session() + sess = fixture_session() u1 = User(name="ed") u1.addresses.append(Address(email_address="ed@bar.com")) sess.add(u1) @@ -132,7 +130,7 @@ class PickleTest(fixtures.MappedTest): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() u1 = User(name="ed") u1.addresses.append(Address(email_address="ed@bar.com")) @@ -195,33 +193,38 @@ class PickleTest(fixtures.MappedTest): "email_address": sa.orm.deferred(addresses.c.email_address) }, ) - sess = create_session() - u1 = User(name="ed") - u1.addresses.append(Address(email_address="ed@bar.com")) - sess.add(u1) - sess.flush() - sess.expunge_all() - u1 = sess.query(User).get(u1.id) - assert "name" not in u1.__dict__ - assert "addresses" not in u1.__dict__ + with fixture_session(expire_on_commit=False) as sess: + u1 = User(name="ed") + u1.addresses.append(Address(email_address="ed@bar.com")) + sess.add(u1) + sess.commit() + + with fixture_session() as sess: + u1 = sess.query(User).get(u1.id) + assert "name" not in u1.__dict__ + assert "addresses" not in u1.__dict__ u2 = pickle.loads(pickle.dumps(u1)) - sess2 = create_session() - sess2.add(u2) - eq_(u2.name, "ed") - eq_( - u2, - User(name="ed", addresses=[Address(email_address="ed@bar.com")]), - ) + with fixture_session() as sess2: + sess2.add(u2) + eq_(u2.name, "ed") + eq_( + u2, + User( + name="ed", addresses=[Address(email_address="ed@bar.com")] + ), + ) u2 = pickle.loads(pickle.dumps(u1)) - sess2 = create_session() - u2 = sess2.merge(u2, load=False) - eq_(u2.name, "ed") - eq_( - u2, - User(name="ed", addresses=[Address(email_address="ed@bar.com")]), - ) + with fixture_session() as sess2: + u2 = sess2.merge(u2, load=False) + eq_(u2.name, "ed") + eq_( + u2, + User( + name="ed", addresses=[Address(email_address="ed@bar.com")] + ), + ) def test_instance_lazy_relation_loaders(self): users, addresses = (self.tables.users, self.tables.addresses) @@ -233,7 +236,7 @@ class PickleTest(fixtures.MappedTest): ) mapper(Address, addresses) - sess = Session() + sess = fixture_session() u1 = User(name="ed", addresses=[Address(email_address="ed@bar.com")]) sess.add(u1) @@ -243,7 +246,7 @@ class PickleTest(fixtures.MappedTest): u1 = sess.query(User).options(lazyload(User.addresses)).first() u2 = pickle.loads(pickle.dumps(u1)) - sess = Session() + sess = fixture_session() sess.add(u2) assert u2.addresses @@ -290,52 +293,57 @@ class PickleTest(fixtures.MappedTest): ) mapper(Address, addresses) - sess = create_session() - u1 = User(name="ed") - u1.addresses.append(Address(email_address="ed@bar.com")) - sess.add(u1) - sess.flush() - sess.expunge_all() - - u1 = ( - sess.query(User) - .options( - sa.orm.defer("name"), sa.orm.defer("addresses.email_address") + with fixture_session(expire_on_commit=False) as sess: + u1 = User(name="ed") + u1.addresses.append(Address(email_address="ed@bar.com")) + sess.add(u1) + sess.commit() + + with fixture_session(expire_on_commit=False) as sess: + u1 = ( + sess.query(User) + .options( + sa.orm.defer("name"), + sa.orm.defer("addresses.email_address"), + ) + .get(u1.id) ) - .get(u1.id) - ) - assert "name" not in u1.__dict__ - assert "addresses" not in u1.__dict__ + assert "name" not in u1.__dict__ + assert "addresses" not in u1.__dict__ u2 = pickle.loads(pickle.dumps(u1)) - sess2 = create_session() - sess2.add(u2) - eq_(u2.name, "ed") - assert "addresses" not in u2.__dict__ - ad = u2.addresses[0] - assert "email_address" not in ad.__dict__ - eq_(ad.email_address, "ed@bar.com") - eq_( - u2, - User(name="ed", addresses=[Address(email_address="ed@bar.com")]), - ) + with fixture_session() as sess2: + sess2.add(u2) + eq_(u2.name, "ed") + assert "addresses" not in u2.__dict__ + ad = u2.addresses[0] + assert "email_address" not in ad.__dict__ + eq_(ad.email_address, "ed@bar.com") + eq_( + u2, + User( + name="ed", addresses=[Address(email_address="ed@bar.com")] + ), + ) u2 = pickle.loads(pickle.dumps(u1)) - sess2 = create_session() - u2 = sess2.merge(u2, load=False) - eq_(u2.name, "ed") - assert "addresses" not in u2.__dict__ - ad = u2.addresses[0] + with fixture_session() as sess2: + u2 = sess2.merge(u2, load=False) + eq_(u2.name, "ed") + assert "addresses" not in u2.__dict__ + ad = u2.addresses[0] - # mapper options now transmit over merge(), - # new as of 0.6, so email_address is deferred. - assert "email_address" not in ad.__dict__ + # mapper options now transmit over merge(), + # new as of 0.6, so email_address is deferred. + assert "email_address" not in ad.__dict__ - eq_(ad.email_address, "ed@bar.com") - eq_( - u2, - User(name="ed", addresses=[Address(email_address="ed@bar.com")]), - ) + eq_(ad.email_address, "ed@bar.com") + eq_( + u2, + User( + name="ed", addresses=[Address(email_address="ed@bar.com")] + ), + ) def test_pickle_protocols(self): users, addresses = (self.tables.users, self.tables.addresses) @@ -347,7 +355,7 @@ class PickleTest(fixtures.MappedTest): ) mapper(Address, addresses) - sess = sessionmaker()() + sess = fixture_session() u1 = User(name="ed") u1.addresses.append(Address(email_address="ed@bar.com")) sess.add(u1) @@ -363,7 +371,7 @@ class PickleTest(fixtures.MappedTest): def test_09_pickle(self): users = self.tables.users mapper(User, users) - sess = Session() + sess = fixture_session() sess.add(User(id=1, name="ed")) sess.commit() sess.close() @@ -389,7 +397,7 @@ class PickleTest(fixtures.MappedTest): state.__setstate__(state_09) eq_(state.expired_attributes, {"name", "id"}) - sess = Session() + sess = fixture_session() sess.add(inst) eq_(inst.name, "ed") # test identity_token expansion @@ -398,7 +406,7 @@ class PickleTest(fixtures.MappedTest): def test_11_pickle(self): users = self.tables.users mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(id=1, name="ed") sess.add(u1) sess.commit() @@ -658,7 +666,7 @@ class OptionsTest(_Polymorphic): eq_(opt2.__getstate__()["path"], serialized) def test_load(self): - s = Session() + s = fixture_session() with_poly = with_polymorphic(Person, [Engineer, Manager], flat=True) emp = ( @@ -706,17 +714,17 @@ class PolymorphicDeferredTest(fixtures.MappedTest): ) eu = EmailUser(name="user1", email_address="foo@bar.com") - sess = create_session() - sess.add(eu) - sess.flush() - sess.expunge_all() + with fixture_session() as sess: + sess.add(eu) + sess.commit() - eu = sess.query(User).first() - eu2 = pickle.loads(pickle.dumps(eu)) - sess2 = create_session() - sess2.add(eu2) - assert "email_address" not in eu2.__dict__ - eq_(eu2.email_address, "foo@bar.com") + with fixture_session() as sess: + eu = sess.query(User).first() + eu2 = pickle.loads(pickle.dumps(eu)) + sess2 = fixture_session() + sess2.add(eu2) + assert "email_address" not in eu2.__dict__ + eq_(eu2.email_address, "foo@bar.com") class TupleLabelTest(_fixtures.FixtureTest): @@ -750,7 +758,7 @@ class TupleLabelTest(_fixtures.FixtureTest): ) # m2o def test_tuple_labeling(self): - sess = create_session() + sess = fixture_session() # test pickle + all the protocols ! for pickled in False, -1, 0, 1, 2: diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 9e528dc0d..fd8e849fb 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -22,7 +22,6 @@ from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import literal from sqlalchemy import literal_column -from sqlalchemy import MetaData from sqlalchemy import null from sqlalchemy import or_ from sqlalchemy import select @@ -43,7 +42,6 @@ from sqlalchemy.orm import backref from sqlalchemy.orm import Bundle from sqlalchemy.orm import column_property from sqlalchemy.orm import contains_eager -from sqlalchemy.orm import create_session from sqlalchemy.orm import defer from sqlalchemy.orm import joinedload from sqlalchemy.orm import lazyload @@ -70,9 +68,9 @@ from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message from sqlalchemy.testing.assertions import eq_ -from sqlalchemy.testing.assertions import eq_ignore_whitespace from sqlalchemy.testing.assertions import expect_warnings from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.util import collections_abc @@ -95,8 +93,8 @@ class MiscTest(QueryTest): def test_with_session(self): User = self.classes.User - s1 = Session() - s2 = Session() + s1 = fixture_session() + s2 = fixture_session() q1 = s1.query(User) q2 = q1.with_session(s2) assert q2.session is s2 @@ -106,14 +104,14 @@ class MiscTest(QueryTest): class OnlyReturnTuplesTest(QueryTest): def test_single_entity_false(self): User = self.classes.User - query = create_session().query(User).only_return_tuples(False) + query = fixture_session().query(User).only_return_tuples(False) is_true(query.is_single_entity) row = query.first() assert isinstance(row, User) def test_single_entity_true(self): User = self.classes.User - query = create_session().query(User).only_return_tuples(True) + query = fixture_session().query(User).only_return_tuples(True) is_false(query.is_single_entity) row = query.first() assert isinstance(row, collections_abc.Sequence) @@ -121,7 +119,9 @@ class OnlyReturnTuplesTest(QueryTest): def test_multiple_entity_false(self): User = self.classes.User - query = create_session().query(User.id, User).only_return_tuples(False) + query = ( + fixture_session().query(User.id, User).only_return_tuples(False) + ) is_false(query.is_single_entity) row = query.first() assert isinstance(row, collections_abc.Sequence) @@ -129,7 +129,7 @@ class OnlyReturnTuplesTest(QueryTest): def test_multiple_entity_true(self): User = self.classes.User - query = create_session().query(User.id, User).only_return_tuples(True) + query = fixture_session().query(User.id, User).only_return_tuples(True) is_false(query.is_single_entity) row = query.first() assert isinstance(row, collections_abc.Sequence) @@ -145,7 +145,7 @@ class RowTupleTest(QueryTest): mapper(User, users, properties={"uname": users.c.name}) row = ( - create_session() + fixture_session() .query(User.id, User.uname) .filter(User.id == 7) .first() @@ -167,7 +167,7 @@ class RowTupleTest(QueryTest): mapper(User, users) - s = Session() + s = fixture_session() q = testing.resolve_lambda(test_case, **locals()) @@ -189,7 +189,7 @@ class RowTupleTest(QueryTest): mapper(User, users) - s = Session(testing.db) + s = fixture_session() q = testing.resolve_lambda(test_case, **locals()) @@ -207,7 +207,7 @@ class RowTupleTest(QueryTest): mapper(User, users) - s = Session() + s = fixture_session() q = testing.resolve_lambda(test_case, **locals()) @@ -223,7 +223,7 @@ class RowTupleTest(QueryTest): mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) - s = Session() + s = fixture_session() row = s.query(User).only_return_tuples(True).first() eq_(row._mapping[User], row[0]) @@ -451,7 +451,7 @@ class RowTupleTest(QueryTest): mapper(User, users) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() user_alias = aliased(User) user_alias_id_label = user_alias.id.label("foo") address_alias = aliased(Address, name="aalias") @@ -481,7 +481,7 @@ class RowTupleTest(QueryTest): mapper(User, users) - s = Session() + s = fixture_session() q = s.query(User, type_coerce(users.c.id, MyType).label("foo")).filter( User.id == 7 ) @@ -489,78 +489,11 @@ class RowTupleTest(QueryTest): eq_(row, (User(id=7), [7])) -class BindSensitiveStringifyTest(fixtures.TestBase): - def _fixture(self, bind_to=None): - # building a totally separate metadata /mapping here - # because we need to control if the MetaData is bound or not - - class User(object): - pass - - m = MetaData(bind=bind_to) - user_table = Table( - "users", - m, - Column("id", Integer, primary_key=True), - Column("name", String(50)), - ) - - mapper(User, user_table) - return User - - def _dialect_fixture(self): - class MyDialect(default.DefaultDialect): - default_paramstyle = "qmark" - - from sqlalchemy.engine import base - - return base.Engine(mock.Mock(), MyDialect(), mock.Mock()) - - def _test( - self, bound_metadata, bound_session, session_present, expect_bound - ): - if bound_metadata or bound_session: - eng = self._dialect_fixture() - else: - eng = None - - User = self._fixture(bind_to=eng if bound_metadata else None) - - s = Session(eng if bound_session else None) - q = s.query(User).filter(User.id == 7) - if not session_present: - q = q.with_session(None) - - eq_ignore_whitespace( - str(q), - "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id = ?" - if expect_bound - else "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id = :id_1", - ) - - def test_query_unbound_metadata_bound_session(self): - self._test(False, True, True, True) - - def test_query_bound_metadata_unbound_session(self): - self._test(True, False, True, True) - - def test_query_unbound_metadata_no_session(self): - self._test(False, False, False, False) - - def test_query_unbound_metadata_unbound_session(self): - self._test(False, False, True, False) - - def test_query_bound_metadata_bound_session(self): - self._test(True, True, True, True) - - class GetTest(QueryTest): def test_loader_options(self): User = self.classes.User - s = Session() + s = fixture_session() u1 = s.query(User).options(joinedload(User.addresses)).get(8) eq_(len(u1.__dict__["addresses"]), 3) @@ -568,7 +501,7 @@ class GetTest(QueryTest): def test_loader_options_future(self): User = self.classes.User - s = Session() + s = fixture_session() u1 = s.get(User, 8, options=[joinedload(User.addresses)]) eq_(len(u1.__dict__["addresses"]), 3) @@ -576,13 +509,13 @@ class GetTest(QueryTest): def test_get_composite_pk_keyword_based_no_result(self): CompositePk = self.classes.CompositePk - s = Session() + s = fixture_session() is_(s.query(CompositePk).get({"i": 100, "j": 100}), None) def test_get_composite_pk_keyword_based_result(self): CompositePk = self.classes.CompositePk - s = Session() + s = fixture_session() one_two = s.query(CompositePk).get({"i": 1, "j": 2}) eq_(one_two.i, 1) eq_(one_two.j, 2) @@ -591,21 +524,21 @@ class GetTest(QueryTest): def test_get_composite_pk_keyword_based_wrong_keys(self): CompositePk = self.classes.CompositePk - s = Session() + s = fixture_session() q = s.query(CompositePk) assert_raises(sa_exc.InvalidRequestError, q.get, {"i": 1, "k": 2}) def test_get_composite_pk_keyword_based_too_few_keys(self): CompositePk = self.classes.CompositePk - s = Session() + s = fixture_session() q = s.query(CompositePk) assert_raises(sa_exc.InvalidRequestError, q.get, {"i": 1}) def test_get_composite_pk_keyword_based_too_many_keys(self): CompositePk = self.classes.CompositePk - s = Session() + s = fixture_session() q = s.query(CompositePk) assert_raises( sa_exc.InvalidRequestError, q.get, {"i": 1, "j": "2", "k": 3} @@ -614,7 +547,7 @@ class GetTest(QueryTest): def test_get(self): User = self.classes.User - s = create_session() + s = fixture_session() assert s.query(User).get(19) is None u = s.query(User).get(7) u2 = s.query(User).get(7) @@ -626,7 +559,7 @@ class GetTest(QueryTest): def test_get_future(self): User = self.classes.User - s = create_session() + s = fixture_session() assert s.get(User, 19) is None u = s.get(User, 7) u2 = s.get(User, 7) @@ -638,13 +571,13 @@ class GetTest(QueryTest): def test_get_composite_pk_no_result(self): CompositePk = self.classes.CompositePk - s = Session() + s = fixture_session() assert s.query(CompositePk).get((100, 100)) is None def test_get_composite_pk_result(self): CompositePk = self.classes.CompositePk - s = Session() + s = fixture_session() one_two = s.query(CompositePk).get((1, 2)) assert one_two.i == 1 assert one_two.j == 2 @@ -653,28 +586,28 @@ class GetTest(QueryTest): def test_get_too_few_params(self): CompositePk = self.classes.CompositePk - s = Session() + s = fixture_session() q = s.query(CompositePk) assert_raises(sa_exc.InvalidRequestError, q.get, 7) def test_get_too_few_params_tuple(self): CompositePk = self.classes.CompositePk - s = Session() + s = fixture_session() q = s.query(CompositePk) assert_raises(sa_exc.InvalidRequestError, q.get, (7,)) def test_get_too_many_params(self): CompositePk = self.classes.CompositePk - s = Session() + s = fixture_session() q = s.query(CompositePk) assert_raises(sa_exc.InvalidRequestError, q.get, (7, 10, 100)) def test_get_against_col(self): User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User.id) assert_raises(sa_exc.InvalidRequestError, q.get, (5,)) @@ -702,14 +635,14 @@ class GetTest(QueryTest): PK (i.e. map to an outerjoin) works with get().""" UserThing = outerjoin_mapping - sess = create_session() + sess = fixture_session() u10 = sess.query(UserThing).get((10, None)) eq_(u10, UserThing(id=10)) def test_get_fully_null_pk(self): User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User) assert_raises_message( sa_exc.SAWarning, @@ -722,7 +655,7 @@ class GetTest(QueryTest): def test_get_fully_null_composite_pk(self, outerjoin_mapping): UserThing = outerjoin_mapping - s = Session() + s = fixture_session() q = s.query(UserThing) assert_raises_message( @@ -739,7 +672,7 @@ class GetTest(QueryTest): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() q = s.query(User).join("addresses").filter(Address.user_id == 8) assert_raises(sa_exc.InvalidRequestError, q.get, 7) @@ -758,7 +691,7 @@ class GetTest(QueryTest): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() s.query(User).get(7) @@ -775,13 +708,13 @@ class GetTest(QueryTest): m = mapper(SomeUser, s) assert s.primary_key == m.primary_key - sess = create_session() + sess = fixture_session() assert sess.query(SomeUser).get(7).name == "jack" def test_load(self): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session(autoflush=False) assert s.query(User).populate_existing().get(19) is None @@ -804,14 +737,8 @@ class GetTest(QueryTest): assert u2.name == "jack" assert a not in u2.addresses - @testing.provide_metadata @testing.requires.unicode_connections - def test_unicode(self, connection): - """test that Query.get properly sets up the type for the bind - parameter. using unicode would normally fail on postgresql, mysql and - oracle unless it is converted to an encoded string""" - - metadata = self.metadata + def test_unicode(self, metadata, connection): table = Table( "unicode_data", metadata, @@ -836,7 +763,7 @@ class GetTest(QueryTest): def test_populate_existing(self): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session(autoflush=False) userlist = s.query(User).all() @@ -872,7 +799,7 @@ class GetTest(QueryTest): def test_populate_existing_future(self): User, Address = self.classes.User, self.classes.Address - s = Session(testing.db, autoflush=False) + s = fixture_session(autoflush=False) userlist = s.query(User).all() @@ -923,7 +850,7 @@ class GetTest(QueryTest): stmt = select(User).execution_options( populate_existing=True, autoflush=False, yield_per=10 ) - s = Session(testing.db) + s = fixture_session() m1 = mock.Mock() @@ -948,7 +875,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): def test_no_limit_offset(self, test_case): User = self.classes.User - s = create_session() + s = fixture_session() q = testing.resolve_lambda(test_case, User=User, s=s) @@ -972,7 +899,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): def test_no_from(self): users, User = self.tables.users, self.classes.User - s = create_session() + s = fixture_session() q = s.query(User).select_from(users) assert_raises(sa_exc.InvalidRequestError, q.select_from, users) @@ -994,7 +921,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): def test_invalid_select_from(self): User = self.classes.User - s = create_session() + s = fixture_session() q = s.query(User) assert_raises(sa_exc.ArgumentError, q.select_from, User.id == 5) assert_raises(sa_exc.ArgumentError, q.select_from, User.id) @@ -1006,7 +933,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): self.tables.users, ) - s = create_session() + s = fixture_session() q = s.query(User) assert_raises(sa_exc.ArgumentError, q.from_statement, User.id == 5) assert_raises( @@ -1016,14 +943,14 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): def test_invalid_column(self): User = self.classes.User - s = create_session() + s = fixture_session() q = s.query(User) assert_raises(sa_exc.ArgumentError, q.add_columns, object()) def test_invalid_column_tuple(self): User = self.classes.User - s = create_session() + s = fixture_session() q = s.query(User) assert_raises(sa_exc.ArgumentError, q.add_columns, (1, 1)) @@ -1033,7 +960,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): User = self.classes.User - s = create_session() + s = fixture_session() q = s.query(User).distinct() assert_raises(sa_exc.InvalidRequestError, q.select_from, User) assert_raises( @@ -1049,7 +976,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): User = self.classes.User - s = create_session() + s = fixture_session() q = s.query(User).order_by(User.id) assert_raises(sa_exc.InvalidRequestError, q.select_from, User) assert_raises( @@ -1062,14 +989,14 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): def test_only_full_mapper_zero(self): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() q = s.query(User, Address) assert_raises(sa_exc.InvalidRequestError, q.get, 5) def test_entity_or_mapper_zero_from_context(self): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() q = s.query(User, Address)._compile_state() is_(q._mapper_zero(), inspect(User)) @@ -1111,7 +1038,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): def test_from_statement(self, test_case): User = self.classes.User - s = create_session() + s = fixture_session() q = testing.resolve_lambda(test_case, User=User, s=s) @@ -1127,7 +1054,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): def test_from_statement_text(self, meth, test_case): User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User) q = q.from_statement(text("x")) @@ -1160,7 +1087,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): ua, ) - s = Session() + s = fixture_session() assert_raises_message( sa_exc.ArgumentError, "SQL expression element or literal value expected, got .*User", @@ -1192,7 +1119,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): if entity is not None: # specify a lead entity, so that when we are testing # correlation, the correlation actually happens - sess = Session() + sess = fixture_session() lead = sess.query(entity) context = lead._compile_context() context.compile_state.statement._label_style = ( @@ -1207,7 +1134,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): self, clause, expected, from_, onclause, checkparams=None ): dialect = default.DefaultDialect() - sess = Session() + sess = fixture_session() lead = sess.query(from_).join(onclause, aliased=True) full = lead.filter(clause) context = lead._compile_context() @@ -1244,7 +1171,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): lhs = testing.resolve_lambda(lhs, User=User) rhs = testing.resolve_lambda(rhs, User=User) - create_session().query(User) + fixture_session().query(User) self._test(py_op(lhs, rhs), res % sql_op) @testing.combinations( @@ -1277,7 +1204,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): def test_comparison(self, py_op, fwd_op, rev_op, fixture): User = self.classes.User - create_session().query(User) + fixture_session().query(User) ualias = aliased(User) lhs, rhs, l_sql, r_sql = fixture(User=User, ualias=ualias) @@ -1562,7 +1489,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): self.classes.Address, self.classes.Dingaling, ) - sess = Session() + sess = fixture_session() q = sess.query(User).filter( User.addresses.any( @@ -1653,7 +1580,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_function_element_column_labels(self): users = self.tables.users - sess = Session() + sess = fixture_session() class max_(expression.FunctionElement): name = "max" @@ -1667,7 +1594,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_truly_unlabeled_sql_expressions(self): users = self.tables.users - sess = Session() + sess = fixture_session() class not_named_max(expression.ColumnElement): name = "not_named_max" @@ -1690,7 +1617,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - session = create_session() + session = fixture_session() s = ( session.query(User) .filter( @@ -1741,7 +1668,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - session = create_session() + session = fixture_session() q = session.query(User.id).filter(User.id == 7).scalar_subquery() @@ -1759,7 +1686,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_subquery_no_eagerloads(self): User = self.classes.User - s = Session() + s = fixture_session() self.assert_compile( s.query(User).options(joinedload(User.addresses)).subquery(), @@ -1768,7 +1695,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_exists_no_eagerloads(self): User = self.classes.User - s = Session() + s = fixture_session() self.assert_compile( s.query( @@ -1780,7 +1707,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_named_subquery(self): User = self.classes.User - session = create_session() + session = fixture_session() a1 = session.query(User.id).filter(User.id == 7).subquery("foo1") a2 = session.query(User.id).filter(User.id == 7).subquery(name="foo2") a3 = session.query(User.id).filter(User.id == 7).subquery() @@ -1792,7 +1719,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_labeled_subquery(self): User = self.classes.User - session = create_session() + session = fixture_session() a1 = ( session.query(User.id) .filter(User.id == 7) @@ -1804,7 +1731,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): User = self.classes.User ua = aliased(User) - session = create_session() + session = fixture_session() a1 = ( session.query(User.id, ua.id, ua.name) .filter(User.id == ua.id) @@ -1820,7 +1747,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_label(self): User = self.classes.User - session = create_session() + session = fixture_session() q = session.query(User.id).filter(User.id == 7).label("foo") self.assert_compile( @@ -1832,7 +1759,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_scalar_subquery(self): User = self.classes.User - session = create_session() + session = fixture_session() q = session.query(User.id).filter(User.id == 7).scalar_subquery() @@ -1847,7 +1774,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_param_transfer(self): User = self.classes.User - session = create_session() + session = fixture_session() q = ( session.query(User.id) @@ -1863,7 +1790,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_in(self): User, Address = self.classes.User, self.classes.Address - session = create_session() + session = fixture_session() s = ( session.query(User.id) .join(User.addresses) @@ -1875,7 +1802,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_union(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User).filter(User.name == "ed").with_labels() q2 = s.query(User).filter(User.name == "fred").with_labels() @@ -1889,7 +1816,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_select(self): User = self.classes.User - s = create_session() + s = fixture_session() # this is actually not legal on most DBs since the subquery has no # alias @@ -1906,7 +1833,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_join(self): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() # TODO: do we want aliased() to detect a query and convert to # subquery() automatically ? @@ -1926,7 +1853,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_group_by_plain(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User.id, User.name).group_by(User.name) self.assert_compile( @@ -1939,7 +1866,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_group_by_append(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User.id, User.name).group_by(User.name) @@ -1954,7 +1881,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_group_by_cancellation(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User.id, User.name).group_by(User.name) # test cancellation by using None, replacement with something else @@ -1977,7 +1904,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_group_by_cancelled_still_present(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User.id, User.name).group_by(User.name).group_by(None) @@ -1985,7 +1912,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_order_by_plain(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User.id, User.name).order_by(User.name) self.assert_compile( @@ -1998,7 +1925,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_order_by_append(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User.id, User.name).order_by(User.name) @@ -2013,7 +1940,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_order_by_cancellation(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User.id, User.name).order_by(User.name) # test cancellation by using None, replacement with something else @@ -2036,7 +1963,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_order_by_cancellation_false(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User.id, User.name).order_by(User.name) # test cancellation by using None, replacement with something else @@ -2059,7 +1986,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_order_by_cancelled_allows_assertions(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User.id, User.name).order_by(User.name).order_by(None) @@ -2067,7 +1994,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): def test_legacy_order_by_cancelled_allows_assertions(self): User = self.classes.User - s = create_session() + s = fixture_session() q1 = s.query(User.id, User.name).order_by(User.name).order_by(False) @@ -2126,7 +2053,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._func_fixture() User = self.classes.User - s = Session() + s = fixture_session() u1 = aliased(User) self.assert_compile( @@ -2139,7 +2066,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._func_fixture(label=True) User = self.classes.User - s = Session() + s = fixture_session() u1 = aliased(User) self.assert_compile( @@ -2152,7 +2079,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._func_fixture() User = self.classes.User - s = Session() + s = fixture_session() u1 = aliased(User) self.assert_compile( @@ -2165,7 +2092,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): User, Address = self.classes("User", "Address") self._fixture(label=True) - s = Session() + s = fixture_session() q = s.query(User).order_by("email_ad") self.assert_compile( q, @@ -2180,7 +2107,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): User, Address = self.classes("User", "Address") self._fixture(label=True) - s = Session() + s = fixture_session() ua = aliased(User) q = s.query(ua).order_by("email_ad") @@ -2195,7 +2122,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture(label=True) ua = aliased(User) - s = Session() + s = fixture_session() q = s.query(ua).order_by(ua.ead) self.assert_compile( q, @@ -2210,7 +2137,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture(label=True) ua = aliased(User) - s = Session() + s = fixture_session() q = s.query(ua.ead).order_by(ua.ead) self.assert_compile( q, @@ -2236,7 +2163,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture(label=True) ua = aliased(User) - s = Session() + s = fixture_session() q = s.query(User.ead, ua.ead).order_by(User.ead, ua.ead) self.assert_compile( q, @@ -2264,7 +2191,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture(label=True, polymorphic=True) ua = aliased(User) - s = Session() + s = fixture_session() q = s.query(ua, User.id).order_by(ua.ead) self.assert_compile( q, @@ -2280,7 +2207,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture(label=False) ua = aliased(User) - s = Session() + s = fixture_session() q = s.query(ua).order_by(ua.ead) self.assert_compile( q, @@ -2295,7 +2222,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture(label=False) ua = aliased(User) - s = Session() + s = fixture_session() q = s.query(ua.ead).order_by(ua.ead) self.assert_compile( q, @@ -2321,7 +2248,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture(label=False) ua = aliased(User) - s = Session() + s = fixture_session() q = s.query(User.ead, ua.ead).order_by(User.ead, ua.ead) self.assert_compile( q, @@ -2349,7 +2276,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): User, Address = self.classes("User", "Address") self._fixture(label=True) - s = Session() + s = fixture_session() q = s.query(User).order_by(User.ead) # this one is a bit of a surprise; this is compiler # label-order-by logic kicking in, but won't work in more @@ -2367,7 +2294,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): User, Address = self.classes("User", "Address") self._fixture(label=True) - s = Session() + s = fixture_session() q = s.query(User).options(defer(User.ead)).order_by(User.ead) self.assert_compile( q, @@ -2395,7 +2322,7 @@ class ComparatorTest(QueryTest): # this use case isn't exactly needed in this form, however it tests # that we resolve for multiple __clause_element__() calls as is needed # by systems like composites - sess = Session() + sess = fixture_session() eq_( sess.query(Comparator(User.id)) .order_by(Comparator(User.id)) @@ -2412,16 +2339,16 @@ class SliceTest(QueryTest): def test_first(self): User = self.classes.User - assert User(id=7) == create_session().query(User).first() + assert User(id=7) == fixture_session().query(User).first() assert ( - create_session().query(User).filter(User.id == 27).first() is None + fixture_session().query(User).filter(User.id == 27).first() is None ) def test_negative_indexes_raise(self): User = self.classes.User - sess = create_session(future=True) + sess = fixture_session(future=True) q = sess.query(User).order_by(User.id) with expect_raises_message( @@ -2462,7 +2389,7 @@ class SliceTest(QueryTest): User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User).order_by(User.id) self.assert_sql( @@ -2524,7 +2451,7 @@ class SliceTest(QueryTest): def test_first_against_expression_offset(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = ( sess.query(User) .order_by(User.id) @@ -2548,7 +2475,7 @@ class SliceTest(QueryTest): def test_full_slice_against_expression_offset(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = ( sess.query(User) .order_by(User.id) @@ -2571,7 +2498,7 @@ class SliceTest(QueryTest): def test_full_slice_against_integer_offset(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User).order_by(User.id).offset(2) self.assert_sql( @@ -2591,7 +2518,7 @@ class SliceTest(QueryTest): def test_start_slice_against_expression_offset(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User).order_by(User.id).offset(literal_column("2")) self.assert_sql( @@ -2614,14 +2541,14 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_basic(self): User = self.classes.User - users = create_session().query(User).all() + users = fixture_session().query(User).all() eq_([User(id=7), User(id=8), User(id=9), User(id=10)], users) @testing.requires.offset def test_limit_offset(self): User = self.classes.User - sess = create_session() + sess = fixture_session() assert [User(id=8), User(id=9)] == sess.query(User).order_by( User.id @@ -2640,7 +2567,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_select_with_bindparam_offset_limit(self): """Does a query allow bindparam for the limit?""" User = self.classes.User - sess = create_session() + sess = fixture_session() q1 = ( sess.query(self.classes.User) .order_by(self.classes.User.id) @@ -2667,7 +2594,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): @testing.requires.bound_limit_offset def test_select_with_bindparam_offset_limit_w_cast(self): User = self.classes.User - sess = create_session() + sess = fixture_session() eq_( list( sess.query(User) @@ -2685,7 +2612,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_exists(self): User = self.classes.User - sess = create_session(testing.db) + sess = fixture_session() assert sess.query(exists().where(User.id == 9)).scalar() assert not sess.query(exists().where(User.id == 29)).scalar() @@ -2693,16 +2620,16 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_one_filter(self): User = self.classes.User - assert [User(id=8), User(id=9)] == create_session().query(User).filter( - User.name.endswith("ed") - ).all() + assert [User(id=8), User(id=9)] == fixture_session().query( + User + ).filter(User.name.endswith("ed")).all() def test_contains(self): """test comparing a collection to an object instance.""" User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() address = sess.query(Address).get(3) assert [User(id=8)] == sess.query(User).filter( User.addresses.contains(address) @@ -2731,7 +2658,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_clause_element_ok(self): User = self.classes.User - s = Session() + s = fixture_session() self.assert_compile( s.query(User).filter(User.addresses), "SELECT users.id AS users_id, users.name AS users_name " @@ -2743,7 +2670,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): unique""" User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() a1, a2 = sess.query(Address).order_by(Address.id)[0:2] self.assert_compile( sess.query(User) @@ -2762,7 +2689,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): # SQL compilation User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() assert [User(id=8), User(id=9)] == sess.query(User).filter( User.addresses.any(Address.email_address.like("%ed%")) @@ -2797,7 +2724,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): # SQL compilation User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() # test that any() doesn't overcorrelate assert [User(id=7), User(id=8)] == sess.query(User).join( @@ -2815,7 +2742,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() assert [Address(id=5)] == sess.query(Address).filter( Address.user.has(name="fred") ).all() @@ -2864,7 +2791,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_contains_m2m(self): Item, Order = self.classes.Item, self.classes.Order - sess = create_session() + sess = fixture_session() item = sess.query(Item).get(3) eq_( @@ -2920,7 +2847,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() assert_raises_message( sa.exc.ArgumentError, "Mapped instance expected for relationship comparison to object.", @@ -2941,7 +2868,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() user = sess.query(User).get(8) assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query( Address @@ -2986,7 +2913,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_filter_by(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() user = sess.query(User).get(8) assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query( Address @@ -3007,7 +2934,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_filter_by_tables(self): users = self.tables.users addresses = self.tables.addresses - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(users) .filter_by(name="ed") @@ -3022,7 +2949,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_empty_filters(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q1 = sess.query(User) @@ -3031,7 +2958,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_filter_by_no_property(self): addresses = self.tables.addresses - sess = create_session() + sess = fixture_session() assert_raises_message( sa.exc.InvalidRequestError, 'Entity namespace for "addresses" has no property "name"', @@ -3046,7 +2973,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): self.classes.Address, ) - sess = create_session() + sess = fixture_session() # scalar eq_( @@ -3119,7 +3046,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): eq_( [(7,), (8,), (9,), (10,)], - create_session() + fixture_session() .query(User.id) .filter_by() .order_by(User.id) @@ -3127,7 +3054,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): ) eq_( [(7,), (8,), (9,), (10,)], - create_session() + fixture_session() .query(User.id) .filter_by(**{}) .order_by(User.id) @@ -3136,7 +3063,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_text_coerce(self): User = self.classes.User - s = create_session() + s = fixture_session() self.assert_compile( s.query(User).filter(text("name='ed'")), "SELECT users.id AS users_id, users.name " @@ -3144,7 +3071,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): ) def test_filter_by_non_entity(self): - s = create_session() + s = fixture_session() e = sa.func.count(123) assert_raises_message( sa_exc.InvalidRequestError, @@ -3205,7 +3132,7 @@ class HasAnyTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_has_composite_secondary(self): A, D = self.classes("A", "D") - s = Session() + s = fixture_session() self.assert_compile( s.query(A).filter(A.d.has(D.id == 1)), "SELECT a.id AS a_id, a.b_id AS a_b_id FROM a WHERE EXISTS " @@ -3215,7 +3142,7 @@ class HasAnyTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_has_many_to_one(self): B, C = self.classes("B", "C") - s = Session() + s = fixture_session() self.assert_compile( s.query(B).filter(B.c.has(C.id == 1)), "SELECT b.id AS b_id, b.c_id AS b_c_id FROM b WHERE " @@ -3224,7 +3151,7 @@ class HasAnyTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_any_many_to_many(self): B, D = self.classes("B", "D") - s = Session() + s = fixture_session() self.assert_compile( s.query(B).filter(B.d.any(D.id == 1)), "SELECT b.id AS b_id, b.c_id AS b_c_id FROM b WHERE " @@ -3234,7 +3161,7 @@ class HasAnyTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_any_one_to_many(self): B, C = self.classes("B", "C") - s = Session() + s = fixture_session() self.assert_compile( s.query(C).filter(C.bs.any(B.id == 1)), "SELECT c.id AS c_id, c.d_id AS c_d_id FROM c WHERE " @@ -3243,7 +3170,7 @@ class HasAnyTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_any_many_to_many_doesnt_overcorrelate(self): B, D = self.classes("B", "D") - s = Session() + s = fixture_session() self.assert_compile( s.query(B).join(B.d).filter(B.d.any(D.id == 1)), @@ -3256,7 +3183,7 @@ class HasAnyTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_has_doesnt_overcorrelate(self): B, C = self.classes("B", "C") - s = Session() + s = fixture_session() self.assert_compile( s.query(B).join(B.c).filter(B.c.has(C.id == 1)), @@ -3268,7 +3195,7 @@ class HasAnyTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_has_doesnt_get_aliased_join_subq(self): B, C = self.classes("B", "C") - s = Session() + s = fixture_session() ca = aliased(C) self.assert_compile( @@ -3281,7 +3208,7 @@ class HasAnyTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_any_many_to_many_doesnt_get_aliased_join_subq(self): B, D = self.classes("B", "D") - s = Session() + s = fixture_session() da = aliased(D) self.assert_compile( @@ -3298,7 +3225,7 @@ class HasAnyTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): class HasMapperEntitiesTest(QueryTest): def test_entity(self): User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User) @@ -3306,7 +3233,7 @@ class HasMapperEntitiesTest(QueryTest): def test_cols(self): User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User.id) @@ -3314,7 +3241,7 @@ class HasMapperEntitiesTest(QueryTest): def test_cols_set_entities(self): User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User.id) @@ -3323,7 +3250,7 @@ class HasMapperEntitiesTest(QueryTest): def test_entity_set_entities(self): User = self.classes.User - s = Session() + s = fixture_session() q = s.query(User) @@ -3337,7 +3264,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): def test_union(self): User = self.classes.User - s = create_session() + s = fixture_session() fred = s.query(User).filter(User.name == "fred") ed = s.query(User).filter(User.name == "ed") @@ -3363,7 +3290,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() q1 = ( s.query(User, Address) .join(User.addresses) @@ -3390,7 +3317,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): User = self.classes.User - s = Session() + s = fixture_session() q1 = s.query(User, literal("x")) q2 = s.query(User, literal_column("'y'")) q3 = q1.union(q2) @@ -3409,7 +3336,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): def test_union_literal_expressions_results(self): User = self.classes.User - s = Session() + s = fixture_session() x_literal = literal("x") q1 = s.query(User, x_literal) @@ -3443,7 +3370,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): def test_union_labeled_anonymous_columns(self): User = self.classes.User - s = Session() + s = fixture_session() c1, c2 = column("c1"), column("c2") q1 = s.query(User, c1.label("foo"), c1.label("bar")) @@ -3468,7 +3395,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): def test_order_by_anonymous_col(self): User = self.classes.User - s = Session() + s = fixture_session() c1, c2 = column("c1"), column("c2") f = c1.label("foo") @@ -3501,7 +3428,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): def test_union_mapped_colnames_preserved_across_subquery(self): User = self.classes.User - s = Session() + s = fixture_session() q1 = s.query(User.name) q2 = s.query(User.name) @@ -3521,7 +3448,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): def test_intersect(self): User = self.classes.User - s = create_session() + s = fixture_session() fred = s.query(User).filter(User.name == "fred") ed = s.query(User).filter(User.name == "ed") @@ -3533,7 +3460,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): def test_eager_load(self): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() fred = s.query(User).filter(User.name == "fred") ed = s.query(User).filter(User.name == "ed") @@ -3559,7 +3486,7 @@ class AggregateTest(QueryTest): def test_sum(self): Order = self.classes.Order - sess = create_session() + sess = fixture_session() orders = sess.query(Order).filter(Order.id.in_([2, 3, 4])) eq_( orders.with_entities( @@ -3571,7 +3498,7 @@ class AggregateTest(QueryTest): def test_apply(self): Order = self.classes.Order - sess = create_session() + sess = fixture_session() assert sess.query(func.sum(Order.user_id * Order.address_id)).filter( Order.id.in_([2, 3, 4]) ).one() == (79,) @@ -3579,7 +3506,7 @@ class AggregateTest(QueryTest): def test_having(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() assert [User(name="ed", id=8)] == sess.query(User).order_by( User.id ).group_by(User).join("addresses").having( @@ -3601,7 +3528,7 @@ class ExistsTest(QueryTest, AssertsCompiledSQL): def test_exists(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q1 = sess.query(User) self.assert_compile( @@ -3620,7 +3547,7 @@ class ExistsTest(QueryTest, AssertsCompiledSQL): def test_exists_col_warning(self): User = self.classes.User Address = self.classes.Address - sess = create_session() + sess = fixture_session() q1 = sess.query(User, Address).filter(User.id == Address.user_id) self.assert_compile( @@ -3633,7 +3560,7 @@ class ExistsTest(QueryTest, AssertsCompiledSQL): def test_exists_w_select_from(self): User = self.classes.User - sess = create_session() + sess = fixture_session() q1 = sess.query().select_from(User).exists() self.assert_compile( @@ -3645,7 +3572,7 @@ class CountTest(QueryTest): def test_basic(self): users, User = self.tables.users, self.classes.User - s = create_session() + s = fixture_session() eq_(s.query(User).count(), 4) @@ -3654,7 +3581,7 @@ class CountTest(QueryTest): def test_basic_future(self): User = self.classes.User - s = create_session() + s = fixture_session() eq_( s.execute(select(func.count()).select_from(User)).scalar(), @@ -3670,7 +3597,7 @@ class CountTest(QueryTest): def test_count_char(self): User = self.classes.User - s = create_session() + s = fixture_session() # '*' is favored here as the most common character, # it is reported that Informix doesn't like count(1), # rumors about Oracle preferring count(1) don't appear @@ -3689,7 +3616,7 @@ class CountTest(QueryTest): def test_multiple_entity(self): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() q = s.query(User, Address).join(Address, true()) eq_(q.count(), 20) # cartesian product @@ -3699,7 +3626,7 @@ class CountTest(QueryTest): def test_multiple_entity_future(self): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() stmt = select(User, Address).join(Address, true()) @@ -3714,7 +3641,7 @@ class CountTest(QueryTest): def test_nested(self): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() q = s.query(User, Address).join(Address, true()).limit(2) eq_(q.count(), 2) @@ -3727,7 +3654,7 @@ class CountTest(QueryTest): def test_nested_future(self): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() stmt = select(User, Address).join(Address, true()).limit(2) eq_( @@ -3752,7 +3679,7 @@ class CountTest(QueryTest): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() q = s.query(func.count(distinct(User.name))) eq_(q.count(), 1) @@ -3774,7 +3701,7 @@ class CountTest(QueryTest): User, Address = self.classes.User, self.classes.Address - s = create_session() + s = fixture_session() stmt = select(func.count(distinct(User.name))) eq_( @@ -3822,11 +3749,11 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): eq_( [User(id=7), User(id=8), User(id=9), User(id=10)], - create_session().query(User).order_by(User.id).distinct().all(), + fixture_session().query(User).order_by(User.id).distinct().all(), ) eq_( [User(id=7), User(id=9), User(id=8), User(id=10)], - create_session() + fixture_session() .query(User) .distinct() .order_by(desc(User.name)) @@ -3847,7 +3774,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): expr = (User.id.op("+")(2)).label("label") - sess = create_session() + sess = fixture_session() q = sess.query(expr).select_from(User).order_by(desc(expr)).distinct() @@ -3864,7 +3791,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): expr = User.id + literal(1) - sess = create_session() + sess = fixture_session() q = sess.query(expr).select_from(User).order_by(asc(expr)).distinct() # no double col in the select list, @@ -3880,7 +3807,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): expr = (User.id + literal(1)).label("label") - sess = create_session() + sess = fixture_session() q = sess.query(expr).select_from(User).order_by(asc(expr)).distinct() # no double col in the select list, @@ -3896,7 +3823,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): expr = (User.id + literal(1)).label("label") - sess = create_session() + sess = fixture_session() q = ( sess.query(expr) .select_from(User) @@ -3934,7 +3861,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): """ User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() subq = ( sess.query(User, Address.email_address) @@ -3957,7 +3884,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): """ User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = ( sess.query(User, Address.email_address) .join("addresses") @@ -3980,7 +3907,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() # test that it works on embedded joinedload/LIMIT subquery q = ( @@ -4009,7 +3936,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() subq = ( sess.query( @@ -4054,7 +3981,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = ( sess.query( @@ -4094,7 +4021,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_sql_one(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() subq = ( sess.query( @@ -4131,7 +4058,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_sql_union_one(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = ( sess.query( @@ -4168,7 +4095,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_sql_union_two(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = ( sess.query( @@ -4198,7 +4125,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_sql_two(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = ( sess.query(User) @@ -4234,7 +4161,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_sql_three(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = ( sess.query(User.id, User.name.label("foo"), Address.id) @@ -4254,7 +4181,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_distinct_on(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() subq = ( sess.query( @@ -4293,7 +4220,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_sql_three_using_label_reference(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = ( sess.query(User.id, User.name.label("foo"), Address.id) @@ -4313,7 +4240,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_sql_illegal_label_reference(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = sess.query(User.id, User.name.label("foo"), Address.id).distinct( "not a label" @@ -4332,7 +4259,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): def test_columns_augmented_sql_four(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() q = ( sess.query(User) @@ -4371,14 +4298,14 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): class PrefixSuffixWithTest(QueryTest, AssertsCompiledSQL): def test_one_prefix(self): User = self.classes.User - sess = create_session() + sess = fixture_session() query = sess.query(User.name).prefix_with("PREFIX_1") expected = "SELECT PREFIX_1 " "users.name AS users_name FROM users" self.assert_compile(query, expected, dialect=default.DefaultDialect()) def test_one_suffix(self): User = self.classes.User - sess = create_session() + sess = fixture_session() query = sess.query(User.name).suffix_with("SUFFIX_1") # trailing space for some reason expected = "SELECT users.name AS users_name FROM users SUFFIX_1 " @@ -4386,7 +4313,7 @@ class PrefixSuffixWithTest(QueryTest, AssertsCompiledSQL): def test_many_prefixes(self): User = self.classes.User - sess = create_session() + sess = fixture_session() query = sess.query(User.name).prefix_with("PREFIX_1", "PREFIX_2") expected = ( "SELECT PREFIX_1 PREFIX_2 " "users.name AS users_name FROM users" @@ -4395,7 +4322,7 @@ class PrefixSuffixWithTest(QueryTest, AssertsCompiledSQL): def test_chained_prefixes(self): User = self.classes.User - sess = create_session() + sess = fixture_session() query = ( sess.query(User.name) .prefix_with("PREFIX_1") @@ -4433,7 +4360,7 @@ class YieldTest(_fixtures.FixtureTest): User = self.classes.User - sess = create_session() + sess = fixture_session() q = iter( sess.query(User) .yield_per(1) @@ -4459,7 +4386,7 @@ class YieldTest(_fixtures.FixtureTest): User = self.classes.User - sess = create_session() + sess = fixture_session() @event.listens_for(sess, "do_orm_execute") def check(ctx): @@ -4483,7 +4410,7 @@ class YieldTest(_fixtures.FixtureTest): User = self.classes.User - sess = create_session() + sess = fixture_session() @event.listens_for(sess, "do_orm_execute") def check(ctx): @@ -4508,7 +4435,7 @@ class YieldTest(_fixtures.FixtureTest): self._eagerload_mappings() User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User).options(joinedload("addresses")).yield_per(1) assert_raises_message( sa_exc.InvalidRequestError, @@ -4521,7 +4448,7 @@ class YieldTest(_fixtures.FixtureTest): self._eagerload_mappings() User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User).options(subqueryload("addresses")).yield_per(1) assert_raises_message( sa_exc.InvalidRequestError, @@ -4534,7 +4461,7 @@ class YieldTest(_fixtures.FixtureTest): self._eagerload_mappings(addresses_lazy="subquery") User = self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User).yield_per(1) assert_raises_message( sa_exc.InvalidRequestError, @@ -4546,7 +4473,7 @@ class YieldTest(_fixtures.FixtureTest): def test_joinedload_m2o_ok(self): self._eagerload_mappings(user_lazy="joined") Address = self.classes.Address - sess = create_session() + sess = fixture_session() q = sess.query(Address).yield_per(1) q.all() @@ -4554,7 +4481,7 @@ class YieldTest(_fixtures.FixtureTest): self._eagerload_mappings() User = self.classes.User - sess = create_session() + sess = fixture_session() q = ( sess.query(User) .options(subqueryload("addresses")) @@ -4574,7 +4501,7 @@ class YieldTest(_fixtures.FixtureTest): def test_m2o_joinedload_not_others(self): self._eagerload_mappings(addresses_lazy="joined") Address = self.classes.Address - sess = create_session() + sess = fixture_session() q = ( sess.query(Address) .options(lazyload("*"), joinedload("user")) @@ -4599,7 +4526,7 @@ class HintsTest(QueryTest, AssertsCompiledSQL): dialect = mysql.dialect() - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User).with_hint( @@ -4635,7 +4562,7 @@ class HintsTest(QueryTest, AssertsCompiledSQL): def test_statement_hints(self): User = self.classes.User - sess = create_session() + sess = fixture_session() stmt = ( sess.query(User) .with_statement_hint("test hint one") @@ -4666,7 +4593,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): assert_raises_message( sa_exc.ArgumentError, "Textual SQL expression", - create_session().query(User).from_statement, + fixture_session().query(User).from_statement, "select * from users order by id", ) @@ -4674,14 +4601,14 @@ class TextTest(QueryTest, AssertsCompiledSQL): User = self.classes.User eq_( - create_session() + fixture_session() .query(User) .from_statement(text("select * from users order by id")) .first(), User(id=7), ) eq_( - create_session() + fixture_session() .query(User) .from_statement( text("select * from users where name='nonexistent'") @@ -4693,7 +4620,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_select_star_future(self): User = self.classes.User - sess = Session(testing.db) + sess = fixture_session() eq_( sess.execute( select(User).from_statement( @@ -4720,7 +4647,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): # ordering doesn't matter User = self.classes.User - s = create_session() + s = fixture_session() q = s.query(User).from_statement( text( "select name, 27 as foo, id as users_id from users order by id" @@ -4741,7 +4668,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): # ordering doesn't matter User = self.classes.User - s = create_session(testing.db) + s = fixture_session() q = select(User).from_statement( text( "select name, 27 as foo, id as users_id from users order by id" @@ -4763,7 +4690,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - s = create_session() + s = fixture_session() q = s.query(User, Address).from_statement( text( "select users.name AS users_name, users.id AS users_id, " @@ -4788,7 +4715,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - s = create_session(testing.db) + s = fixture_session() q = select(User, Address).from_statement( text( "select users.name AS users_name, users.id AS users_id, " @@ -4813,7 +4740,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - s = create_session() + s = fixture_session() q = ( s.query(User) .from_statement( @@ -4839,7 +4766,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - s = create_session(testing.db) + s = fixture_session() q = ( select(User) .from_statement( @@ -4865,7 +4792,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - s = create_session() + s = fixture_session() q = ( s.query(User) .from_statement( @@ -4891,7 +4818,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - s = create_session(testing.db) + s = fixture_session() q = ( select(User) .from_statement( @@ -4919,7 +4846,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): users = self.tables.users addresses = self.tables.addresses - s = create_session() + s = fixture_session() q = s.query(User.name, User.id, Address.id).from_statement( text( "select users.name AS users_name, users.id AS users_id, " @@ -4958,7 +4885,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): if add_columns: text_clause = text_clause.columns(User.id, User.name) - s = create_session() + s = fixture_session() q = ( s.query(User) .from_statement(text_clause) @@ -4974,12 +4901,12 @@ class TextTest(QueryTest, AssertsCompiledSQL): User = self.classes.User eq_( - create_session().query(User).filter(text("id in (8, 9)")).all(), + fixture_session().query(User).filter(text("id in (8, 9)")).all(), [User(id=8), User(id=9)], ) eq_( - create_session() + fixture_session() .query(User) .filter(text("name='fred'")) .filter(text("id=9")) @@ -4987,7 +4914,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): [User(id=9)], ) eq_( - create_session() + fixture_session() .query(User) .filter(text("name='fred'")) .filter(User.id == 9) @@ -4998,7 +4925,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_whereclause_future(self): User = self.classes.User - s = create_session(testing.db) + s = fixture_session() eq_( s.execute(select(User).filter(text("id in (8, 9)"))) .scalars() @@ -5030,14 +4957,14 @@ class TextTest(QueryTest, AssertsCompiledSQL): sa_exc.ArgumentError, r"Textual SQL expression 'id in \(:id1, :id2\)' " "should be explicitly declared", - create_session().query(User).filter, + fixture_session().query(User).filter, "id in (:id1, :id2)", ) def test_plain_textual_column(self): User = self.classes.User - s = create_session() + s = fixture_session() self.assert_compile( s.query(User.id, text("users.name")), @@ -5056,7 +4983,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_via_select(self): User = self.classes.User - s = create_session() + s = fixture_session() eq_( s.query(User) .from_statement( @@ -5070,7 +4997,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_via_textasfrom_from_statement(self): User = self.classes.User - s = create_session() + s = fixture_session() eq_( s.query(User) @@ -5085,7 +5012,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_columns_via_textasfrom_from_statement(self): User = self.classes.User - s = create_session() + s = fixture_session() eq_( s.query(User.id, User.name) @@ -5100,7 +5027,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_via_textasfrom_use_mapped_columns(self): User = self.classes.User - s = create_session() + s = fixture_session() eq_( s.query(User) @@ -5115,7 +5042,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_via_textasfrom_select_from(self): User = self.classes.User - s = create_session() + s = fixture_session() eq_( s.query(User) @@ -5131,7 +5058,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_group_by_accepts_text(self): User = self.classes.User - s = create_session() + s = fixture_session() q = s.query(User).group_by(text("name")) self.assert_compile( @@ -5148,7 +5075,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_order_by_w_eager_one(self): User = self.classes.User - s = create_session() + s = fixture_session() # from 1.0.0 thru 1.0.2, the "name" symbol here was considered # to be part of the things we need to ORDER BY and it was being @@ -5177,7 +5104,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_order_by_w_eager_two(self): User = self.classes.User - s = create_session() + s = fixture_session() q = ( s.query(User) @@ -5193,7 +5120,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_order_by_w_eager_three(self): User = self.classes.User - s = create_session() + s = fixture_session() self.assert_compile( s.query(User) @@ -5225,7 +5152,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): def test_order_by_w_eager_four(self): User = self.classes.User Address = self.classes.Address - s = create_session() + s = fixture_session() self.assert_compile( s.query(User) @@ -5262,7 +5189,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - sess = create_session() + sess = fixture_session() q = sess.query(User, Address.email_address.label("email_address")) @@ -5294,16 +5221,20 @@ class TextErrorTest(QueryTest, AssertsCompiledSQL): def test_filter(self): User = self.classes.User - self._test(Session().query(User.id).filter, "myid == 5", "myid == 5") + self._test( + fixture_session().query(User.id).filter, "myid == 5", "myid == 5" + ) def test_having(self): User = self.classes.User - self._test(Session().query(User.id).having, "myid == 5", "myid == 5") + self._test( + fixture_session().query(User.id).having, "myid == 5", "myid == 5" + ) def test_from_statement(self): User = self.classes.User self._test( - Session().query(User.id).from_statement, + fixture_session().query(User.id).from_statement, "select id from user", "select id from user", ) @@ -5319,7 +5250,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): self.classes.Order, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) u1 = q.filter_by(name="jack").one() @@ -5370,7 +5301,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_select_from(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() u1 = sess.query(User).get(7) q = sess.query(Address).select_from(Address).with_parent(u1) self.assert_compile( @@ -5385,7 +5316,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_from_entity_standalone_fn(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() u1 = sess.query(User).get(7) q = sess.query(User, Address).filter( with_parent(u1, "addresses", from_entity=Address) @@ -5404,7 +5335,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_from_entity_query_entity(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() u1 = sess.query(User).get(7) q = sess.query(User, Address).with_parent( u1, "addresses", from_entity=Address @@ -5423,7 +5354,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_select_from_alias(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() u1 = sess.query(User).get(7) a1 = aliased(Address) q = sess.query(a1).with_parent(u1) @@ -5440,7 +5371,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_select_from_alias_explicit_prop(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() u1 = sess.query(User).get(7) a1 = aliased(Address) q = sess.query(a1).with_parent(u1, "addresses") @@ -5457,7 +5388,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_select_from_alias_from_entity(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() u1 = sess.query(User).get(7) a1 = aliased(Address) a2 = aliased(Address) @@ -5478,7 +5409,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_select_from_alias_of_type(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() u1 = sess.query(User).get(7) a1 = aliased(Address) a2 = aliased(Address) @@ -5499,7 +5430,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_noparent(self): Item, User = self.classes.Item, self.classes.User - sess = create_session() + sess = fixture_session() q = sess.query(User) u1 = q.filter_by(name="jack").one() @@ -5516,7 +5447,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_m2m(self): Item, Keyword = self.classes.Item, self.classes.Keyword - sess = create_session() + sess = fixture_session() i1 = sess.query(Item).filter_by(id=2).one() k = sess.query(Keyword).with_parent(i1).all() assert [ @@ -5528,7 +5459,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_with_transient(self): User, Order = self.classes.User, self.classes.Order - sess = Session() + sess = fixture_session() q = sess.query(User) u1 = q.filter_by(name="jack").one() @@ -5556,7 +5487,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_with_pending_autoflush(self): Order, User = self.classes.Order, self.classes.User - sess = Session() + sess = fixture_session() o1 = sess.query(Order).first() opending = Order(id=20, user_id=o1.user_id) @@ -5573,7 +5504,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_with_pending_no_autoflush(self): Order, User = self.classes.Order, self.classes.User - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) o1 = sess.query(Order).first() opending = Order(user_id=o1.user_id) @@ -5587,7 +5518,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): """bindparams used in the 'parent' query are unique""" User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() u1, u2 = sess.query(User).order_by(User.id)[0:2] q1 = sess.query(Address).with_parent(u1, "addresses") @@ -5612,7 +5543,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): def test_unique_binds_or(self): User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() u1, u2 = sess.query(User).order_by(User.id)[0:2] self.assert_compile( @@ -5680,7 +5611,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture1() User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() q = sess.query(Address).filter(Address.user == User()) assert_raises_message( @@ -5694,7 +5625,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture1() User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() q = sess.query(Address).filter(Address.user == User(id=None)) with expect_warnings("Got None for value of column "): @@ -5711,7 +5642,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture1() User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() u1 = User() # id is not set, so evaluates to NEVER_SET @@ -5733,7 +5664,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture1() User, Address = self.classes.User, self.classes.Address - s = Session() + s = fixture_session() q = s.query(Address).filter( Address.special_user == User(id=None, name=None) ) @@ -5757,7 +5688,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): self.classes.Dingaling, self.classes.HasDingaling, ) - s = Session() + s = fixture_session() d = Dingaling(id=1) s.add(d) s.flush() @@ -5781,7 +5712,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): self.classes.Dingaling, self.classes.HasDingaling, ) - s = Session() + s = fixture_session() d = Dingaling() s.add(d) s.flush() @@ -5811,7 +5742,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): self.classes.Dingaling, self.classes.HasDingaling, ) - s = Session() + s = fixture_session() d = Dingaling(data="some data") s.add(d) s.commit() @@ -5833,7 +5764,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture1() User, Address = self.classes.User, self.classes.Address - sess = Session() + sess = fixture_session() q = sess.query(User).with_parent(Address(user_id=None), "user") with expect_warnings("Got None for value of column"): @@ -5848,7 +5779,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture1() User, Address = self.classes.User, self.classes.Address - s = Session() + s = fixture_session() q = s.query(User).with_parent( Address(user_id=None, email_address=None), "special_user" ) @@ -5866,7 +5797,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture1() User, Address = self.classes.User, self.classes.Address - s = Session() + s = fixture_session() q = s.query(Address).filter(Address.user != User(id=None)) with expect_warnings("Got None for value of column"): self.assert_compile( @@ -5884,7 +5815,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): self._fixture1() User, Address = self.classes.User, self.classes.Address - s = Session() + s = fixture_session() # this one does *not* warn because we do the criteria # without deferral @@ -5973,7 +5904,7 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): def test_options(self): User, Order = self.classes.User, self.classes.Order - s = create_session() + s = fixture_session() def go(): result = ( @@ -6002,7 +5933,7 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): def test_options_syn_of_syn(self): User, Order = self.classes.User, self.classes.Order - s = create_session() + s = fixture_session() def go(): result = ( @@ -6031,7 +5962,7 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): def test_options_syn_of_syn_string(self): User, Order = self.classes.User, self.classes.Order - s = create_session() + s = fixture_session() def go(): result = ( @@ -6069,7 +6000,7 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): [User.orders_syn, Order.items_syn], [User.orders_syn_2, Order.items_syn], ): - q = create_session().query(User) + q = fixture_session().query(User) for path in j: q = q.join(path) q = q.filter_by(id=3) @@ -6087,7 +6018,7 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): ("name_syn", "orders_syn"), ("name_syn", "orders_syn_2"), ): - sess = create_session() + sess = fixture_session() q = sess.query(User) u1 = q.filter_by(**{nameprop: "jack"}).one() @@ -6102,7 +6033,7 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): def test_froms_aliased_col(self): Address, User = self.classes.Address, self.classes.User - sess = create_session() + sess = fixture_session() ua = aliased(User) q = sess.query(ua.name_syn).join(Address, ua.id == Address.user_id) @@ -6134,7 +6065,7 @@ class ImmediateTest(_fixtures.FixtureTest): def test_one(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() assert_raises_message( sa.orm.exc.NoResultFound, @@ -6216,7 +6147,7 @@ class ImmediateTest(_fixtures.FixtureTest): def test_one_or_none(self): User, Address = self.classes.User, self.classes.Address - sess = create_session() + sess = fixture_session() eq_(sess.query(User).filter(User.id == 99).one_or_none(), None) @@ -6299,7 +6230,7 @@ class ImmediateTest(_fixtures.FixtureTest): def test_scalar(self): User = self.classes.User - sess = create_session() + sess = fixture_session() eq_(sess.query(User.id).filter_by(id=7).scalar(), 7) eq_(sess.query(User.id, User.name).filter_by(id=7).scalar(), 7) @@ -6320,7 +6251,7 @@ class ExecutionOptionsTest(QueryTest): def test_option_building(self): User = self.classes.User - sess = create_session(bind=testing.db, autocommit=False) + sess = fixture_session(autocommit=False) q1 = sess.query(User) eq_(q1._execution_options, dict()) @@ -6338,7 +6269,7 @@ class ExecutionOptionsTest(QueryTest): def test_get_options(self): User = self.classes.User - sess = create_session(bind=testing.db, autocommit=False) + sess = fixture_session(autocommit=False) q = sess.query(User).execution_options(foo="bar", stream_results=True) eq_(q.get_execution_options(), dict(foo="bar", stream_results=True)) @@ -6358,9 +6289,7 @@ class ExecutionOptionsTest(QueryTest): result.close() return iter([]) - sess = create_session( - bind=testing.db, autocommit=False, query_cls=TQuery - ) + sess = fixture_session(autocommit=False, query_cls=TQuery) q1 = sess.query(User).execution_options(**execution_options) q1.all() @@ -6375,7 +6304,7 @@ class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): return d def test_one(self): - s = Session() + s = fixture_session() c = column("x", Boolean) self.assert_compile( s.query(c).filter(c), @@ -6384,7 +6313,7 @@ class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_two(self): - s = Session() + s = fixture_session() c = column("x", Boolean) self.assert_compile( s.query(c).filter(c), @@ -6393,7 +6322,7 @@ class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_three(self): - s = Session() + s = fixture_session() c = column("x", Boolean) self.assert_compile( s.query(c).filter(~c), @@ -6402,7 +6331,7 @@ class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_four(self): - s = Session() + s = fixture_session() c = column("x", Boolean) self.assert_compile( s.query(c).filter(~c), @@ -6411,7 +6340,7 @@ class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_five(self): - s = Session() + s = fixture_session() c = column("x", Boolean) self.assert_compile( s.query(c).having(c), @@ -6439,26 +6368,26 @@ class SessionBindTest(QueryTest): def test_single_entity_q(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(User).all() def test_aliased_entity_q(self): User = self.classes.User u = aliased(User) - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(u).all() def test_sql_expr_entity_q(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(User.id).all() def test_sql_expr_subquery_from_entity(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): subq = session.query(User.id).subquery() session.query(subq).all() @@ -6466,14 +6395,14 @@ class SessionBindTest(QueryTest): @testing.requires.boolean_col_expressions def test_sql_expr_exists_from_entity(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): subq = session.query(User.id).exists() session.query(subq).all() def test_sql_expr_cte_from_entity(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): cte = session.query(User.id).cte() subq = session.query(cte).subquery() @@ -6481,7 +6410,7 @@ class SessionBindTest(QueryTest): def test_sql_expr_bundle_cte_from_entity(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): cte = session.query(User.id, User.name).cte() subq = session.query(cte).subquery() @@ -6490,63 +6419,63 @@ class SessionBindTest(QueryTest): def test_count(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(User).count() def test_single_col(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(User.name).all() def test_single_col_from_subq(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): subq = session.query(User.id, User.name).subquery() session.query(subq.c.name).all() def test_aggregate_fn(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(func.max(User.name)).all() def test_case(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(case([(User.name == "x", "C")], else_="W")).all() def test_cast(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(cast(User.name, String())).all() def test_type_coerce(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(type_coerce(User.name, String())).all() def test_binary_op(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(User.name + "x").all() @testing.requires.boolean_col_expressions def test_boolean_op(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(User.name == "x").all() def test_bulk_update_no_sync(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session): session.query(User).filter(User.id == 15).update( {"name": "foob"}, synchronize_session=False @@ -6554,7 +6483,7 @@ class SessionBindTest(QueryTest): def test_bulk_delete_no_sync(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session): session.query(User).filter(User.id == 15).delete( synchronize_session=False @@ -6562,7 +6491,7 @@ class SessionBindTest(QueryTest): def test_bulk_update_fetch_sync(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session): session.query(User).filter(User.id == 15).update( {"name": "foob"}, synchronize_session="fetch" @@ -6570,7 +6499,7 @@ class SessionBindTest(QueryTest): def test_bulk_delete_fetch_sync(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session): session.query(User).filter(User.id == 15).delete( synchronize_session="fetch" @@ -6584,14 +6513,14 @@ class SessionBindTest(QueryTest): "score", column_property(func.coalesce(self.tables.users.c.name, None)), ) - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=True): session.query(func.max(User.score)).scalar() def test_plain_table(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=False): session.query(inspect(User).local_table).all() @@ -6599,14 +6528,14 @@ class SessionBindTest(QueryTest): User = self.classes.User # TODO: this test is dumb - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=False): session.query(inspect(User).local_table).from_self().all() def test_plain_table_count(self): User = self.classes.User - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=False): session.query(inspect(User).local_table).count() @@ -6614,7 +6543,7 @@ class SessionBindTest(QueryTest): User = self.classes.User table = inspect(User).local_table - session = Session() + session = fixture_session() with self._assert_bind_args(session, expect_mapped_bind=False): session.query(table).select_from(table).all() @@ -6632,7 +6561,7 @@ class SessionBindTest(QueryTest): .scalar_subquery() ), ) - session = Session() + session = fixture_session() with self._assert_bind_args(session): session.query(func.max(User.score)).scalar() @@ -6664,7 +6593,7 @@ class QueryClsTest(QueryTest): def _test_get(self, fixture): User = self.classes.User - s = Session(query_cls=fixture()) + s = fixture_session(query_cls=fixture()) assert s.query(User).get(19) is None u = s.query(User).get(7) @@ -6674,7 +6603,7 @@ class QueryClsTest(QueryTest): def _test_o2m_lazyload(self, fixture): User, Address = self.classes("User", "Address") - s = Session(query_cls=fixture()) + s = fixture_session(query_cls=fixture()) u1 = s.query(User).filter(User.id == 7).first() eq_(u1.addresses, [Address(id=1)]) @@ -6682,7 +6611,7 @@ class QueryClsTest(QueryTest): def _test_m2o_lazyload(self, fixture): User, Address = self.classes("User", "Address") - s = Session(query_cls=fixture()) + s = fixture_session(query_cls=fixture()) a1 = s.query(Address).filter(Address.id == 1).first() eq_(a1.user, User(id=7)) @@ -6690,7 +6619,7 @@ class QueryClsTest(QueryTest): def _test_expr(self, fixture): User, Address = self.classes("User", "Address") - s = Session(query_cls=fixture()) + s = fixture_session(query_cls=fixture()) q = s.query(func.max(User.id).label("max")) eq_(q.scalar(), 10) @@ -6699,7 +6628,7 @@ class QueryClsTest(QueryTest): # see #4269. not documented but already out there. User, Address = self.classes("User", "Address") - s = Session(query_cls=fixture()) + s = fixture_session(query_cls=fixture()) q = Query(func.max(User.id).label("max")).with_session(s) eq_(q.scalar(), 10) diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index 22315e176..5979f08ae 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -20,7 +20,6 @@ from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import column_property from sqlalchemy.orm import composite from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import foreign from sqlalchemy.orm import joinedload @@ -29,7 +28,6 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm import remote from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session -from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import subqueryload from sqlalchemy.orm import synonym from sqlalchemy.orm.interfaces import MANYTOONE @@ -41,9 +39,9 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ -from sqlalchemy.testing import startswith_ from sqlalchemy.testing.assertsql import assert_engine from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm import _fixtures @@ -288,7 +286,7 @@ class DependencyTwoParentTest(fixtures.MappedTest): cls.classes.D, ) - session = create_session(connection) + session = Session(connection) a = A(name="a1") b = B(name="b1") c = C(name="c1", a_row=a) @@ -303,7 +301,7 @@ class DependencyTwoParentTest(fixtures.MappedTest): def test_DeleteRootTable(self): A = self.classes.A - session = create_session() + session = fixture_session() a = session.query(A).filter_by(name="a1").one() session.delete(a) @@ -312,7 +310,7 @@ class DependencyTwoParentTest(fixtures.MappedTest): def test_DeleteMiddleTable(self): C = self.classes.C - session = create_session() + session = fixture_session() c = session.query(C).filter_by(name="c1").one() session.delete(c) @@ -345,7 +343,7 @@ class M2ODontOverwriteFKTest(fixtures.MappedTest): def test_joinedload_doesnt_produce_bogus_event(self): A, B = self._fixture() - sess = Session() + sess = fixture_session() b1 = B() sess.add(b1) @@ -364,7 +362,7 @@ class M2ODontOverwriteFKTest(fixtures.MappedTest): def test_init_doesnt_produce_scalar_event(self): A, B = self._fixture() - sess = Session() + sess = fixture_session() b1 = B() sess.add(b1) @@ -379,7 +377,7 @@ class M2ODontOverwriteFKTest(fixtures.MappedTest): def test_init_doesnt_produce_collection_event(self): A, B = self._fixture(uselist=True) - sess = Session() + sess = fixture_session() b1 = B() sess.add(b1) @@ -394,7 +392,7 @@ class M2ODontOverwriteFKTest(fixtures.MappedTest): def test_scalar_relationship_overrides_fk(self): A, B = self._fixture() - sess = Session() + sess = fixture_session() b1 = B() sess.add(b1) @@ -409,7 +407,7 @@ class M2ODontOverwriteFKTest(fixtures.MappedTest): def test_collection_relationship_overrides_fk(self): A, B = self._fixture(uselist=True) - sess = Session() + sess = fixture_session() b1 = B() sess.add(b1) @@ -506,7 +504,7 @@ class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): def _fixture(self): Entity = self.classes.Entity - sess = Session() + sess = fixture_session() sess.add_all( [ Entity("/foo"), @@ -625,7 +623,7 @@ class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): def test_plain_join_descendants(self): self._descendants_fixture(data=False) Entity = self.classes.Entity - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(Entity).join(Entity.descendants, aliased=True), "SELECT entity.path AS entity_path FROM entity JOIN entity AS " @@ -718,7 +716,7 @@ class OverlappingFksSiblingTest(fixtures.TestBase): __mapper_args__ = {"polymorphic_identity": "bsub2"} configure_mappers() - self.metadata.create_all() + self.metadata.create_all(testing.db) return A, AMember, B, BSub1, BSub2 @@ -1271,7 +1269,7 @@ class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): def _test_no_warning(self, overwrites=False): configure_mappers() self._test_relationships() - sess = Session() + sess = fixture_session() self._setup_data(sess) self._test_lazy_relations(sess) self._test_join_aliasing(sess) @@ -1486,7 +1484,7 @@ class SynonymsAsFKsTest(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() b = B(id=0) a = A(id=0, b=b) @@ -1556,15 +1554,15 @@ class FKsAsPksTest(fixtures.MappedTest): configure_mappers() assert A.b.property.strategy.use_get - sess = create_session() + with fixture_session() as sess: + a1 = A() + sess.add(a1) + sess.commit() - a1 = A() - sess.add(a1) - sess.flush() - sess.close() - a1 = sess.query(A).first() - a1.b = B() - sess.flush() + with fixture_session() as sess: + a1 = sess.query(A).first() + a1.b = B() + sess.commit() def test_no_delete_PK_AtoB(self): """A cant be deleted without B because B would have no PK value.""" @@ -1585,19 +1583,17 @@ class FKsAsPksTest(fixtures.MappedTest): a1 = A() a1.bs.append(B()) - sess = create_session() - sess.add(a1) - sess.flush() - - sess.delete(a1) - try: + with fixture_session() as sess: + sess.add(a1) sess.flush() - assert False - except AssertionError as e: - startswith_( - str(e), + + sess.delete(a1) + + assert_raises_message( + AssertionError, "Dependency rule tried to blank-out " "primary key column 'tableB.id' on instance ", + sess.flush, ) def test_no_delete_PK_BtoA(self): @@ -1616,40 +1612,37 @@ class FKsAsPksTest(fixtures.MappedTest): b1 = B() a1 = A() b1.a = a1 - sess = create_session() - sess.add(b1) - sess.flush() - b1.a = None - try: + with fixture_session() as sess: + sess.add(b1) sess.flush() - assert False - except AssertionError as e: - startswith_( - str(e), + b1.a = None + assert_raises_message( + AssertionError, "Dependency rule tried to blank-out " "primary key column 'tableB.id' on instance ", + sess.flush, ) @testing.fails_on_everything_except( "sqlite", testing.requires.mysql_non_strict ) - def test_nullPKsOK_BtoA(self): + def test_nullPKsOK_BtoA(self, metadata, connection): A, tableA = self.classes.A, self.tables.tableA # postgresql cant handle a nullable PK column...? tableC = Table( "tablec", - tableA.metadata, + metadata, Column("id", Integer, primary_key=True), Column( "a_id", Integer, - ForeignKey("tableA.id"), + ForeignKey(tableA.c.id), primary_key=True, nullable=True, ), ) - tableC.create() + tableC.create(connection) class C(fixtures.BasicEntity): pass @@ -1662,10 +1655,10 @@ class FKsAsPksTest(fixtures.MappedTest): c1 = C() c1.id = 5 c1.a = None - sess = create_session() - sess.add(c1) - # test that no error is raised. - sess.flush() + with Session(connection) as sess: + sess.add(c1) + # test that no error is raised. + sess.flush() def test_delete_cascade_BtoA(self): """No 'blank the PK' error when the child is to @@ -1695,14 +1688,14 @@ class FKsAsPksTest(fixtures.MappedTest): b1 = B() a1 = A() b1.a = a1 - sess = create_session() - sess.add(b1) - sess.flush() - sess.delete(b1) - sess.flush() - assert a1 not in sess - assert b1 not in sess - sess.expunge_all() + with fixture_session() as sess: + sess.add(b1) + sess.flush() + sess.delete(b1) + sess.flush() + assert a1 not in sess + assert b1 not in sess + sa.orm.clear_mappers() def test_delete_cascade_AtoB(self): @@ -1729,15 +1722,15 @@ class FKsAsPksTest(fixtures.MappedTest): a1 = A() b1 = B() a1.bs.append(b1) - sess = create_session() - sess.add(a1) - sess.flush() + with fixture_session() as sess: + sess.add(a1) + sess.flush() + + sess.delete(a1) + sess.flush() + assert a1 not in sess + assert b1 not in sess - sess.delete(a1) - sess.flush() - assert a1 not in sess - assert b1 not in sess - sess.expunge_all() sa.orm.clear_mappers() def test_delete_manual_AtoB(self): @@ -1754,17 +1747,16 @@ class FKsAsPksTest(fixtures.MappedTest): a1 = A() b1 = B() a1.bs.append(b1) - sess = create_session() - sess.add(a1) - sess.add(b1) - sess.flush() + with fixture_session() as sess: + sess.add(a1) + sess.add(b1) + sess.flush() - sess.delete(a1) - sess.delete(b1) - sess.flush() - assert a1 not in sess - assert b1 not in sess - sess.expunge_all() + sess.delete(a1) + sess.delete(b1) + sess.flush() + assert a1 not in sess + assert b1 not in sess def test_delete_manual_BtoA(self): tableB, A, B, tableA = ( @@ -1780,15 +1772,15 @@ class FKsAsPksTest(fixtures.MappedTest): b1 = B() a1 = A() b1.a = a1 - sess = create_session() - sess.add(b1) - sess.add(a1) - sess.flush() - sess.delete(b1) - sess.delete(a1) - sess.flush() - assert a1 not in sess - assert b1 not in sess + with fixture_session() as sess: + sess.add(b1) + sess.add(a1) + sess.flush() + sess.delete(b1) + sess.delete(a1) + sess.flush() + assert a1 not in sess + assert b1 not in sess class UniqueColReferenceSwitchTest(fixtures.MappedTest): @@ -1840,7 +1832,7 @@ class UniqueColReferenceSwitchTest(fixtures.MappedTest): mapper(A, table_a) mapper(B, table_b, properties={"a": relationship(A, backref="bs")}) - session = create_session() + session = fixture_session() a1, a2 = A(ident="uuid1"), A(ident="uuid2") session.add_all([a1, a2]) a1.bs = [B(), B()] @@ -1926,7 +1918,7 @@ class RelationshipToSelectableTest(fixtures.MappedTest): ), ) - session = create_session() + session = fixture_session() con = Container() con.policyNum = "99" con.policyEffDate = datetime.date.today() @@ -2003,7 +1995,7 @@ class FKEquatedToConstantTest(fixtures.MappedTest): mapper(TagInstance, tag_foo) - sess = create_session() + sess = fixture_session() t1 = Tag(data="some tag") t1.foo.append(TagInstance(data="iplc_case")) t1.foo.append(TagInstance(data="not_iplc_case")) @@ -2075,7 +2067,7 @@ class BackrefPropagatesForwardsArgs(fixtures.MappedTest): ) mapper(Address, addresses) - sess = sessionmaker()() + sess = fixture_session() u1 = User(name="u1", addresses=[Address(email="a1")]) sess.add(u1) sess.commit() @@ -2153,7 +2145,7 @@ class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest): def test_mapping(self): Subscriber, Address = self.classes.Subscriber, self.classes.Address - sess = create_session() + sess = fixture_session() assert Subscriber.addresses.property.direction is ONETOMANY assert Address.customer.property.direction is MANYTOONE @@ -2210,7 +2202,7 @@ class ManualBackrefTest(_fixtures.FixtureTest): }, ) - sess = create_session() + sess = fixture_session() u1 = User(name="u1") a1 = Address(email_address="foo") @@ -2550,7 +2542,7 @@ class TypeMatchTest(fixtures.MappedTest): c1 = C() a1.bs.append(b1) a1.bs.append(c1) - sess = create_session() + sess = fixture_session() try: sess.add(a1) assert False @@ -2582,7 +2574,7 @@ class TypeMatchTest(fixtures.MappedTest): c1 = C() a1.bs.append(b1) a1.bs.append(c1) - sess = create_session() + sess = fixture_session() sess.add(a1) sess.add(b1) sess.add(c1) @@ -2611,7 +2603,7 @@ class TypeMatchTest(fixtures.MappedTest): c1 = C() a1.bs.append(b1) a1.bs.append(c1) - sess = create_session() + sess = fixture_session() sess.add(a1) sess.add(b1) sess.add(c1) @@ -2637,7 +2629,7 @@ class TypeMatchTest(fixtures.MappedTest): b1 = B() d1 = D() d1.a = b1 - sess = create_session() + sess = fixture_session() sess.add(b1) sess.add(d1) assert_raises_message( @@ -2662,7 +2654,7 @@ class TypeMatchTest(fixtures.MappedTest): b1 = B() d1 = D() d1.a = b1 - sess = create_session() + sess = fixture_session() assert_raises_message( AssertionError, "doesn't handle objects of type", sess.add, d1 ) @@ -2725,16 +2717,22 @@ class TypedAssociationTable(fixtures.MappedTest): c.col1 = "cid" a.t2s.append(b) a.t2s.append(c) - sess = create_session() + sess = fixture_session() sess.add(a) sess.flush() - eq_(select(func.count("*")).select_from(t3).scalar(), 2) + eq_( + sess.connection().scalar(select(func.count("*")).select_from(t3)), + 2, + ) a.t2s.remove(c) sess.flush() - eq_(select(func.count("*")).select_from(t3).scalar(), 1) + eq_( + sess.connection().scalar(select(func.count("*")).select_from(t3)), + 1, + ) class CustomOperatorTest(fixtures.MappedTest, AssertsCompiledSQL): @@ -2782,7 +2780,7 @@ class CustomOperatorTest(fixtures.MappedTest, AssertsCompiledSQL): ) mapper(B, self.tables.b) self.assert_compile( - Session().query(A).join(A.bs), + fixture_session().query(A).join(A.bs), "SELECT a.id AS a_id, a.foo AS a_foo " "FROM a JOIN b ON a.foo &* b.foo", ) @@ -2976,7 +2974,7 @@ class ViewOnlyM2MBackrefTest(fixtures.MappedTest): configure_mappers() - sess = create_session() + sess = fixture_session() a1 = A() b1 = B(as_=[a1]) @@ -3069,7 +3067,7 @@ class ViewOnlyOverlappingNames(fixtures.MappedTest): c3 = C3() c3.data = "c1data" c3.t2 = c2b - sess = create_session() + sess = fixture_session() sess.add(c1) sess.add(c3) sess.flush() @@ -3330,7 +3328,7 @@ class ViewOnlyUniqueNames(fixtures.MappedTest): c3 = C3() c3.data = "c1data" c3.t2 = c2b - sess = create_session() + sess = fixture_session() sess.add_all((c1, c3)) sess.flush() @@ -3418,20 +3416,20 @@ class ViewOnlyNonEquijoin(fixtures.MappedTest): mapper(Bar, bars) - sess = create_session() - sess.add_all( - ( - Foo(id=4), - Foo(id=9), - Bar(id=1, fid=2), - Bar(id=2, fid=3), - Bar(id=3, fid=6), - Bar(id=4, fid=7), + with fixture_session() as sess: + sess.add_all( + ( + Foo(id=4), + Foo(id=9), + Bar(id=1, fid=2), + Bar(id=2, fid=3), + Bar(id=3, fid=6), + Bar(id=4, fid=7), + ) ) - ) - sess.flush() + sess.commit() - sess = create_session() + sess = fixture_session() eq_( sess.query(Foo).filter_by(id=4).one(), Foo(id=4, bars=[Bar(fid=2), Bar(fid=3)]), @@ -3492,7 +3490,7 @@ class ViewOnlyRepeatedRemoteColumn(fixtures.MappedTest): ) mapper(Bar, bars) - sess = create_session() + sess = fixture_session() b1 = Bar(id=1, data="b1") b2 = Bar(id=2, data="b2") b3 = Bar(id=3, data="b3") @@ -3566,7 +3564,7 @@ class ViewOnlyRepeatedLocalColumn(fixtures.MappedTest): ) mapper(Bar, bars) - sess = create_session() + sess = fixture_session() f1 = Foo(id=1, data="f1") f2 = Foo(id=2, data="f2") b1 = Bar(fid1=1, data="b1") @@ -3678,7 +3676,7 @@ class ViewOnlyComplexJoin(_RelationshipErrors, fixtures.MappedTest): ) mapper(T3, t3) - sess = create_session() + sess = fixture_session() sess.add(T2(data="t2", t1=T1(data="t1"), t3s=[T3(data="t3")])) sess.flush() sess.expunge_all() @@ -3766,7 +3764,7 @@ class FunctionAsPrimaryJoinTest(fixtures.DeclarativeMappedTest): def test_lazyload(self): Venue = self.classes.Venue - s = Session() + s = fixture_session() v1 = s.query(Venue).filter_by(name="parent1").one() eq_( [d.name for d in v1.descendants], @@ -3775,7 +3773,7 @@ class FunctionAsPrimaryJoinTest(fixtures.DeclarativeMappedTest): def test_joinedload(self): Venue = self.classes.Venue - s = Session() + s = fixture_session() def go(): v1 = ( @@ -3951,7 +3949,7 @@ class ExplicitLocalRemoteTest(fixtures.MappedTest): ) is_(T1.t2s.property.direction, ONETOMANY) eq_(T1.t2s.property.local_remote_pairs, [(t1.c.id, t2.c.t1id)]) - sess = create_session() + sess = fixture_session() a1 = T1(id="number1", data="a1") a2 = T1(id="number2", data="a2") b1 = T2(data="b1", t1id="NuMbEr1") @@ -3996,7 +3994,7 @@ class ExplicitLocalRemoteTest(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() a1 = T1(id="number1", data="a1") a2 = T1(id="number2", data="a2") b1 = T2(data="b1", t1id="NuMbEr1") @@ -4036,7 +4034,7 @@ class ExplicitLocalRemoteTest(fixtures.MappedTest): ) mapper(T2, t2) - sess = create_session() + sess = fixture_session() a1 = T1(id="NuMbeR1", data="a1") a2 = T1(id="NuMbeR2", data="a2") b1 = T2(data="b1", t1id="number1") @@ -4081,7 +4079,7 @@ class ExplicitLocalRemoteTest(fixtures.MappedTest): }, ) - sess = create_session() + sess = fixture_session() a1 = T1(id="NuMbeR1", data="a1") a2 = T1(id="NuMbeR2", data="a2") b1 = T2(data="b1", t1id="number1") @@ -4525,7 +4523,7 @@ class SecondaryNestedJoinTest( def test_render_join(self): A = self.classes.A - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(A).join(A.d), "SELECT a.id AS a_id, a.name AS a_name, a.b_id AS a_b_id " @@ -4537,7 +4535,7 @@ class SecondaryNestedJoinTest( def test_render_joinedload(self): A = self.classes.A - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(A).options(joinedload(A.d)), "SELECT a.id AS a_id, a.name AS a_name, a.b_id AS a_b_id, " @@ -4551,7 +4549,7 @@ class SecondaryNestedJoinTest( def test_render_lazyload(self): A = self.classes.A - sess = Session() + sess = fixture_session() a1 = sess.query(A).filter(A.name == "a1").first() def go(): @@ -4577,14 +4575,14 @@ class SecondaryNestedJoinTest( def test_join(self): A, D = self.classes.A, self.classes.D - sess = Session() + sess = fixture_session() for a, d in sess.query(A, D).outerjoin(A.d): eq_(self.mapping[a.name], d.name if d is not None else None) def test_joinedload(self): A = self.classes.A - sess = Session() + sess = fixture_session() for a in sess.query(A).options(joinedload(A.d)): d = a.d @@ -4592,7 +4590,7 @@ class SecondaryNestedJoinTest( def test_lazyload(self): A = self.classes.A - sess = Session() + sess = fixture_session() for a in sess.query(A): d = a.d @@ -5343,7 +5341,7 @@ class ActiveHistoryFlagTest(_fixtures.FixtureTest): run_deletes = None def _test_attribute(self, obj, attrname, newvalue): - sess = Session() + sess = fixture_session() sess.add(obj) oldvalue = getattr(obj, attrname) sess.commit() @@ -5627,7 +5625,7 @@ class InactiveHistoryNoRaiseTest(_fixtures.FixtureTest): }, ) - s = Session() + s = fixture_session() a1 = Address(email_address="a1") u1 = User(name="u1", addresses=[a1]) @@ -5727,7 +5725,7 @@ class RaiseLoadTest(_fixtures.FixtureTest): users, properties=dict(addresses=relationship(Address, lazy="raise")), ) - q = create_session().query(User) + q = fixture_session().query(User) result = [None] def go(): @@ -5753,7 +5751,7 @@ class RaiseLoadTest(_fixtures.FixtureTest): mapper(Address, addresses) mapper(User, users, properties=dict(addresses=relationship(Address))) - q = create_session().query(User) + q = fixture_session().query(User) result = [None] def go(): @@ -5787,7 +5785,7 @@ class RaiseLoadTest(_fixtures.FixtureTest): users, properties=dict(addresses=relationship(Address, lazy="raise")), ) - q = create_session().query(User).options(sa.orm.lazyload("addresses")) + q = fixture_session().query(User).options(sa.orm.lazyload("addresses")) result = [None] def go(): @@ -5810,7 +5808,7 @@ class RaiseLoadTest(_fixtures.FixtureTest): ) mapper(Address, addresses, properties={"user": relationship(User)}) mapper(User, users) - s = Session() + s = fixture_session() a1 = ( s.query(Address) .filter_by(id=1) @@ -5836,7 +5834,7 @@ class RaiseLoadTest(_fixtures.FixtureTest): ) mapper(Address, addresses, properties={"user": relationship(User)}) mapper(User, users) - s = Session() + s = fixture_session() a1 = ( s.query(Address) .filter_by(id=1) @@ -5886,7 +5884,7 @@ class RaiseLoadTest(_fixtures.FixtureTest): }, ) mapper(User, users) - s = Session() + s = fixture_session() u1 = s.query(User).first() # noqa a1 = ( s.query(Address) @@ -5917,7 +5915,7 @@ class RaiseLoadTest(_fixtures.FixtureTest): properties=dict(addresses=relationship(Address, backref="user")), ) q = ( - create_session() + fixture_session() .query(User, Address) .join(Address, User.id == Address.user_id) ) @@ -5955,7 +5953,7 @@ class RaiseLoadTest(_fixtures.FixtureTest): properties=dict(addresses=relationship(Address, backref="user")), ) q = ( - create_session() + fixture_session() .query(User, Address) .join(Address, User.id == Address.user_id) ) @@ -6052,7 +6050,7 @@ class RelationDeprecationTest(fixtures.MappedTest): ) mapper(Address, addresses_table) - session = create_session() + session = fixture_session() session.query(User).filter( User.addresses.any(Address.email_address == "ed@foo.bar") @@ -6123,7 +6121,7 @@ class SecondaryIncludesLocalColsTest(fixtures.MappedTest): def test_query_join(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() with assert_engine(testing.db) as asserter_: rows = s.query(A.id, B.id).join(A.bs).order_by(A.id, B.id).all() @@ -6143,7 +6141,7 @@ class SecondaryIncludesLocalColsTest(fixtures.MappedTest): def test_eager_join(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() with assert_engine(testing.db) as asserter_: a2 = ( @@ -6166,7 +6164,7 @@ class SecondaryIncludesLocalColsTest(fixtures.MappedTest): def test_exists(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() with assert_engine(testing.db) as asserter_: eq_(set(id_ for id_, in s.query(A.id).filter(A.bs.any())), {1, 2}) @@ -6184,7 +6182,7 @@ class SecondaryIncludesLocalColsTest(fixtures.MappedTest): def test_eager_selectin(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() with assert_engine(testing.db) as asserter_: a2 = ( diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py index d1ed9acc1..5386fd112 100644 --- a/test/orm/test_scoping.py +++ b/test/orm/test_scoping.py @@ -39,7 +39,7 @@ class ScopedSessionTest(fixtures.MappedTest): def test_basic(self): table2, table1 = self.tables.table2, self.tables.table1 - Session = scoped_session(sa.orm.sessionmaker()) + Session = scoped_session(sa.orm.sessionmaker(testing.db)) class CustomQuery(query.Query): pass diff --git a/test/orm/test_selectable.py b/test/orm/test_selectable.py index 502df314a..c22391b44 100644 --- a/test/orm/test_selectable.py +++ b/test/orm/test_selectable.py @@ -12,6 +12,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -44,7 +45,7 @@ class SelectableNoFromsTest(fixtures.MappedTest, AssertsCompiledSQL): mapper(Subset, selectable, primary_key=[selectable.c.x]) self.assert_compile( - Session().query(Subset), + fixture_session().query(Subset), "SELECT anon_1.x AS anon_1_x, anon_1.y AS anon_1_y, " "anon_1.z AS anon_1_z FROM (SELECT x, y, z) AS anon_1", use_default_dialect=True, diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py index de6282da5..5535fe5d6 100644 --- a/test/orm/test_selectin_relations.py +++ b/test/orm/test_selectin_relations.py @@ -8,7 +8,6 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import aliased from sqlalchemy.orm import clear_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import defaultload from sqlalchemy.orm import defer from sqlalchemy.orm import deferred @@ -31,6 +30,7 @@ from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import assert_engine from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm import _fixtures @@ -64,7 +64,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User).options(selectinload(User.addresses)) @@ -122,7 +122,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def go(): - sess = create_session() + sess = fixture_session() u = aliased(User) @@ -148,7 +148,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): for i in range(3): def go(): - sess = create_session() + sess = fixture_session() u = aliased(User) @@ -165,7 +165,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): for i in range(3): def go(): - sess = create_session() + sess = fixture_session() u = aliased(User) @@ -217,7 +217,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User).options(selectinload(User.addresses)) @@ -249,7 +249,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User).options(selectinload(User.addresses)) @@ -280,7 +280,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): properties={"addresses": relationship(Address, lazy="dynamic")}, ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() # previously this would not raise, but would emit # the query needlessly and put the result nowhere. @@ -314,7 +314,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - q = create_session().query(Item).order_by(Item.id) + q = fixture_session().query(Item).order_by(Item.id) def go(): eq_(self.static.item_keyword_result, q.all()) @@ -344,7 +344,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - q = create_session().query(Item).order_by(Item.id) + q = fixture_session().query(Item).order_by(Item.id) def go(): eq_( @@ -377,7 +377,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - q = create_session().query(Item).order_by(Item.id) + q = fixture_session().query(Item).order_by(Item.id) def go(): ka = aliased(Keyword) @@ -407,7 +407,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - q = create_session().query(User) + q = fixture_session().query(User) eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -444,7 +444,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - q = create_session().query(User) + q = fixture_session().query(User) eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -484,7 +484,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - q = create_session().query(User) + q = fixture_session().query(User) result = ( q.filter(User.id == Address.user_id) .order_by(Address.email_address) @@ -527,7 +527,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - sess = create_session() + sess = fixture_session() eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -713,7 +713,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def _do_query_tests(self, opts, count): Order, User = self.classes.Order, self.classes.User - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -732,7 +732,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.static.user_item_keyword_result[2:3], ) - sess = create_session() + sess = fixture_session() eq_( sess.query(User) .options(*opts) @@ -772,7 +772,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) is_(sa.orm.class_mapper(Address).get_property("user").lazy, "selectin") - sess = create_session() + sess = fixture_session() eq_( self.static.user_address_result, sess.query(User).order_by(User.id).all(), @@ -810,7 +810,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) is_(sa.orm.class_mapper(Address).get_property("user").lazy, "selectin") - sess = create_session() + sess = fixture_session() eq_( self.static.user_address_result, sess.query(User).order_by(User.id).all(), @@ -1034,7 +1034,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): User, Address, Order, Item = self.classes( "User", "Address", "Order", "Item" ) - q = create_session().query(User).order_by(User.id) + q = fixture_session().query(User).order_by(User.id) def items(*ids): if no_items: @@ -1133,7 +1133,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) result = q.order_by(User.id).limit(2).offset(1).all() @@ -1159,7 +1159,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - q = create_session().query(User) + q = fixture_session().query(User) def go(): result = q.filter(users.c.id == 7).all() @@ -1184,7 +1184,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - q = create_session().query(User) + q = fixture_session().query(User) def go(): result = q.filter(users.c.id == 10).all() @@ -1207,7 +1207,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): user=relationship(mapper(User, users), lazy="selectin") ), ) - sess = create_session() + sess = fixture_session() q = sess.query(Address) def go(): @@ -1233,7 +1233,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session(autoflush=False) q = sess.query(Order).filter(Order.id.in_([4, 5])).order_by(Order.id) o4, o5 = q.all() @@ -1269,7 +1269,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() q = sess.query(Order).filter(Order.id.in_([4, 5])).order_by(Order.id) o4, o5 = q.all() @@ -1293,7 +1293,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - q = create_session().query(User) + q = fixture_session().query(User) result = q.filter(users.c.id == 10).all() u1 = result[0] @@ -1334,7 +1334,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): }, ) - q = create_session().query(User) + q = fixture_session().query(User) def go(): eq_( @@ -1374,7 +1374,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): properties={"order": relationship(Order, uselist=False)}, ) mapper(Order, orders) - s = create_session() + s = fixture_session() assert_raises( sa.exc.SAWarning, s.query(User).options(selectinload(User.order)).all, @@ -1405,7 +1405,7 @@ class LoadOnExistingTest(_fixtures.FixtureTest): ) mapper(Dingaling, self.tables.dingalings) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) return User, Address, Dingaling, sess def _collection_to_collection_fixture(self): @@ -1426,7 +1426,7 @@ class LoadOnExistingTest(_fixtures.FixtureTest): ) mapper(Item, self.tables.items) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) return User, Order, Item, sess def _eager_config_fixture(self): @@ -1437,7 +1437,7 @@ class LoadOnExistingTest(_fixtures.FixtureTest): properties={"addresses": relationship(Address, lazy="selectin")}, ) mapper(Address, self.tables.addresses) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) return User, Address, sess def _deferred_config_fixture(self): @@ -1451,7 +1451,7 @@ class LoadOnExistingTest(_fixtures.FixtureTest): }, ) mapper(Address, self.tables.addresses) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) return User, Address, sess def test_runs_query_on_refresh(self): @@ -1622,7 +1622,7 @@ class OrderBySecondaryTest(fixtures.MappedTest): ) mapper(B, b) - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -1730,12 +1730,12 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): Paperwork(description="tps report #2"), ] e2.paperwork = [Paperwork(description="tps report #3")] - sess = create_session(connection) + sess = Session(connection) sess.add_all([e1, e2]) sess.flush() def test_correct_select_nofrom(self): - sess = create_session() + sess = fixture_session() # use Person.paperwork here just to give the least # amount of context q = ( @@ -1778,7 +1778,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): ) def test_correct_select_existingfrom(self): - sess = create_session() + sess = fixture_session() # use Person.paperwork here just to give the least # amount of context q = ( @@ -1829,7 +1829,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): def test_correct_select_with_polymorphic_no_alias(self): # test #3106 - sess = create_session() + sess = fixture_session() wp = with_polymorphic(Person, [Engineer]) q = ( @@ -1875,7 +1875,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): def test_correct_select_with_polymorphic_alias(self): # test #3106 - sess = create_session() + sess = fixture_session() wp = with_polymorphic(Person, [Engineer], aliased=True) q = ( @@ -1929,7 +1929,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): def test_correct_select_with_polymorphic_flat_alias(self): # test #3106 - sess = create_session() + sess = fixture_session() wp = with_polymorphic(Person, [Engineer], aliased=True, flat=True) q = ( @@ -2061,7 +2061,7 @@ class HeterogeneousSubtypesTest(fixtures.DeclarativeMappedTest): Company, Programmer, Manager, GolfSwing, Language = self.classes( "Company", "Programmer", "Manager", "GolfSwing", "Language" ) - sess = Session() + sess = fixture_session() company = ( sess.query(Company) .filter(Company.id == 1) @@ -2082,7 +2082,7 @@ class HeterogeneousSubtypesTest(fixtures.DeclarativeMappedTest): Company, Programmer, Manager, GolfSwing, Language = self.classes( "Company", "Programmer", "Manager", "GolfSwing", "Language" ) - sess = Session() + sess = fixture_session() company = ( sess.query(Company) .filter(Company.id == 2) @@ -2105,7 +2105,7 @@ class HeterogeneousSubtypesTest(fixtures.DeclarativeMappedTest): Company, Programmer, Manager, GolfSwing, Language = self.classes( "Company", "Programmer", "Manager", "GolfSwing", "Language" ) - sess = Session() + sess = fixture_session() rows = ( sess.query(Company) .options( @@ -2169,7 +2169,7 @@ class TupleTest(fixtures.DeclarativeMappedTest): def test_load_o2m(self): A, B = self.classes("A", "B") - session = Session() + session = fixture_session() def go(): q = ( @@ -2205,7 +2205,7 @@ class TupleTest(fixtures.DeclarativeMappedTest): def test_load_m2o(self): A, B = self.classes("A", "B") - session = Session() + session = fixture_session() def go(): q = session.query(B).options(selectinload(B.a)).order_by(B.id) @@ -2277,7 +2277,7 @@ class ChunkingTest(fixtures.DeclarativeMappedTest): def test_odd_number_chunks(self): A, B = self.classes("A", "B") - session = Session() + session = fixture_session() def go(): with mock.patch( @@ -2320,7 +2320,7 @@ class ChunkingTest(fixtures.DeclarativeMappedTest): import random - session = Session() + session = fixture_session() yield_per = random.randint(8, 105) offset = random.randint(0, 19) @@ -2350,7 +2350,7 @@ class ChunkingTest(fixtures.DeclarativeMappedTest): def test_dont_emit_for_redundant_m2o(self): A, B = self.classes("A", "B") - session = Session() + session = fixture_session() def go(): with mock.patch( @@ -2498,7 +2498,7 @@ class SubRelationFromJoinedSubclassMultiLevelTest(_Polymorphic): @classmethod def insert_data(cls, connection): c1 = cls._fixture() - sess = create_session(connection) + sess = Session(connection) sess.add(c1) sess.flush() @@ -2526,7 +2526,7 @@ class SubRelationFromJoinedSubclassMultiLevelTest(_Polymorphic): ) def test_chained_selectin_subclass(self): - s = Session() + s = fixture_session() q = s.query(Company).options( selectinload(Company.employees.of_type(Engineer)) .selectinload(Engineer.machines) @@ -2568,7 +2568,7 @@ class SelfReferentialTest(fixtures.MappedTest): ) }, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -2644,7 +2644,7 @@ class SelfReferentialTest(fixtures.MappedTest): ) }, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -2691,7 +2691,7 @@ class SelfReferentialTest(fixtures.MappedTest): "data": deferred(nodes.c.data), }, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -2744,7 +2744,7 @@ class SelfReferentialTest(fixtures.MappedTest): nodes, properties={"children": relationship(Node, order_by=nodes.c.id)}, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -2799,7 +2799,7 @@ class SelfReferentialTest(fixtures.MappedTest): nodes, properties={"children": relationship(Node, lazy="selectin")}, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -2891,7 +2891,7 @@ class SelfRefInheritanceAliasedTest( attr1 = Foo.foo.of_type(r) attr2 = r.foo - s = Session() + s = fixture_session() q = ( s.query(Foo) .filter(Foo.id == 2) @@ -2996,7 +2996,7 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): def test_o2m(self): A, A2, B, C1o2m, C2o2m = self.classes("A", "A2", "B", "C1o2m", "C2o2m") - s = Session() + s = fixture_session() # A -J-> B -L-> C1 # A -J-> B -S-> C2 @@ -3017,7 +3017,7 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): def test_m2o(self): A, A2, B, C1m2o, C2m2o = self.classes("A", "A2", "B", "C1m2o", "C2m2o") - s = Session() + s = fixture_session() # A -J-> B -L-> C1 # A -J-> B -S-> C2 @@ -3070,7 +3070,7 @@ class SingleInhSubclassTest( def test_load(self): (EmployerUser,) = self.classes("EmployerUser") - s = Session() + s = fixture_session() q = s.query(EmployerUser) @@ -3132,7 +3132,7 @@ class MissingForeignTest( def test_missing_rec(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() eq_( s.query(A).options(selectinload(A.b)).order_by(A.id).all(), [ @@ -3191,7 +3191,7 @@ class M2OWDegradeTest( def test_use_join_parent_criteria(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() q = ( s.query(A) .filter(A.id.in_([1, 3])) @@ -3220,7 +3220,7 @@ class M2OWDegradeTest( def test_use_join_parent_criteria_degrade_on_defer(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() q = ( s.query(A) .filter(A.id.in_([1, 3])) @@ -3255,7 +3255,7 @@ class M2OWDegradeTest( def test_use_join(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() q = s.query(A).options(selectinload(A.b)).order_by(A.id) results = self.assert_sql_execution( testing.db, @@ -3286,7 +3286,7 @@ class M2OWDegradeTest( def test_use_join_omit_join_false(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() q = s.query(A).options(selectinload(A.b_no_omit_join)).order_by(A.id) results = self.assert_sql_execution( testing.db, @@ -3318,7 +3318,7 @@ class M2OWDegradeTest( def test_use_join_parent_degrade_on_defer(self): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() q = s.query(A).options(defer(A.b_id), selectinload(A.b)).order_by(A.id) results = self.assert_sql_execution( testing.db, @@ -3423,7 +3423,7 @@ class SameNamePolymorphicTest(fixtures.DeclarativeMappedTest): GenericParent, ParentA, ParentB, ChildA, ChildB = self.classes( "GenericParent", "ParentA", "ParentB", "ChildA", "ChildB" ) - session = Session() + session = fixture_session() parent_types = with_polymorphic(GenericParent, [ParentA, ParentB]) @@ -3516,7 +3516,7 @@ class TestBakedCancelsCorrectly(fixtures.DeclarativeMappedTest): # the cache spoil did not use full=True which kept the lead # entities around. - sess = Session() + sess = fixture_session() foo_polymorphic = with_polymorphic(Foo, [SubFoo], aliased=True) credit_adjustment_load = selectinload( diff --git a/test/orm/test_session.py b/test/orm/test_session.py index a9e962cde..20c4752b8 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -35,7 +35,7 @@ from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing import pickleable -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect @@ -203,11 +203,11 @@ class SessionUtilTest(_fixtures.FixtureTest): mapper(User, users) - s1 = Session() + s1 = fixture_session() u1 = User() s1.add(u1) - s2 = Session() + s2 = fixture_session() u2 = User() s2.add(u2) @@ -224,11 +224,11 @@ class SessionUtilTest(_fixtures.FixtureTest): mapper(User, users) - s1 = Session() + s1 = fixture_session() u1 = User() s1.add(u1) - s2 = Session() + s2 = fixture_session() u2 = User() s2.add(u2) @@ -255,7 +255,7 @@ class SessionUtilTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = create_session() + sess = fixture_session(autoflush=False) sess.add(User(name="test")) sess.flush() @@ -303,7 +303,7 @@ class SessionUtilTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(name="test") sess.add(u1) sess.commit() @@ -318,7 +318,7 @@ class SessionUtilTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(id=1, name="test") sess.add(u1) sess.commit() @@ -334,7 +334,7 @@ class SessionUtilTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(id=1, name="test") sess.add(u1) assert_raises_message( @@ -348,7 +348,7 @@ class SessionUtilTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(id=1, name="test") sess.add(u1) sess.commit() @@ -367,7 +367,7 @@ class SessionStateTest(_fixtures.FixtureTest): __prefer_requires__ = ("independent_connections",) def test_info(self): - s = Session() + s = fixture_session() eq_(s.info, {}) maker = sessionmaker(info={"global": True, "s1": 5}) @@ -392,7 +392,7 @@ class SessionStateTest(_fixtures.FixtureTest): def test_autoflush(self): User, users = self.classes.User, self.tables.users - bind = self.metadata.bind + bind = testing.db mapper(User, users) conn1 = bind.connect() conn2 = bind.connect() @@ -419,7 +419,7 @@ class SessionStateTest(_fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = Session() + sess = fixture_session() u = User() u.name = "ed" @@ -473,7 +473,7 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) - sess = sessionmaker()() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -504,7 +504,7 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) sess.commit() @@ -537,7 +537,7 @@ class SessionStateTest(_fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - with create_session(autocommit=False, autoflush=True) as sess: + with fixture_session(autocommit=False, autoflush=True) as sess: u = User() u.name = "ed" sess.add(u) @@ -653,7 +653,7 @@ class SessionStateTest(_fixtures.FixtureTest): self.tables.users, ) - s = create_session() + s = fixture_session() mapper( User, users, @@ -692,7 +692,7 @@ class SessionStateTest(_fixtures.FixtureTest): assert user in s assert user not in s.dirty - s2 = create_session() + s2 = fixture_session() assert_raises_message( sa.exc.InvalidRequestError, "is already attached to session", @@ -722,8 +722,8 @@ class SessionStateTest(_fixtures.FixtureTest): users = self.tables.users mapper(User, users) - s1 = Session() - s2 = Session() + s1 = fixture_session() + s2 = fixture_session() u1 = User(id=1, name="u1") make_transient_to_detached(u1) # shorthand for actually persisting it @@ -743,7 +743,7 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) - with create_session() as s: + with fixture_session() as s: s.execute(users.delete()) u1 = User(name="ed") s.add(u1) @@ -775,7 +775,7 @@ class SessionStateTest(_fixtures.FixtureTest): ) mapper(Address, addresses) - session = Session() + session = fixture_session() @event.listens_for(session, "after_flush") def load_collections(session, flush_context): @@ -806,8 +806,8 @@ class SessionStateTest(_fixtures.FixtureTest): users, User = self.tables.users, pickleable.User mapper(User, users) - sess1 = create_session() - sess2 = create_session() + sess1 = fixture_session() + sess2 = fixture_session() u1 = User(name="u1") sess1.add(u1) assert_raises_message( @@ -824,7 +824,7 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) Session = sessionmaker() - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -870,7 +870,7 @@ class SessionStateTest(_fixtures.FixtureTest): def test_no_double_save(self): users = self.tables.users - sess = create_session() + sess = fixture_session() class Foo(object): def __init__(self): @@ -893,7 +893,7 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) - sess = Session() + sess = fixture_session() sess.add_all([User(name="u1"), User(name="u2"), User(name="u3")]) sess.commit() @@ -911,7 +911,7 @@ class SessionStateTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User m = mapper(User, users) - s = Session() + s = fixture_session() @event.listens_for(m, "after_update") def e(mapper, conn, target): @@ -955,7 +955,7 @@ class SessionStateTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session() + sess = fixture_session() sess.add(User(name="x")) sess.commit() @@ -976,7 +976,7 @@ class SessionStateTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(name="x") sess.add(u1) @@ -995,7 +995,7 @@ class SessionStateTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session() + sess = fixture_session() sess.add(User(name="x")) sess.commit() @@ -1047,7 +1047,7 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): """ User, Address = self.classes("User", "Address") - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session(autoflush=True, autocommit=False) u = User(name="ed", addresses=[Address(email_address="foo")]) sess.add(u) eq_( @@ -1058,7 +1058,9 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): def test_deferred_expression_obj_was_gced(self): User, Address = self.classes("User", "Address") - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session( + autoflush=True, autocommit=False, expire_on_commit=False + ) u = User(name="ed", addresses=[Address(email_address="foo")]) sess.add(u) @@ -1078,7 +1080,9 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session( + autoflush=True, autocommit=False, expire_on_commit=False + ) u = User(name="ed", addresses=[Address(email_address="foo")]) sess.add(u) sess.commit() @@ -1091,7 +1095,7 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): def test_deferred_expression_obj_was_never_flushed(self): User, Address = self.classes("User", "Address") - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session(autoflush=True, autocommit=False) u = User(name="ed", addresses=[Address(email_address="foo")]) assert_raises_message( @@ -1120,7 +1124,7 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): def test_deferred_expression_unflushed_obj_became_detached_unexpired(self): User, Address = self.classes("User", "Address") - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session(autoflush=True, autocommit=False) u = User(name="ed", addresses=[Address(email_address="foo")]) q = sess.query(Address).filter(Address.user == u) @@ -1134,7 +1138,7 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): def test_deferred_expression_unflushed_obj_became_detached_expired(self): User, Address = self.classes("User", "Address") - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session(autoflush=True, autocommit=False) u = User(name="ed", addresses=[Address(email_address="foo")]) q = sess.query(Address).filter(Address.user == u) @@ -1149,7 +1153,7 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): def test_deferred_expr_unflushed_obj_became_detached_expired_by_key(self): User, Address = self.classes("User", "Address") - sess = create_session(autoflush=True, autocommit=False) + sess = fixture_session(autoflush=True, autocommit=False) u = User(name="ed", addresses=[Address(email_address="foo")]) q = sess.query(Address).filter(Address.user == u) @@ -1164,7 +1168,7 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): def test_deferred_expression_expired_obj_became_detached_expired(self): User, Address = self.classes("User", "Address") - sess = create_session( + sess = fixture_session( autoflush=True, autocommit=False, expire_on_commit=True ) u = User(name="ed", addresses=[Address(email_address="foo")]) @@ -1207,7 +1211,7 @@ class SessionStateWFixtureTest(_fixtures.FixtureTest): mapper(Address, addresses) mapper(User, users, properties={"addresses": relationship(Address)}) - sess = create_session(autocommit=False, autoflush=True) + sess = fixture_session(autocommit=False, autoflush=True) u = sess.query(User).get(8) newad = Address(email_address="a new address") u.addresses.append(newad) @@ -1244,7 +1248,7 @@ class SessionStateWFixtureTest(_fixtures.FixtureTest): }, ) - session = create_session() + session = fixture_session() u = session.query(User).filter_by(id=7).one() # get everything to load in both directions @@ -1288,7 +1292,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): User = self.classes.User u1 = User() u1.name = "ed" - sess = Session() + sess = fixture_session() sess.add(u1) sess.flush() return sess, u1 @@ -1306,7 +1310,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): u1.name = "ed" self._assert_modified(u1) self._assert_no_cycle(u1) - sess = Session() + sess = fixture_session() sess.add(u1) self._assert_cycle(u1) sess.flush() @@ -1366,7 +1370,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): def test_move_persistent_clean(self): sess, u1 = self._persistent_fixture() sess.close() - s2 = Session() + s2 = fixture_session() s2.add(u1) self._assert_no_cycle(u1) self._assert_not_modified(u1) @@ -1378,7 +1382,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): self._assert_modified(u1) sess.close() self._assert_no_cycle(u1) - s2 = Session() + s2 = fixture_session() s2.add(u1) self._assert_cycle(u1) self._assert_modified(u1) @@ -1392,7 +1396,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): del sess gc_collect() self._assert_cycle(u1) - s2 = Session() + s2 = fixture_session() s2.add(u1) self._assert_cycle(u1) self._assert_modified(u1) @@ -1417,7 +1421,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User - s = create_session() + s = fixture_session() mapper(User, users) s.add(User(name="ed")) @@ -1449,7 +1453,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): def test_weakref_pickled(self): users, User = self.tables.users, pickleable.User - s = create_session() + s = fixture_session() mapper(User, users) s.add(User(name="ed")) @@ -1486,7 +1490,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): self.classes.User, ) - s = sessionmaker()() + s = fixture_session() mapper( User, users, @@ -1524,7 +1528,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): self.classes.User, ) - s = sessionmaker()() + s = fixture_session() mapper( User, users, @@ -1561,7 +1565,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -1569,7 +1573,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): # can't add u1 to Session, # already belongs to u2 - s2 = Session() + s2 = fixture_session() assert_raises_message( sa.exc.InvalidRequestError, r".*is already attached to session", @@ -1592,7 +1596,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(name="u1") sess.add(u1) @@ -1629,7 +1633,7 @@ class IsModifiedTest(_fixtures.FixtureTest): def test_is_modified(self): User, Address = self._default_mapping_fixture() - s = create_session() + s = fixture_session() # save user u = User(name="fred") @@ -1662,7 +1666,7 @@ class IsModifiedTest(_fixtures.FixtureTest): User, Address = self._default_mapping_fixture() - s = Session() + s = fixture_session() u = User(name="fred", addresses=[Address(email_address="foo")]) s.add(u) s.commit() @@ -1687,7 +1691,7 @@ class IsModifiedTest(_fixtures.FixtureTest): def test_is_modified_syn(self): User, users = self.classes.User, self.tables.users - s = sessionmaker()() + s = fixture_session() mapper(User, users, properties={"uname": sa.orm.synonym("name")}) u = User(uname="fred") @@ -1850,7 +1854,7 @@ class SessionInterface(fixtures.TestBase): ) def raises_(method, *args, **kw): - x_raises_(create_session(), method, *args, **kw) + x_raises_(fixture_session(), method, *args, **kw) for name in [ "__contains__", @@ -1874,7 +1878,7 @@ class SessionInterface(fixtures.TestBase): self._map_it(OK) - s = create_session() + s = fixture_session() s.add(OK()) x_raises_(s, "flush", (user_arg,)) @@ -1903,7 +1907,10 @@ class SessionInterface(fixtures.TestBase): def raises_(method, *args, **kw): watchdog.add(method) - callable_ = getattr(Session(), method) + callable_ = getattr( + Session(), + method, + ) if is_class: assert_raises( sa.orm.exc.UnmappedClassError, callable_, *args, **kw @@ -1976,7 +1983,7 @@ class SessionInterface(fixtures.TestBase): self._map_it(Mapped) m1 = Mapped() - s = create_session() + s = fixture_session() with mock.patch.object(s, "_validate_persistent"): assert_raises_message( @@ -2107,7 +2114,7 @@ class FlushWarningsTest(fixtures.MappedTest): User = self.classes.User Address = self.classes.Address - s = Session() + s = fixture_session() event.listen(User, "after_insert", fn) u1 = User(name="u1", addresses=[Address(name="a1")]) diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py index 280a4355f..fe20442a3 100644 --- a/test/orm/test_subquery_relations.py +++ b/test/orm/test_subquery_relations.py @@ -10,7 +10,6 @@ from sqlalchemy.orm import aliased from sqlalchemy.orm import backref from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import close_all_sessions -from sqlalchemy.orm import create_session from sqlalchemy.orm import deferred from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper @@ -28,6 +27,7 @@ from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.entities import ComparableEntity +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm import _fixtures @@ -62,7 +62,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User).options(subqueryload(User.addresses)) @@ -106,7 +106,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): }, ) query_cache = {} - sess = create_session() + sess = fixture_session() u1 = ( sess.query(User) @@ -155,7 +155,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): User, Dingaling, Address = self.user_dingaling_fixture() for i in range(3): - sess = create_session() + sess = fixture_session() u = aliased(User) @@ -180,7 +180,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): User, Dingaling, Address = self.user_dingaling_fixture() for i in range(3): - sess = create_session() + sess = fixture_session() u = aliased(User) @@ -195,7 +195,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): User, Dingaling, Address = self.user_dingaling_fixture() for i in range(3): - sess = create_session() + sess = fixture_session() u = aliased(User) q = sess.query(u).options( @@ -248,7 +248,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User).options(subqueryload(User.addresses)) @@ -280,7 +280,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User).options(subqueryload(User.addresses)) @@ -311,7 +311,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): properties={"addresses": relationship(Address, lazy="dynamic")}, ) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() # previously this would not raise, but would emit # the query needlessly and put the result nowhere. @@ -345,7 +345,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - q = create_session().query(Item).order_by(Item.id) + q = fixture_session().query(Item).order_by(Item.id) def go(): eq_(self.static.item_keyword_result, q.all()) @@ -375,7 +375,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - q = create_session().query(Item).order_by(Item.id) + q = fixture_session().query(Item).order_by(Item.id) def go(): eq_( @@ -408,7 +408,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - q = create_session().query(Item).order_by(Item.id) + q = fixture_session().query(Item).order_by(Item.id) def go(): ka = aliased(Keyword) @@ -438,7 +438,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - q = create_session().query(User) + q = fixture_session().query(User) eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -475,7 +475,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) }, ) - q = create_session().query(User) + q = fixture_session().query(User) eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -515,7 +515,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ), ) - q = create_session().query(User) + q = fixture_session().query(User) result = ( q.filter(User.id == Address.user_id) .order_by(Address.email_address) @@ -558,7 +558,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - sess = create_session() + sess = fixture_session() eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -734,7 +734,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def _do_query_tests(self, opts, count): Order, User = self.classes.Order, self.classes.User - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -753,7 +753,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.static.user_item_keyword_result[2:3], ) - sess = create_session() + sess = fixture_session() eq_( sess.query(User) .options(*opts) @@ -793,7 +793,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) is_(sa.orm.class_mapper(Address).get_property("user").lazy, "subquery") - sess = create_session() + sess = fixture_session() eq_( self.static.user_address_result, sess.query(User).order_by(User.id).all(), @@ -831,7 +831,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) is_(sa.orm.class_mapper(Address).get_property("user").lazy, "subquery") - sess = create_session() + sess = fixture_session() eq_( self.static.user_address_result, sess.query(User).order_by(User.id).all(), @@ -852,7 +852,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): properties=dict(addresses=relationship(Address, lazy="subquery")), ) - sess = create_session() + sess = fixture_session() self.assert_compile( sess.query(User, literal_column("1")), @@ -1071,7 +1071,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): User, Address, Order, Item = self.classes( "User", "Address", "Order", "Item" ) - q = create_session().query(User).order_by(User.id) + q = fixture_session().query(User).order_by(User.id) def items(*ids): if no_items: @@ -1170,7 +1170,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): }, ) - sess = create_session() + sess = fixture_session() q = sess.query(User) result = q.order_by(User.id).limit(2).offset(1).all() @@ -1200,7 +1200,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): }, ) - q = create_session().query(User) + q = fixture_session().query(User) eq_( [ User(id=7, addresses=[Address(id=1)]), @@ -1235,7 +1235,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) ), ) - q = create_session().query(User) + q = fixture_session().query(User) def go(): result = q.filter(users.c.id == 7).all() @@ -1258,7 +1258,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): user=relationship(mapper(User, users), lazy="subquery") ), ) - sess = create_session() + sess = fixture_session() q = sess.query(Address) def go(): @@ -1304,7 +1304,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): }, ) - q = create_session().query(User) + q = fixture_session().query(User) def go(): eq_( @@ -1344,7 +1344,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): properties={"order": relationship(Order, uselist=False)}, ) mapper(Order, orders) - s = create_session() + s = fixture_session() assert_raises( sa.exc.SAWarning, s.query(User).options(subqueryload(User.order)).all, @@ -1375,7 +1375,7 @@ class LoadOnExistingTest(_fixtures.FixtureTest): ) mapper(Dingaling, self.tables.dingalings) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) return User, Address, Dingaling, sess def _collection_to_collection_fixture(self): @@ -1396,7 +1396,7 @@ class LoadOnExistingTest(_fixtures.FixtureTest): ) mapper(Item, self.tables.items) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) return User, Order, Item, sess def _eager_config_fixture(self): @@ -1407,7 +1407,7 @@ class LoadOnExistingTest(_fixtures.FixtureTest): properties={"addresses": relationship(Address, lazy="subquery")}, ) mapper(Address, self.tables.addresses) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) return User, Address, sess def _deferred_config_fixture(self): @@ -1421,7 +1421,7 @@ class LoadOnExistingTest(_fixtures.FixtureTest): }, ) mapper(Address, self.tables.addresses) - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) return User, Address, sess def test_runs_query_on_refresh(self): @@ -1592,7 +1592,7 @@ class OrderBySecondaryTest(fixtures.MappedTest): ) mapper(B, b) - sess = create_session() + sess = fixture_session() def go(): eq_( @@ -1721,12 +1721,12 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): ), ] e2.paperwork = [Paperwork(description="tps report #3")] - sess = create_session(connection) + sess = Session(connection) sess.add_all([e1, e2]) sess.flush() def test_correct_subquery_nofrom(self): - sess = create_session() + sess = fixture_session() # use Person.paperwork here just to give the least # amount of context q = ( @@ -1779,7 +1779,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): ) def test_correct_subquery_existingfrom(self): - sess = create_session() + sess = fixture_session() # use Person.paperwork here just to give the least # amount of context q = ( @@ -1839,7 +1839,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): ) def test_correct_subquery_multilevel(self): - sess = create_session() + sess = fixture_session() # use Person.paperwork here just to give the least # amount of context q = ( @@ -1917,7 +1917,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): def test_correct_subquery_with_polymorphic_no_alias(self): # test #3106 - sess = create_session() + sess = fixture_session() wp = with_polymorphic(Person, [Engineer]) q = ( @@ -1966,7 +1966,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): def test_correct_subquery_with_polymorphic_alias(self): # test #3106 - sess = create_session() + sess = fixture_session() wp = with_polymorphic(Person, [Engineer], aliased=True) q = ( @@ -2033,7 +2033,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): def test_correct_subquery_with_polymorphic_flat_alias(self): # test #3106 - sess = create_session() + sess = fixture_session() wp = with_polymorphic(Person, [Engineer], aliased=True, flat=True) q = ( @@ -2193,7 +2193,7 @@ class SubRelationFromJoinedSubclassMultiLevelTest(_Polymorphic): @classmethod def insert_data(cls, connection): c1 = cls._fixture() - sess = create_session(connection) + sess = Session(connection) sess.add(c1) sess.flush() @@ -2221,7 +2221,7 @@ class SubRelationFromJoinedSubclassMultiLevelTest(_Polymorphic): ) def test_chained_subq_subclass(self): - s = Session() + s = fixture_session() q = s.query(Company).options( subqueryload(Company.employees.of_type(Engineer)) .subqueryload(Engineer.machines) @@ -2263,7 +2263,7 @@ class SelfReferentialTest(fixtures.MappedTest): ) }, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -2339,7 +2339,7 @@ class SelfReferentialTest(fixtures.MappedTest): ) }, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -2386,7 +2386,7 @@ class SelfReferentialTest(fixtures.MappedTest): "data": deferred(nodes.c.data), }, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -2439,7 +2439,7 @@ class SelfReferentialTest(fixtures.MappedTest): nodes, properties={"children": relationship(Node, order_by=nodes.c.id)}, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -2494,7 +2494,7 @@ class SelfReferentialTest(fixtures.MappedTest): nodes, properties={"children": relationship(Node, lazy="subquery")}, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.append(Node(data="n11")) n1.append(Node(data="n12")) @@ -2725,7 +2725,7 @@ class CyclicalInheritingEagerTestOne(fixtures.MappedTest): mapper(SubT2, None, inherits=T2, polymorphic_identity="subt2") # testing a particular endless loop condition in eager load setup - create_session().query(SubT).all() + fixture_session().query(SubT).all() class CyclicalInheritingEagerTestTwo( @@ -2765,7 +2765,7 @@ class CyclicalInheritingEagerTestTwo( def test_from_subclass(self): Director = self.classes.Director - s = create_session() + s = fixture_session() with self.sql_execution_asserter(testing.db) as asserter: s.query(Director).options(subqueryload("*")).all() @@ -2869,7 +2869,7 @@ class SubqueryloadDistinctTest( Movie(title="Manhattan", credits=[Credit(), Credit()]), Movie(title="Sweet and Lowdown", credits=[Credit()]), ] - sess = create_session(connection) + sess = Session(connection) sess.add_all([d]) sess.flush() @@ -2897,7 +2897,7 @@ class SubqueryloadDistinctTest( # Director.photos expect_distinct = director_strategy_level in (True, None) - s = create_session(testing.db) + s = fixture_session() with self.sql_execution_asserter(testing.db) as asserter: result = ( @@ -2963,7 +2963,7 @@ class SubqueryloadDistinctTest( Movie = self.classes.Movie Credit = self.classes.Credit - s = create_session(testing.db) + s = fixture_session() with self.sql_execution_asserter(testing.db) as asserter: result = ( @@ -3047,7 +3047,7 @@ class JoinedNoLoadConflictTest(fixtures.DeclarativeMappedTest): Parent = self.classes.Parent Child = self.classes.Child - s = Session() + s = fixture_session() # here we have # Parent->subqueryload->Child->joinedload->parent->noload->children. @@ -3101,7 +3101,7 @@ class SelfRefInheritanceAliasedTest( attr1 = Foo.foo.of_type(r) attr2 = r.foo - s = Session() + s = fixture_session() q = ( s.query(Foo) .filter(Foo.id == 2) @@ -3211,7 +3211,7 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): def test_o2m(self): A, A2, B, C1o2m, C2o2m = self.classes("A", "A2", "B", "C1o2m", "C2o2m") - s = Session() + s = fixture_session() # A -J-> B -L-> C1 # A -J-> B -S-> C2 @@ -3232,7 +3232,7 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): def test_m2o(self): A, A2, B, C1m2o, C2m2o = self.classes("A", "A2", "B", "C1m2o", "C2m2o") - s = Session() + s = fixture_session() # A -J-> B -L-> C1 # A -J-> B -S-> C2 @@ -3323,7 +3323,7 @@ class FromSubqTest(fixtures.DeclarativeMappedTest): def test_subq_w_from_self_one(self): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() cache = {} @@ -3409,7 +3409,7 @@ class FromSubqTest(fixtures.DeclarativeMappedTest): A, B, C = self.classes("A", "B", "C") - s = Session() + s = fixture_session() cache = {} for i in range(3): diff --git a/test/orm/test_sync.py b/test/orm/test_sync.py index 880f0bd18..76cd7f758 100644 --- a/test/orm/test_sync.py +++ b/test/orm/test_sync.py @@ -3,7 +3,6 @@ from sqlalchemy import Integer from sqlalchemy import testing from sqlalchemy.orm import attributes from sqlalchemy.orm import class_mapper -from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import mapper from sqlalchemy.orm import sync @@ -11,6 +10,7 @@ from sqlalchemy.orm import unitofwork from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -61,7 +61,7 @@ class SyncTest( def _fixture(self): A, B = self.classes.A, self.classes.B - session = create_session() + session = fixture_session() uowcommit = self._get_test_uow(session) a_mapper = class_mapper(A) b_mapper = class_mapper(B) diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index 6eda6fbb6..550cf6535 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -14,7 +14,6 @@ from sqlalchemy import testing from sqlalchemy import text from sqlalchemy.future import Engine from sqlalchemy.orm import attributes -from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship @@ -33,6 +32,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not from sqlalchemy.testing import mock +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.util import gc_collect from test.orm._fixtures import FixtureTest @@ -44,10 +44,9 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): def test_no_close_transaction_on_flush(self): User, users = self.classes.User, self.tables.users - c = testing.db.connect() - try: + with testing.db.connect() as c: mapper(User, users) - s = create_session(bind=c) + s = Session(bind=c) s.begin() tran = s._legacy_transaction() s.add(User(name="first")) @@ -61,8 +60,6 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): s.flush() assert s._legacy_transaction() is tran tran.close() - finally: - c.close() @engines.close_open_connections def test_subtransaction_on_external_subtrans(self): @@ -71,7 +68,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) conn = testing.db.connect() trans = conn.begin() - sess = create_session(bind=conn, autocommit=False, autoflush=True) + sess = Session(bind=conn, autocommit=False, autoflush=True) sess.begin(subtransactions=True) u = User(name="ed") sess.add(u) @@ -88,7 +85,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) conn = testing.db.connect() trans = conn.begin() - sess = create_session(bind=conn, autocommit=False, autoflush=True) + sess = Session(bind=conn, autocommit=False, autoflush=True) u = User(name="ed") sess.add(u) sess.flush() @@ -106,7 +103,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): try: conn = testing.db.connect() trans = conn.begin() - sess = create_session(bind=conn, autocommit=False, autoflush=True) + sess = Session(bind=conn, autocommit=False, autoflush=True) u1 = User(name="u1") sess.add(u1) sess.flush() @@ -134,7 +131,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): conn = engine.connect() conn.begin() - sess = create_session(bind=conn, autocommit=False, autoflush=True) + sess = Session(bind=conn, autocommit=False, autoflush=True) u = User(name="ed") sess.add(u) sess.flush() @@ -154,7 +151,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): conn = engine.connect() conn.begin() - sess = create_session(bind=conn, autocommit=False, autoflush=True) + sess = Session(bind=conn, autocommit=False, autoflush=True) u = User(name="ed") sess.add(u) sess.flush() @@ -174,7 +171,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): with engine.connect() as conn: conn.begin() - sess = create_session(bind=conn, autocommit=False, autoflush=True) + sess = Session(bind=conn, autocommit=False, autoflush=True) u1 = User(name="u1") sess.add(u1) sess.flush() @@ -194,7 +191,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - session = create_session(bind=testing.db) + session = fixture_session() session.begin() session.begin_nested() u1 = User(name="u1") @@ -210,7 +207,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - session = create_session(bind=testing.db) + session = fixture_session() session.begin() u1 = User(name="u1") session.add(u1) @@ -230,7 +227,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): def test_heavy_nesting(self): users = self.tables.users - session = create_session(bind=testing.db) + session = fixture_session() session.begin() session.connection().execute(users.insert().values(name="user1")) session.begin(subtransactions=True) @@ -263,7 +260,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): users = self.tables.users engine = Engine._future_facade(testing.db) - session = create_session(engine, autocommit=False) + session = Session(engine, autocommit=False) session.begin() session.connection().execute(users.insert().values(name="user1")) @@ -354,8 +351,8 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - s1 = create_session(bind=testing.db, autocommit=False) - s2 = create_session(bind=testing.db, autocommit=False) + s1 = fixture_session(autocommit=False) + s2 = fixture_session(autocommit=False) u1 = User(name="u1") s1.add(u1) s1.flush() @@ -377,7 +374,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(Address, addresses) engine2 = engines.testing_engine() - sess = create_session(autocommit=True, autoflush=False, twophase=True) + sess = fixture_session(autocommit=True, autoflush=False, twophase=True) sess.bind_mapper(User, testing.db) sess.bind_mapper(Address, engine2) sess.begin() @@ -387,14 +384,15 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): sess.commit() sess.close() engine2.dispose() - eq_(select(func.count("*")).select_from(users).scalar(), 1) - eq_(select(func.count("*")).select_from(addresses).scalar(), 1) + with testing.db.connect() as conn: + eq_(conn.scalar(select(func.count("*")).select_from(users)), 1) + eq_(conn.scalar(select(func.count("*")).select_from(addresses)), 1) @testing.requires.independent_connections def test_invalidate(self): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = Session() + sess = fixture_session() u = User(name="u1") sess.add(u) sess.flush() @@ -423,7 +421,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = create_session(autocommit=False, autoflush=True) + sess = fixture_session(autocommit=False, autoflush=True) sess.begin(subtransactions=True) u = User(name="u1") sess.add(u) @@ -438,7 +436,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = create_session() + sess = fixture_session() sess.begin() u = User(name="u1") @@ -462,7 +460,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = create_session(autocommit=False) + sess = fixture_session(autocommit=False) u = User(name="u1") sess.add(u) sess.flush() @@ -484,7 +482,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = create_session(testing.db, autocommit=False, future=True) + sess = fixture_session(autocommit=False, future=True) u = User(name="u1") sess.add(u) sess.flush() @@ -507,7 +505,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - sess = create_session(autocommit=True) + sess = fixture_session(autocommit=True) sess.begin() sess.begin_nested() @@ -543,7 +541,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - sess = create_session(autocommit=True) + sess = fixture_session(autocommit=True) sess.begin() sess.begin_nested() @@ -575,7 +573,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - sess = create_session(autocommit=False) + sess = fixture_session(autocommit=False) sess.begin_nested() @@ -653,7 +651,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session() + sess = fixture_session() to_flush = [User(name="ed"), User(name="jack"), User(name="wendy")] @@ -677,7 +675,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session() + sess = fixture_session() @event.listens_for(sess, "after_flush_postexec") def add_another_user(session, ctx): @@ -694,7 +692,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = create_session(autocommit=True) + sess = fixture_session(autocommit=True) sess.begin() sess.begin(subtransactions=True) sess.add(User(name="u1")) @@ -711,7 +709,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): sess.close() def test_no_sql_during_commit(self): - sess = create_session(bind=testing.db, autocommit=False) + sess = fixture_session(autocommit=False) @event.listens_for(sess, "after_commit") def go(session): @@ -725,7 +723,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): ) def test_no_sql_during_prepare(self): - sess = create_session(bind=testing.db, autocommit=False, twophase=True) + sess = fixture_session(autocommit=False, twophase=True) sess.prepare() @@ -738,7 +736,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): ) def test_no_sql_during_rollback(self): - sess = create_session(bind=testing.db, autocommit=False) + sess = fixture_session(autocommit=False) sess.connection() @@ -820,7 +818,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): eq_(session.is_active, True) def test_no_prepare_wo_twophase(self): - sess = create_session(bind=testing.db, autocommit=False) + sess = fixture_session(autocommit=False) assert_raises_message( sa_exc.InvalidRequestError, @@ -830,7 +828,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): ) def test_closed_status_check(self): - sess = create_session() + sess = fixture_session() trans = sess.begin() trans.rollback() assert_raises_message( @@ -845,7 +843,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): ) def test_deactive_status_check(self): - sess = create_session() + sess = fixture_session() trans = sess.begin() trans2 = sess.begin(subtransactions=True) trans2.rollback() @@ -858,7 +856,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): ) def test_deactive_status_check_w_exception(self): - sess = create_session() + sess = fixture_session() trans = sess.begin() trans2 = sess.begin(subtransactions=True) try: @@ -878,7 +876,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(id=1, name="u1") sess.add(u1) sess.commit() @@ -984,7 +982,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(id=1, name="u1") sess.add(u1) sess.commit() @@ -1027,7 +1025,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - session = create_session(autocommit=False) + session = fixture_session(autocommit=False) session.add(User(name="ed")) session._legacy_transaction().commit() @@ -1037,7 +1035,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - session = create_session(testing.db, autocommit=False, future=True) + session = fixture_session(autocommit=False, future=True) session.add(User(name="ed")) session._legacy_transaction().commit() @@ -1073,7 +1071,6 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): class _LocalFixture(FixtureTest): run_setup_mappers = "once" run_inserts = None - session = sessionmaker() @classmethod def setup_mappers(cls): @@ -1351,7 +1348,7 @@ class FixtureDataTest(_LocalFixture): def test_attrs_on_rollback(self): User = self.classes.User - sess = self.session() + sess = fixture_session() u1 = sess.query(User).get(7) u1.name = "ed" sess.rollback() @@ -1359,7 +1356,7 @@ class FixtureDataTest(_LocalFixture): def test_commit_persistent(self): User = self.classes.User - sess = self.session() + sess = fixture_session() u1 = sess.query(User).get(7) u1.name = "ed" sess.flush() @@ -1368,12 +1365,12 @@ class FixtureDataTest(_LocalFixture): def test_concurrent_commit_persistent(self): User = self.classes.User - s1 = self.session() + s1 = fixture_session() u1 = s1.query(User).get(7) u1.name = "ed" s1.commit() - s2 = self.session() + s2 = fixture_session() u2 = s2.query(User).get(7) assert u2.name == "ed" u2.name = "will" @@ -1455,7 +1452,7 @@ class AutoExpireTest(_LocalFixture): def test_expunge_pending_on_rollback(self): User = self.classes.User - sess = self.session() + sess = fixture_session() u2 = User(name="newuser") sess.add(u2) assert u2 in sess @@ -1464,7 +1461,7 @@ class AutoExpireTest(_LocalFixture): def test_trans_pending_cleared_on_commit(self): User = self.classes.User - sess = self.session() + sess = fixture_session() u2 = User(name="newuser") sess.add(u2) assert u2 in sess @@ -1478,7 +1475,7 @@ class AutoExpireTest(_LocalFixture): def test_update_deleted_on_rollback(self): User = self.classes.User - s = self.session() + s = fixture_session() u1 = User(name="ed") s.add(u1) s.commit() @@ -1496,7 +1493,7 @@ class AutoExpireTest(_LocalFixture): def test_gced_delete_on_rollback(self): User, users = self.classes.User, self.tables.users - s = self.session() + s = fixture_session() u1 = User(name="ed") s.add(u1) s.commit() @@ -1531,7 +1528,7 @@ class AutoExpireTest(_LocalFixture): def test_trans_deleted_cleared_on_rollback(self): User = self.classes.User - s = self.session() + s = fixture_session() u1 = User(name="ed") s.add(u1) s.commit() @@ -1545,7 +1542,7 @@ class AutoExpireTest(_LocalFixture): def test_update_deleted_on_rollback_cascade(self): User, Address = self.classes.User, self.classes.Address - s = self.session() + s = fixture_session() u1 = User(name="ed", addresses=[Address(email_address="foo")]) s.add(u1) s.commit() @@ -1561,7 +1558,7 @@ class AutoExpireTest(_LocalFixture): def test_update_deleted_on_rollback_orphan(self): User, Address = self.classes.User, self.classes.Address - s = self.session() + s = fixture_session() u1 = User(name="ed", addresses=[Address(email_address="foo")]) s.add(u1) s.commit() @@ -1577,7 +1574,7 @@ class AutoExpireTest(_LocalFixture): def test_commit_pending(self): User = self.classes.User - sess = self.session() + sess = fixture_session() u1 = User(name="newuser") sess.add(u1) sess.flush() @@ -1586,12 +1583,12 @@ class AutoExpireTest(_LocalFixture): def test_concurrent_commit_pending(self): User = self.classes.User - s1 = self.session() + s1 = fixture_session() u1 = User(name="edward") s1.add(u1) s1.commit() - s2 = self.session() + s2 = fixture_session() u2 = s2.query(User).filter(User.name == "edward").one() u2.name = "will" s2.commit() @@ -1605,7 +1602,7 @@ class TwoPhaseTest(_LocalFixture): @testing.requires.two_phase_transactions def test_rollback_on_prepare(self): User = self.classes.User - s = self.session(twophase=True) + s = fixture_session(twophase=True) u = User(name="ed") s.add(u) @@ -1620,7 +1617,7 @@ class RollbackRecoverTest(_LocalFixture): def test_pk_violation(self): User, Address = self.classes.User, self.classes.Address - s = self.session() + s = fixture_session() a1 = Address(email_address="foo") u1 = User(id=1, name="ed", addresses=[a1]) @@ -1662,7 +1659,7 @@ class RollbackRecoverTest(_LocalFixture): @testing.requires.savepoints def test_pk_violation_with_savepoint(self): User, Address = self.classes.User, self.classes.Address - s = self.session() + s = fixture_session() a1 = Address(email_address="foo") u1 = User(id=1, name="ed", addresses=[a1]) s.add(u1) @@ -1704,7 +1701,7 @@ class SavepointTest(_LocalFixture): @testing.requires.savepoints def test_savepoint_rollback(self): User = self.classes.User - s = self.session() + s = fixture_session() u1 = User(name="ed") u2 = User(name="jack") s.add_all([u1, u2]) @@ -1731,7 +1728,7 @@ class SavepointTest(_LocalFixture): @testing.requires.savepoints def test_savepoint_delete(self): User = self.classes.User - s = self.session() + s = fixture_session() u1 = User(name="ed") s.add(u1) s.commit() @@ -1745,7 +1742,7 @@ class SavepointTest(_LocalFixture): @testing.requires.savepoints def test_savepoint_commit(self): User = self.classes.User - s = self.session() + s = fixture_session() u1 = User(name="ed") u2 = User(name="jack") s.add_all([u1, u2]) @@ -1781,7 +1778,7 @@ class SavepointTest(_LocalFixture): @testing.requires.savepoints def test_savepoint_rollback_collections(self): User, Address = self.classes.User, self.classes.Address - s = self.session() + s = fixture_session() u1 = User(name="ed", addresses=[Address(email_address="foo")]) s.add(u1) s.commit() @@ -1834,7 +1831,7 @@ class SavepointTest(_LocalFixture): @testing.requires.savepoints def test_savepoint_commit_collections(self): User, Address = self.classes.User, self.classes.Address - s = self.session() + s = fixture_session() u1 = User(name="ed", addresses=[Address(email_address="foo")]) s.add(u1) s.commit() @@ -1889,7 +1886,7 @@ class SavepointTest(_LocalFixture): @testing.requires.savepoints def test_expunge_pending_on_rollback(self): User = self.classes.User - sess = self.session() + sess = fixture_session() sess.begin_nested() u2 = User(name="newuser") @@ -1901,7 +1898,7 @@ class SavepointTest(_LocalFixture): @testing.requires.savepoints def test_update_deleted_on_rollback(self): User = self.classes.User - s = self.session() + s = fixture_session() u1 = User(name="ed") s.add(u1) s.commit() @@ -1916,7 +1913,7 @@ class SavepointTest(_LocalFixture): @testing.requires.savepoints_w_release def test_savepoint_lost_still_runs(self): User = self.classes.User - s = self.session(bind=self.bind) + s = fixture_session() trans = s.begin_nested() s.connection() u1 = User(name="ed") @@ -1951,7 +1948,7 @@ class AccountingFlagsTest(_LocalFixture): def test_no_expire_on_commit(self): User, users = self.classes.User, self.tables.users - sess = sessionmaker(expire_on_commit=False)() + sess = fixture_session(expire_on_commit=False) u1 = User(name="ed") sess.add(u1) sess.commit() @@ -1967,12 +1964,12 @@ class AutoCommitTest(_LocalFixture): __backend__ = True def test_begin_nested_requires_trans(self): - sess = create_session(autocommit=True) + sess = fixture_session(autocommit=True) assert_raises(sa_exc.InvalidRequestError, sess.begin_nested) def test_begin_preflush(self): User = self.classes.User - sess = create_session(autocommit=True) + sess = fixture_session(autocommit=True) u1 = User(name="ed") sess.add(u1) @@ -1987,7 +1984,7 @@ class AutoCommitTest(_LocalFixture): def test_accounting_commit_fails_add(self): User = self.classes.User - sess = create_session(autocommit=True) + sess = fixture_session(autocommit=True) fail = False @@ -2016,7 +2013,7 @@ class AutoCommitTest(_LocalFixture): def test_accounting_commit_fails_delete(self): User = self.classes.User - sess = create_session(autocommit=True) + sess = fixture_session(autocommit=True) fail = False @@ -2047,7 +2044,7 @@ class AutoCommitTest(_LocalFixture): when autocommit=True/expire_on_commit=True.""" User = self.classes.User - sess = create_session(autocommit=True, expire_on_commit=True) + sess = fixture_session(autocommit=True, expire_on_commit=True) u1 = User(id=1, name="ed") sess.add(u1) @@ -2071,7 +2068,7 @@ class ContextManagerPlusFutureTest(FixtureTest): mapper(User, users) - sess = Session() + sess = fixture_session() def go(): with sess.begin_nested(): @@ -2091,7 +2088,7 @@ class ContextManagerPlusFutureTest(FixtureTest): mapper(User, users) - sess = Session() + sess = fixture_session() with sess.begin(): sess.add(User(name="u1")) @@ -2103,7 +2100,7 @@ class ContextManagerPlusFutureTest(FixtureTest): mapper(User, users) - sess = Session() + sess = fixture_session() def go(): with sess.begin(): @@ -2513,7 +2510,7 @@ class NaturalPKRollbackTest(fixtures.MappedTest): mapper(User, users) - session = sessionmaker()() + session = fixture_session() u1, u2, u3 = User(name="u1"), User(name="u2"), User(name="u3") @@ -2544,7 +2541,7 @@ class NaturalPKRollbackTest(fixtures.MappedTest): u1 = User(name="u1") - s = Session() + s = fixture_session() s.add(u1) s.flush() del u1 @@ -2568,7 +2565,7 @@ class NaturalPKRollbackTest(fixtures.MappedTest): u1 = User(name="u1") u2 = User(name="u2") - s = Session() + s = fixture_session() s.add_all([u1, u2]) s.commit() @@ -2595,7 +2592,7 @@ class NaturalPKRollbackTest(fixtures.MappedTest): u1 = User(name="u1") - s = Session() + s = fixture_session() s.add(u1) s.commit() @@ -2624,7 +2621,7 @@ class NaturalPKRollbackTest(fixtures.MappedTest): u2 = User(name="u2") u3 = User(name="u3") - s = Session() + s = fixture_session() s.add_all([u1, u2, u3]) s.commit() @@ -2655,7 +2652,7 @@ class NaturalPKRollbackTest(fixtures.MappedTest): u1 = User(name="u1") - s = Session() + s = fixture_session() s.add(u1) s.commit() diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 7583b9d22..84373b2dc 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -17,7 +17,6 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property -from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship @@ -30,6 +29,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.util import OrderedDict @@ -70,7 +70,7 @@ class HistoryTest(_fixtures.FixtureTest): ), ) - session = create_session(autocommit=False) + session = fixture_session(autocommit=False) u = User(name="u1") a = Address(email_address="u1@e") @@ -129,7 +129,7 @@ class UnicodeTest(fixtures.MappedTest): t1 = Test(id=1, txt=txt) self.assert_(t1.txt == txt) - session = create_session(autocommit=False) + session = fixture_session(autocommit=False) session.add(t1) session.commit() @@ -150,12 +150,12 @@ class UnicodeTest(fixtures.MappedTest): t1 = Test(txt=txt) t1.t2s.append(Test2()) t1.t2s.append(Test2()) - session = create_session(autocommit=False) + session = fixture_session(autocommit=False, expire_on_commit=False) session.add(t1) session.commit() session.close() - session = create_session() + session = fixture_session() t1 = session.query(Test).filter_by(id=t1.id).one() assert len(t1.t2s) == 2 @@ -227,7 +227,7 @@ class UnicodeSchemaTest(fixtures.MappedTest): b1 = B() a1.t2s.append(b1) - session = create_session() + session = fixture_session() session.add(a1) session.flush() session.expunge_all() @@ -265,7 +265,7 @@ class UnicodeSchemaTest(fixtures.MappedTest): a1 = A(b=5) b1 = B(e=7) - session = create_session() + session = fixture_session() session.add_all((a1, b1)) session.flush() session.expunge_all() @@ -302,7 +302,7 @@ class BinaryHistTest(fixtures.MappedTest, testing.AssertsExecutionResults): mapper(Foo, t1) - s = create_session() + s = fixture_session() f1 = Foo(data=data) s.add(f1) @@ -367,7 +367,7 @@ class PKTest(fixtures.MappedTest): e = Entry(name="entry1", value="this is entry 1", multi_rev=2) - session = create_session() + session = fixture_session() session.add(e) session.flush() session.expunge_all() @@ -386,7 +386,7 @@ class PKTest(fixtures.MappedTest): e = Entry(pk_col_1="pk1", pk_col_2="pk1_related", data="im the data") - session = create_session() + session = fixture_session() session.add(e) session.flush() @@ -402,7 +402,7 @@ class PKTest(fixtures.MappedTest): data="some more data", ) - session = create_session() + session = fixture_session() session.add(e) session.flush() @@ -463,20 +463,21 @@ class ForeignPKTest(fixtures.MappedTest): ps = PersonSite(site="asdf") p.sites.append(ps) - session = create_session() + session = fixture_session() session.add(p) session.flush() - p_count = ( - select(func.count("*")) - .where(people.c.person == "im the key") - .scalar() + conn = session.connection() + p_count = conn.scalar( + select(func.count("*")).where(people.c.person == "im the key") ) eq_(p_count, 1) eq_( - select(func.count("*")) - .where(peoplesites.c.person == "im the key") - .scalar(), + conn.scalar( + select(func.count("*")).where( + peoplesites.c.person == "im the key" + ) + ), 1, ) @@ -539,7 +540,7 @@ class ClauseAttributesTest(fixtures.MappedTest): u = User(name="test") - session = create_session() + session = fixture_session() session.add(u) session.flush() @@ -557,7 +558,7 @@ class ClauseAttributesTest(fixtures.MappedTest): u = User(name="test") - session = create_session() + session = fixture_session() session.add(u) session.flush() @@ -582,7 +583,7 @@ class ClauseAttributesTest(fixtures.MappedTest): u = User(name="test", counter=sa.select(5).scalar_subquery()) - session = create_session() + session = fixture_session() session.add(u) session.flush() @@ -593,7 +594,7 @@ class ClauseAttributesTest(fixtures.MappedTest): PkDefault = self.classes.PkDefault pk = PkDefault(id=literal(5) + 10, data="some data") - session = Session() + session = fixture_session() session.add(pk) session.flush() @@ -613,7 +614,7 @@ class ClauseAttributesTest(fixtures.MappedTest): bool, None == sa.false(), # noqa ) - s = create_session() + s = fixture_session() hb = HasBoolean(value=None) s.add(hb) s.flush() @@ -638,7 +639,7 @@ class ClauseAttributesTest(fixtures.MappedTest): u = User(id=5, name="test", counter=Thing(3)) - session = create_session() + session = fixture_session() session.add(u) session.flush() @@ -706,24 +707,32 @@ class PassiveDeletesTest(fixtures.MappedTest): ) }, ) - session = create_session() - mc = MyClass() - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) + with fixture_session() as session: + mc = MyClass() + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) - session.add(mc) - session.flush() - session.expunge_all() + session.add(mc) + session.flush() + session.expunge_all() - eq_(select(func.count("*")).select_from(myothertable).scalar(), 4) - mc = session.query(MyClass).get(mc.id) - session.delete(mc) - session.flush() + conn = session.connection() + + eq_( + conn.scalar(select(func.count("*")).select_from(myothertable)), + 4, + ) + mc = session.query(MyClass).get(mc.id) + session.delete(mc) + session.flush() - eq_(select(func.count("*")).select_from(mytable).scalar(), 0) - eq_(select(func.count("*")).select_from(myothertable).scalar(), 0) + eq_(conn.scalar(select(func.count("*")).select_from(mytable)), 0) + eq_( + conn.scalar(select(func.count("*")).select_from(myothertable)), + 0, + ) @testing.emits_warning( r".*'passive_deletes' is normally configured on one-to-many" @@ -754,7 +763,7 @@ class PassiveDeletesTest(fixtures.MappedTest): ) mapper(MyClass, mytable) - session = Session() + session = fixture_session() mc = MyClass() mco = MyOtherClass() mco.myclass = mc @@ -882,20 +891,24 @@ class ExtraPassiveDeletesTest(fixtures.MappedTest): }, ) - session = create_session() - mc = MyClass() - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) - session.add(mc) - session.flush() - session.expunge_all() + with fixture_session(expire_on_commit=False) as session: + mc = MyClass() + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + session.add(mc) + session.commit() - eq_(select(func.count("*")).select_from(myothertable).scalar(), 4) - mc = session.query(MyClass).get(mc.id) - session.delete(mc) - assert_raises(sa.exc.DBAPIError, session.flush) + with fixture_session(expire_on_commit=False) as session: + conn = session.connection() + eq_( + conn.scalar(select(func.count("*")).select_from(myothertable)), + 4, + ) + mc = session.query(MyClass).get(mc.id) + session.delete(mc) + assert_raises(sa.exc.DBAPIError, session.flush) def test_extra_passive_2(self): myothertable, MyClass, MyOtherClass, mytable = ( @@ -916,19 +929,23 @@ class ExtraPassiveDeletesTest(fixtures.MappedTest): }, ) - session = create_session() - mc = MyClass() - mc.children.append(MyOtherClass()) - session.add(mc) - session.flush() - session.expunge_all() + with fixture_session(expire_on_commit=False) as session: + mc = MyClass() + mc.children.append(MyOtherClass()) + session.add(mc) + session.commit() - eq_(select(func.count("*")).select_from(myothertable).scalar(), 1) + with fixture_session(autoflush=False) as session: + conn = session.connection() + eq_( + conn.scalar(select(func.count("*")).select_from(myothertable)), + 1, + ) - mc = session.query(MyClass).get(mc.id) - session.delete(mc) - mc.children[0].data = "some new data" - assert_raises(sa.exc.DBAPIError, session.flush) + mc = session.query(MyClass).get(mc.id) + session.delete(mc) + mc.children[0].data = "some new data" + assert_raises(sa.exc.DBAPIError, session.flush) def test_extra_passive_obj_removed_o2m(self): myothertable, MyClass, MyOtherClass, mytable = ( @@ -947,7 +964,7 @@ class ExtraPassiveDeletesTest(fixtures.MappedTest): }, ) - session = create_session() + session = fixture_session() mc = MyClass() moc1 = MyOtherClass() moc2 = MyOtherClass() @@ -982,7 +999,7 @@ class ExtraPassiveDeletesTest(fixtures.MappedTest): ) }, ) - session = Session() + session = fixture_session() mc = MyClass() session.add(mc) session.commit() @@ -1016,16 +1033,18 @@ class ColumnCollisionTest(fixtures.MappedTest): pass mapper(Book, book) - sess = create_session() + with fixture_session() as sess: - b1 = Book(book_id="abc", title="def") - sess.add(b1) - sess.flush() + b1 = Book(book_id="abc", title="def") + sess.add(b1) + sess.flush() - b1.title = "ghi" - sess.flush() - sess.close() - eq_(sess.query(Book).first(), Book(book_id="abc", title="ghi")) + b1.title = "ghi" + sess.flush() + sess.commit() + + with fixture_session() as sess: + eq_(sess.query(Book).first(), Book(book_id="abc", title="ghi")) class DefaultTest(fixtures.MappedTest): @@ -1133,7 +1152,7 @@ class DefaultTest(fixtures.MappedTest): h4 = Hoho() h5 = Hoho(foober="im the new foober") - session = create_session(autocommit=False) + session = fixture_session(autocommit=False, expire_on_commit=False) session.add_all((h1, h2, h3, h4, h5)) session.commit() @@ -1193,7 +1212,7 @@ class DefaultTest(fixtures.MappedTest): mapper(Secondary, self.tables.secondary_table) h1 = Hoho() - session = create_session() + session = fixture_session() session.add(h1) if testing.db.dialect.implicit_returning: @@ -1218,7 +1237,7 @@ class DefaultTest(fixtures.MappedTest): mapper(Hoho, default_t) h1 = Hoho(hoho="15", counter=15) - session = create_session() + session = fixture_session() session.add(h1) session.flush() @@ -1236,7 +1255,7 @@ class DefaultTest(fixtures.MappedTest): mapper(Hoho, default_t) h1 = Hoho() - session = create_session() + session = fixture_session() session.add(h1) session.flush() @@ -1272,7 +1291,7 @@ class DefaultTest(fixtures.MappedTest): s1 = Secondary(data="s1") h1.secondaries.append(s1) - session = create_session() + session = fixture_session() session.add(h1) session.flush() session.expunge_all() @@ -1428,7 +1447,7 @@ class ColumnPropertyTest(fixtures.MappedTest): ) mapper(SubData, subdata, inherits=Data) - sess = create_session() + sess = fixture_session() sd1 = SubData(a="hello", b="there", c="hi") sess.add(sd1) sess.flush() @@ -1437,25 +1456,27 @@ class ColumnPropertyTest(fixtures.MappedTest): def _test(self, expect_expiry, expect_deferred_load=False): Data = self.classes.Data - sess = create_session() + with fixture_session() as sess: - d1 = Data(a="hello", b="there") - sess.add(d1) - sess.flush() + d1 = Data(a="hello", b="there") + sess.add(d1) + sess.flush() - eq_(d1.aplusb, "hello there") - - d1.b = "bye" - sess.flush() - if expect_expiry: - eq_(d1.aplusb, "hello bye") - else: eq_(d1.aplusb, "hello there") - d1.b = "foobar" - d1.aplusb = "im setting this explicitly" - sess.flush() - eq_(d1.aplusb, "im setting this explicitly") + d1.b = "bye" + sess.flush() + if expect_expiry: + eq_(d1.aplusb, "hello bye") + else: + eq_(d1.aplusb, "hello there") + + d1.b = "foobar" + d1.aplusb = "im setting this explicitly" + sess.flush() + eq_(d1.aplusb, "im setting this explicitly") + + sess.commit() # test issue #3984. # NOTE: if we only expire_all() here rather than start with brand new @@ -1463,18 +1484,18 @@ class ColumnPropertyTest(fixtures.MappedTest): # "undeferred". this is questionable but not as severe as the never- # loaded attribute being loaded during an unexpire. - sess.close() - d1 = sess.query(Data).first() + with fixture_session() as sess: + d1 = sess.query(Data).first() - d1.b = "so long" - sess.flush() - sess.expire_all() - eq_(d1.b, "so long") - if expect_deferred_load: - eq_("aplusb" in d1.__dict__, False) - else: - eq_("aplusb" in d1.__dict__, True) - eq_(d1.aplusb, "hello so long") + d1.b = "so long" + sess.flush() + sess.expire_all() + eq_(d1.b, "so long") + if expect_deferred_load: + eq_("aplusb" in d1.__dict__, False) + else: + eq_("aplusb" in d1.__dict__, True) + eq_(d1.aplusb, "hello so long") class OneToManyTest(_fixtures.FixtureTest): @@ -1506,21 +1527,22 @@ class OneToManyTest(_fixtures.FixtureTest): a2 = Address(email_address="lala@test.org") u.addresses.append(a2) - session = create_session() + session = fixture_session() session.add(u) session.flush() - user_rows = users.select(users.c.id.in_([u.id])).execute().fetchall() + conn = session.connection() + user_rows = conn.execute( + users.select(users.c.id.in_([u.id])) + ).fetchall() eq_(list(user_rows[0]), [u.id, "one2manytester"]) - address_rows = ( + address_rows = conn.execute( addresses.select( addresses.c.id.in_([a.id, a2.id]), order_by=[addresses.c.email_address], ) - .execute() - .fetchall() - ) + ).fetchall() eq_(list(address_rows[0]), [a2.id, u.id, "lala@test.org"]) eq_(list(address_rows[1]), [a.id, u.id, "one2many@test.org"]) @@ -1531,9 +1553,9 @@ class OneToManyTest(_fixtures.FixtureTest): session.flush() - address_rows = ( - addresses.select(addresses.c.id == addressid).execute().fetchall() - ) + address_rows = conn.execute( + addresses.select(addresses.c.id == addressid) + ).fetchall() eq_(list(address_rows[0]), [addressid, userid, "somethingnew@foo.com"]) self.assert_(u.id == userid and a2.id == addressid) @@ -1569,7 +1591,7 @@ class OneToManyTest(_fixtures.FixtureTest): a3 = Address(email_address="emailaddress3") - session = create_session() + session = fixture_session() session.add_all((u1, u2, a3)) session.flush() @@ -1631,7 +1653,7 @@ class OneToManyTest(_fixtures.FixtureTest): a = Address(email_address="address1") u1.addresses.append(a) - session = create_session() + session = fixture_session() session.add_all((u1, u2)) session.flush() @@ -1668,7 +1690,7 @@ class OneToManyTest(_fixtures.FixtureTest): a = Address(email_address="address1") u1.addresses.append(a) - session = create_session() + session = fixture_session() session.add_all((u1, u2)) session.flush() @@ -1703,7 +1725,7 @@ class OneToManyTest(_fixtures.FixtureTest): a = Address(email_address="myonlyaddress@foo.com") u.address = a - session = create_session() + session = fixture_session() session.add(u) session.flush() @@ -1738,7 +1760,7 @@ class OneToManyTest(_fixtures.FixtureTest): u = User(name="one2onetester") u.address = Address(email_address="myonlyaddress@foo.com") - session = create_session() + session = fixture_session() session.add(u) session.flush() @@ -1768,7 +1790,7 @@ class OneToManyTest(_fixtures.FixtureTest): u = User(name="test") Address(email_address="testaddress", user=u) - session = create_session() + session = fixture_session() session.add(u) session.flush() session.delete(u) @@ -1812,7 +1834,7 @@ class OneToManyTest(_fixtures.FixtureTest): u.boston_addresses.append(a) u.newyork_addresses.append(b) - session = create_session() + session = fixture_session() session.add(u) session.flush() @@ -1829,40 +1851,42 @@ class SaveTest(_fixtures.FixtureTest): u = User(name="savetester") u2 = User(name="savetester2") - session = create_session() - session.add_all((u, u2)) - session.flush() + with fixture_session() as session: + session.add_all((u, u2)) + session.flush() - # assert the first one retrieves the same from the identity map - nu = session.query(m).get(u.id) - assert u is nu + # assert the first one retrieves the same from the identity map + nu = session.query(m).get(u.id) + assert u is nu - # clear out the identity map, so next get forces a SELECT - session.expunge_all() + # clear out the identity map, so next get forces a SELECT + session.expunge_all() - # check it again, identity should be different but ids the same - nu = session.query(m).get(u.id) - assert u is not nu and u.id == nu.id and nu.name == "savetester" + # check it again, identity should be different but ids the same + nu = session.query(m).get(u.id) + assert u is not nu and u.id == nu.id and nu.name == "savetester" + + session.commit() # change first users name and save - session = create_session() - session.add(u) - u.name = "modifiedname" - assert u in session.dirty - session.flush() + with fixture_session() as session: + session.add(u) + u.name = "modifiedname" + assert u in session.dirty + session.flush() - # select both - userlist = ( - session.query(User) - .filter(users.c.id.in_([u.id, u2.id])) - .order_by(users.c.name) - .all() - ) + # select both + userlist = ( + session.query(User) + .filter(users.c.id.in_([u.id, u2.id])) + .order_by(users.c.name) + .all() + ) - eq_(u.id, userlist[0].id) - eq_(userlist[0].name, "modifiedname") - eq_(u2.id, userlist[1].id) - eq_(userlist[1].name, "savetester2") + eq_(u.id, userlist[0].id) + eq_(userlist[0].name, "modifiedname") + eq_(u2.id, userlist[1].id) + eq_(userlist[1].name, "savetester2") def test_synonym(self): users = self.tables.users @@ -1881,7 +1905,7 @@ class SaveTest(_fixtures.FixtureTest): u = SUser(syn_name="some name") eq_(u.syn_name, "User:some name:User") - session = create_session() + session = fixture_session() session.add(u) session.flush() session.expunge_all() @@ -1916,7 +1940,7 @@ class SaveTest(_fixtures.FixtureTest): u.addresses.append(Address(email_address="u1@e3")) u.addresses.append(Address(email_address="u1@e4")) - session = create_session() + session = fixture_session() session.add(u) session.flush() session.expunge_all() @@ -1951,7 +1975,7 @@ class SaveTest(_fixtures.FixtureTest): au = AddressUser(name="u", email_address="u@e") - session = create_session() + session = fixture_session() session.add(au) session.flush() session.expunge_all() @@ -1973,7 +1997,7 @@ class SaveTest(_fixtures.FixtureTest): # don't set deferred attribute, commit session o = Order(id=42) - session = create_session(autocommit=False) + session = fixture_session(autocommit=False) session.add(o) session.commit() @@ -2022,7 +2046,7 @@ class SaveTest(_fixtures.FixtureTest): mapper(User, users) u = User(name="") - session = create_session() + session = fixture_session() session.add(u) session.flush() session.expunge_all() @@ -2058,7 +2082,7 @@ class SaveTest(_fixtures.FixtureTest): ) u = User(name="multitester", email="multi@test.org") - session = create_session() + session = fixture_session() session.add(u) session.flush() session.expunge_all() @@ -2068,26 +2092,27 @@ class SaveTest(_fixtures.FixtureTest): u = session.query(User).get(id_) assert u.name == "multitester" - user_rows = ( - users.select(users.c.id.in_([u.foo_id])).execute().fetchall() - ) + conn = session.connection() + user_rows = conn.execute( + users.select(users.c.id.in_([u.foo_id])) + ).fetchall() eq_(list(user_rows[0]), [u.foo_id, "multitester"]) - address_rows = ( - addresses.select(addresses.c.id.in_([u.id])).execute().fetchall() - ) + address_rows = conn.execute( + addresses.select(addresses.c.id.in_([u.id])) + ).fetchall() eq_(list(address_rows[0]), [u.id, u.foo_id, "multi@test.org"]) u.email = "lala@hey.com" u.name = "imnew" session.flush() - user_rows = ( - users.select(users.c.id.in_([u.foo_id])).execute().fetchall() - ) + user_rows = conn.execute( + users.select(users.c.id.in_([u.foo_id])) + ).fetchall() eq_(list(user_rows[0]), [u.foo_id, "imnew"]) - address_rows = ( - addresses.select(addresses.c.id.in_([u.id])).execute().fetchall() - ) + address_rows = conn.execute( + addresses.select(addresses.c.id.in_([u.id])) + ).fetchall() eq_(list(address_rows[0]), [u.id, u.foo_id, "lala@hey.com"]) session.expunge_all() @@ -2118,7 +2143,7 @@ class SaveTest(_fixtures.FixtureTest): u = User(name="u1") u.addresses.append(Address(email_address="u1@e1")) u.addresses.append(Address(email_address="u1@e2")) - session = create_session() + session = fixture_session() session.add(u) session.flush() session.expunge_all() @@ -2126,8 +2151,18 @@ class SaveTest(_fixtures.FixtureTest): u = session.query(User).get(u.id) session.delete(u) session.flush() - eq_(select(func.count("*")).select_from(users).scalar(), 0) - eq_(select(func.count("*")).select_from(addresses).scalar(), 0) + eq_( + session.connection().scalar( + select(func.count("*")).select_from(users) + ), + 0, + ) + eq_( + session.connection().scalar( + select(func.count("*")).select_from(addresses) + ), + 0, + ) def test_batch_mode(self): """The 'batch=False' flag on mapper()""" @@ -2153,7 +2188,7 @@ class SaveTest(_fixtures.FixtureTest): u1 = User(name="user1") u2 = User(name="user2") - session = create_session() + session = fixture_session() session.add_all((u1, u2)) session.flush() @@ -2202,7 +2237,7 @@ class ManyToOneTest(_fixtures.FixtureTest): ), ) - session = create_session() + session = fixture_session() data = [ {"name": "thesub", "email_address": "bar@foo.com"}, @@ -2253,14 +2288,13 @@ class ManyToOneTest(_fixtures.FixtureTest): ), ) - result = ( - sa.select(users, addresses) - .where( + conn = session.connection() + result = conn.execute( + sa.select(users, addresses).where( sa.and_( users.c.id == addresses.c.user_id, addresses.c.id == a.id ), ) - .execute() ) eq_( list(result.first()), @@ -2287,7 +2321,7 @@ class ManyToOneTest(_fixtures.FixtureTest): u1 = User(name="user1") a1.user = u1 - session = create_session() + session = fixture_session() session.add(a1) session.flush() session.expunge_all() @@ -2324,7 +2358,7 @@ class ManyToOneTest(_fixtures.FixtureTest): u1 = User(name="user1") a1.user = u1 - session = create_session() + session = fixture_session() session.add_all((a1, a2)) session.flush() session.expunge_all() @@ -2366,7 +2400,7 @@ class ManyToOneTest(_fixtures.FixtureTest): u2 = User(name="user2") a1.user = u1 - session = create_session() + session = fixture_session() session.add_all((a1, u1, u2)) session.flush() session.expunge_all() @@ -2408,7 +2442,7 @@ class ManyToOneTest(_fixtures.FixtureTest): a1 = Address(email_address="e1") a1.user = u1 - session = create_session() + session = fixture_session() session.add(u1) session.flush() session.expunge_all() @@ -2496,7 +2530,7 @@ class ManyToManyTest(_fixtures.FixtureTest): }, ] - session = create_session() + session = fixture_session() objects = [] _keywords = dict([(k.name, k) for k in session.query(Keyword)]) @@ -2600,14 +2634,15 @@ class ManyToManyTest(_fixtures.FixtureTest): i.keywords.append(k1) i.keywords.append(k2) - session = create_session() + session = fixture_session() session.add(i) session.flush() - eq_(select(func.count("*")).select_from(item_keywords).scalar(), 2) + conn = session.connection() + eq_(conn.scalar(select(func.count("*")).select_from(item_keywords)), 2) i.keywords = [] session.flush() - eq_(select(func.count("*")).select_from(item_keywords).scalar(), 0) + eq_(conn.scalar(select(func.count("*")).select_from(item_keywords)), 0) def test_scalar(self): """sa.dependency won't delete an m2m relationship referencing None.""" @@ -2633,7 +2668,7 @@ class ManyToManyTest(_fixtures.FixtureTest): ) i = Item(description="x") - session = create_session() + session = fixture_session() session.add(i) session.flush() session.delete(i) @@ -2671,7 +2706,7 @@ class ManyToManyTest(_fixtures.FixtureTest): item = Item(description="item 1") item.keywords.extend([k1, k2, k3]) - session = create_session() + session = fixture_session() session.add(item) session.flush() @@ -2727,7 +2762,7 @@ class ManyToManyTest(_fixtures.FixtureTest): ), ) - session = create_session() + session = fixture_session() def fixture(): _kw = dict([(k.name, k) for k in session.query(Keyword)]) @@ -2786,7 +2821,7 @@ class SaveTest2(_fixtures.FixtureTest): ), ) - session = create_session() + session = fixture_session() def fixture(): return [ @@ -2909,14 +2944,24 @@ class SaveTest3(fixtures.MappedTest): i.keywords.append(k1) i.keywords.append(k2) - session = create_session() + session = fixture_session() session.add(i) session.flush() - eq_(select(func.count("*")).select_from(assoc).scalar(), 2) + eq_( + session.connection().scalar( + select(func.count("*")).select_from(assoc) + ), + 2, + ) i.keywords = [] session.flush() - eq_(select(func.count("*")).select_from(assoc).scalar(), 0) + eq_( + session.connection().scalar( + select(func.count("*")).select_from(assoc) + ), + 0, + ) class BooleanColTest(fixtures.MappedTest): @@ -2941,7 +2986,7 @@ class BooleanColTest(fixtures.MappedTest): mapper(T, t1_t) - sess = create_session() + sess = fixture_session() t1 = T(value=True, name="t1") t2 = T(value=False, name="t2") t3 = T(value=True, name="t3") @@ -3060,7 +3105,7 @@ class RowSwitchTest(fixtures.MappedTest): ) mapper(T6, t6) - sess = create_session() + sess = fixture_session() o5 = T5(data="some t5", id=1) o5.t6s.append(T6(data="some t6", id=1)) @@ -3108,7 +3153,7 @@ class RowSwitchTest(fixtures.MappedTest): ) mapper(T7, t7) - sess = create_session() + sess = fixture_session() o5 = T5(data="some t5", id=1) o5.t7s.append(T7(data="some t7", id=1)) @@ -3159,7 +3204,7 @@ class RowSwitchTest(fixtures.MappedTest): mapper(T6, t6, properties={"t5": relationship(T5)}) mapper(T5, t5) - sess = create_session() + sess = fixture_session() o5 = T6(data="some t6", id=1) o5.t5 = T5(data="some t5", id=1) @@ -3222,7 +3267,7 @@ class InheritingRowSwitchTest(fixtures.MappedTest): mapper(P, parent) mapper(C, child, inherits=P) - sess = create_session() + sess = fixture_session() c1 = C(pid=1, cid=1, pdata="c1", cdata="c1") sess.add(c1) sess.flush() @@ -3253,67 +3298,6 @@ class InheritingRowSwitchTest(fixtures.MappedTest): ) -class TransactionTest(fixtures.MappedTest): - __requires__ = ("deferrable_or_no_constraints",) - - @classmethod - def define_tables(cls, metadata): - Table("t1", metadata, Column("id", Integer, primary_key=True)) - - Table( - "t2", - metadata, - Column("id", Integer, primary_key=True), - Column( - "t1_id", - Integer, - ForeignKey("t1.id", deferrable=True, initially="deferred"), - ), - ) - - @classmethod - def setup_classes(cls): - class T1(cls.Comparable): - pass - - class T2(cls.Comparable): - pass - - @classmethod - def setup_mappers(cls): - T2, T1, t2, t1 = ( - cls.classes.T2, - cls.classes.T1, - cls.tables.t2, - cls.tables.t1, - ) - - mapper(T1, t1) - mapper(T2, t2) - - def test_close_transaction_on_commit_fail(self): - T2, t1 = self.classes.T2, self.tables.t1 - - session = create_session(autocommit=True) - - # with a deferred constraint, this fails at COMMIT time instead - # of at INSERT time. - session.add(T2(t1_id=123)) - - try: - session.flush() - assert False - except Exception: - # Flush needs to rollback also when commit fails - assert session._legacy_transaction() is None - - # todo: on 8.3 at least, the failed commit seems to close the cursor? - # needs investigation. leaving in the DDL above now to help verify - # that the new deferrable support on FK isn't involved in this issue. - if testing.against("postgresql"): - t1.bind.engine.dispose() - - class PartialNullPKTest(fixtures.MappedTest): # sqlite totally fine with NULLs in pk columns. # no other DB is like this. @@ -3340,7 +3324,7 @@ class PartialNullPKTest(fixtures.MappedTest): def test_key_switch(self): T1 = self.classes.T1 - s = Session() + s = fixture_session() s.add(T1(col1="1", col2=None)) t1 = s.query(T1).first() @@ -3354,7 +3338,7 @@ class PartialNullPKTest(fixtures.MappedTest): def test_plain_update(self): T1 = self.classes.T1 - s = Session() + s = fixture_session() s.add(T1(col1="1", col2=None)) t1 = s.query(T1).first() @@ -3368,7 +3352,7 @@ class PartialNullPKTest(fixtures.MappedTest): def test_delete(self): T1 = self.classes.T1 - s = Session() + s = fixture_session() s.add(T1(col1="1", col2=None)) t1 = s.query(T1).first() @@ -3382,7 +3366,7 @@ class PartialNullPKTest(fixtures.MappedTest): def test_total_null(self): T1 = self.classes.T1 - s = Session() + s = fixture_session() s.add(T1(col1=None, col2=None)) assert_raises_message( orm_exc.FlushError, @@ -3394,7 +3378,7 @@ class PartialNullPKTest(fixtures.MappedTest): def test_dont_complain_if_no_update(self): T1 = self.classes.T1 - s = Session() + s = fixture_session() t = T1(col1="1", col2=None) s.add(t) s.commit() @@ -3494,7 +3478,7 @@ class EnsurePKSortableTest(fixtures.MappedTest): mapper(cls.classes.T3, cls.tables.t3) def test_exception_persistent_flush_py3k(self): - s = Session() + s = fixture_session() a, b = self.classes.T2(id=self.three), self.classes.T2(id=self.four) s.add_all([a, b]) @@ -3520,7 +3504,7 @@ class EnsurePKSortableTest(fixtures.MappedTest): s.close() def test_persistent_flush_sortable(self): - s = Session() + s = fixture_session() a, b = self.classes.T1(id=self.one), self.classes.T1(id=self.two) s.add_all([a, b]) @@ -3531,7 +3515,7 @@ class EnsurePKSortableTest(fixtures.MappedTest): s.commit() def test_pep435_custom_sort_key(self): - s = Session() + s = fixture_session() a = self.classes.T3(id=self.three, value=1) b = self.classes.T3(id=self.four, value=2) diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 97b7b9edd..4e713627c 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -15,7 +15,6 @@ from sqlalchemy import text from sqlalchemy import util from sqlalchemy.orm import attributes from sqlalchemy.orm import backref -from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship @@ -30,6 +29,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.mock import Mock from sqlalchemy.testing.mock import patch from sqlalchemy.testing.schema import Column @@ -73,7 +73,7 @@ class RudimentaryFlushTest(UOWTest): mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() a1, a2 = Address(email_address="a1"), Address(email_address="a2") u1 = User(name="u1", addresses=[a1, a2]) @@ -122,7 +122,7 @@ class RudimentaryFlushTest(UOWTest): mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() a1, a2 = Address(email_address="a1"), Address(email_address="a2") u1 = User(name="u1", addresses=[a1, a2]) sess.add(u1) @@ -153,7 +153,7 @@ class RudimentaryFlushTest(UOWTest): mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) - sess = create_session() + sess = fixture_session() a1, a2 = Address(email_address="a1"), Address(email_address="a2") u1 = User(name="u1", addresses=[a1, a2]) sess.add(u1) @@ -186,7 +186,7 @@ class RudimentaryFlushTest(UOWTest): mapper(User, users) mapper(Address, addresses, properties={"user": relationship(User)}) - sess = create_session() + sess = fixture_session() u1 = User(name="u1") a1, a2 = ( @@ -238,7 +238,7 @@ class RudimentaryFlushTest(UOWTest): mapper(User, users) mapper(Address, addresses, properties={"user": relationship(User)}) - sess = create_session() + sess = fixture_session() u1 = User(name="u1") a1, a2 = ( @@ -273,7 +273,7 @@ class RudimentaryFlushTest(UOWTest): mapper(User, users) mapper(Address, addresses, properties={"user": relationship(User)}) - sess = create_session() + sess = fixture_session() u1 = User(name="u1") a1, a2 = ( @@ -318,7 +318,7 @@ class RudimentaryFlushTest(UOWTest): Address(email_address="c2", parent=parent), ) - session = Session() + session = fixture_session() session.add_all([c1, c2]) session.add(parent) @@ -399,7 +399,7 @@ class RudimentaryFlushTest(UOWTest): Address(email_address="c2", parent=parent), ) - session = Session() + session = fixture_session() session.add_all([c1, c2]) session.add(parent) @@ -464,7 +464,7 @@ class RudimentaryFlushTest(UOWTest): Address(email_address="c2", parent=parent), ) - session = Session() + session = fixture_session() session.add_all([c1, c2]) session.add(parent) @@ -521,7 +521,7 @@ class RudimentaryFlushTest(UOWTest): mapper(User, users) mapper(Address, addresses, properties={"user": relationship(User)}) - sess = create_session() + sess = fixture_session() u1 = User(name="u1") a1, a2 = ( @@ -552,7 +552,7 @@ class RudimentaryFlushTest(UOWTest): mapper(User, users) mapper(Address, addresses, properties={"user": relationship(User)}) - sess = create_session() + sess = fixture_session() u1 = User(name="u1") a1, a2 = ( @@ -593,7 +593,7 @@ class RudimentaryFlushTest(UOWTest): mapper(User, users) mapper(Address, addresses, properties={"parent": relationship(User)}) - sess = create_session() + sess = fixture_session() u1 = User(id=1, name="u1") a1 = Address(id=1, user_id=1, email_address="a2") @@ -632,7 +632,7 @@ class RudimentaryFlushTest(UOWTest): mapper(Node, nodes, properties={"children": relationship(Node)}) - sess = create_session() + sess = fixture_session() n1 = Node(id=1) n2 = Node(id=2, parent_id=1) @@ -674,7 +674,7 @@ class RudimentaryFlushTest(UOWTest): ) mapper(Keyword, keywords) - sess = create_session() + sess = fixture_session() k1 = Keyword(name="k1") i1 = Item(description="i1", keywords=[k1]) sess.add(i1) @@ -723,7 +723,7 @@ class RudimentaryFlushTest(UOWTest): addresses, properties={"user": relationship(User, passive_updates=True)}, ) - sess = create_session() + sess = fixture_session() u1 = User(name="ed") sess.add(u1) self._assert_uow_size(sess, 2) @@ -739,35 +739,35 @@ class RudimentaryFlushTest(UOWTest): mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) - sess = create_session() - u1 = User(name="ed") - sess.add(u1) - self._assert_uow_size(sess, 2) + with fixture_session(autoflush=False) as sess: + u1 = User(name="ed") + sess.add(u1) + self._assert_uow_size(sess, 2) - sess.flush() + sess.flush() - u1.name = "jack" + u1.name = "jack" - self._assert_uow_size(sess, 2) - sess.flush() + self._assert_uow_size(sess, 2) + sess.flush() - a1 = Address(email_address="foo") - sess.add(a1) - sess.flush() + a1 = Address(email_address="foo") + sess.add(a1) + sess.flush() - u1.addresses.append(a1) + u1.addresses.append(a1) - self._assert_uow_size(sess, 6) + self._assert_uow_size(sess, 6) - sess.flush() + sess.commit() - sess = create_session() - u1 = sess.query(User).first() - u1.name = "ed" - self._assert_uow_size(sess, 2) + with fixture_session(autoflush=False) as sess: + u1 = sess.query(User).first() + u1.name = "ed" + self._assert_uow_size(sess, 2) - u1.addresses - self._assert_uow_size(sess, 6) + u1.addresses + self._assert_uow_size(sess, 6) class SingleCycleTest(UOWTest): @@ -784,7 +784,7 @@ class SingleCycleTest(UOWTest): Node, nodes = self.classes.Node, self.tables.nodes mapper(Node, nodes, properties={"children": relationship(Node)}) - sess = create_session() + sess = fixture_session() n2, n3 = Node(data="n2"), Node(data="n3") n1 = Node(data="n1", children=[n2, n3]) @@ -832,7 +832,7 @@ class SingleCycleTest(UOWTest): Node, nodes = self.classes.Node, self.tables.nodes mapper(Node, nodes, properties={"children": relationship(Node)}) - sess = create_session() + sess = fixture_session() n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[]) n1 = Node(data="n1", children=[n2, n3]) @@ -860,7 +860,7 @@ class SingleCycleTest(UOWTest): Node, nodes = self.classes.Node, self.tables.nodes mapper(Node, nodes, properties={"children": relationship(Node)}) - sess = create_session() + sess = fixture_session() n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[]) n1 = Node(data="n1", children=[n2, n3]) @@ -894,7 +894,7 @@ class SingleCycleTest(UOWTest): nodes, properties={"parent": relationship(Node, remote_side=nodes.c.id)}, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n2, n3 = Node(data="n2", parent=n1), Node(data="n3", parent=n1) @@ -946,7 +946,7 @@ class SingleCycleTest(UOWTest): nodes, properties={"parent": relationship(Node, remote_side=nodes.c.id)}, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n2, n3 = Node(data="n2", parent=n1), Node(data="n3", parent=n1) @@ -978,30 +978,30 @@ class SingleCycleTest(UOWTest): nodes, properties={"parent": relationship(Node, remote_side=nodes.c.id)}, ) - sess = create_session() - n1 = Node(data="n1") - n2 = Node(data="n2", parent=n1) - sess.add_all([n1, n2]) - sess.flush() - sess.close() + with fixture_session() as sess: + n1 = Node(data="n1") + n2 = Node(data="n2", parent=n1) + sess.add_all([n1, n2]) + sess.commit() - n2 = sess.query(Node).filter_by(data="n2").one() - n2.parent = None - self.assert_sql_execution( - testing.db, - sess.flush, - CompiledSQL( - "UPDATE nodes SET parent_id=:parent_id WHERE " - "nodes.id = :nodes_id", - lambda ctx: {"parent_id": None, "nodes_id": n2.id}, - ), - ) + with fixture_session() as sess: + n2 = sess.query(Node).filter_by(data="n2").one() + n2.parent = None + self.assert_sql_execution( + testing.db, + sess.flush, + CompiledSQL( + "UPDATE nodes SET parent_id=:parent_id WHERE " + "nodes.id = :nodes_id", + lambda ctx: {"parent_id": None, "nodes_id": n2.id}, + ), + ) def test_cycle_rowswitch(self): Node, nodes = self.classes.Node, self.tables.nodes mapper(Node, nodes, properties={"children": relationship(Node)}) - sess = create_session() + sess = fixture_session() n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[]) n1 = Node(data="n1", children=[n2]) @@ -1025,7 +1025,7 @@ class SingleCycleTest(UOWTest): ) }, ) - sess = create_session() + sess = fixture_session() n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[]) n1 = Node(data="n1", children=[n2]) @@ -1051,7 +1051,7 @@ class SingleCycleTest(UOWTest): ) }, ) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n1.children.append(Node(data="n11")) n12 = Node(data="n12") @@ -1137,29 +1137,29 @@ class SingleCycleTest(UOWTest): Node, nodes = self.classes.Node, self.tables.nodes mapper(Node, nodes, properties={"children": relationship(Node)}) - sess = create_session() - n1 = Node(data="ed") - sess.add(n1) - self._assert_uow_size(sess, 2) + with fixture_session() as sess: + n1 = Node(data="ed") + sess.add(n1) + self._assert_uow_size(sess, 2) - sess.flush() + sess.flush() - n1.data = "jack" + n1.data = "jack" - self._assert_uow_size(sess, 2) - sess.flush() + self._assert_uow_size(sess, 2) + sess.flush() - n2 = Node(data="foo") - sess.add(n2) - sess.flush() + n2 = Node(data="foo") + sess.add(n2) + sess.flush() - n1.children.append(n2) + n1.children.append(n2) - self._assert_uow_size(sess, 3) + self._assert_uow_size(sess, 3) - sess.flush() + sess.commit() - sess = create_session() + sess = fixture_session(autoflush=False) n1 = sess.query(Node).first() n1.data = "ed" self._assert_uow_size(sess, 2) @@ -1179,7 +1179,7 @@ class SingleCycleTest(UOWTest): parent = Node() c1, c2 = Node(parent=parent), Node(parent=parent) - session = Session() + session = fixture_session() session.add_all([c1, c2]) session.add(parent) @@ -1285,7 +1285,7 @@ class SingleCyclePlusAttributeTest( ) mapper(FooBar, foobars) - sess = create_session() + sess = fixture_session() n1 = Node(data="n1") n2 = Node(data="n2") n1.children.append(n2) @@ -1355,110 +1355,111 @@ class SingleCycleM2MTest( }, ) - sess = create_session() - n1 = Node(data="n1") - n2 = Node(data="n2") - n3 = Node(data="n3") - n4 = Node(data="n4") - n5 = Node(data="n5") + with fixture_session(autoflush=False) as sess: + n1 = Node(data="n1") + n2 = Node(data="n2") + n3 = Node(data="n3") + n4 = Node(data="n4") + n5 = Node(data="n5") - n4.favorite = n3 - n1.favorite = n5 - n5.favorite = n2 + n4.favorite = n3 + n1.favorite = n5 + n5.favorite = n2 - n1.children = [n2, n3, n4] - n2.children = [n3, n5] - n3.children = [n5, n4] + n1.children = [n2, n3, n4] + n2.children = [n3, n5] + n3.children = [n5, n4] - sess.add_all([n1, n2, n3, n4, n5]) + sess.add_all([n1, n2, n3, n4, n5]) - # can't really assert the SQL on this easily - # since there's too many ways to insert the rows. - # so check the end result - sess.flush() - eq_( - sess.query( - node_to_nodes.c.left_node_id, node_to_nodes.c.right_node_id - ) - .order_by( - node_to_nodes.c.left_node_id, node_to_nodes.c.right_node_id + # can't really assert the SQL on this easily + # since there's too many ways to insert the rows. + # so check the end result + sess.flush() + eq_( + sess.query( + node_to_nodes.c.left_node_id, node_to_nodes.c.right_node_id + ) + .order_by( + node_to_nodes.c.left_node_id, node_to_nodes.c.right_node_id + ) + .all(), + sorted( + [ + (n1.id, n2.id), + (n1.id, n3.id), + (n1.id, n4.id), + (n2.id, n3.id), + (n2.id, n5.id), + (n3.id, n5.id), + (n3.id, n4.id), + ] + ), ) - .all(), - sorted( - [ - (n1.id, n2.id), - (n1.id, n3.id), - (n1.id, n4.id), - (n2.id, n3.id), - (n2.id, n5.id), - (n3.id, n5.id), - (n3.id, n4.id), - ] - ), - ) - sess.delete(n1) + sess.delete(n1) - self.assert_sql_execution( - testing.db, - sess.flush, - # this is n1.parents firing off, as it should, since - # passive_deletes is False for n1.parents - CompiledSQL( - "SELECT nodes.id AS nodes_id, nodes.data AS nodes_data, " - "nodes.favorite_node_id AS nodes_favorite_node_id FROM " - "nodes, node_to_nodes WHERE :param_1 = " - "node_to_nodes.right_node_id AND nodes.id = " - "node_to_nodes.left_node_id", - lambda ctx: {"param_1": n1.id}, - ), - CompiledSQL( - "DELETE FROM node_to_nodes WHERE " - "node_to_nodes.left_node_id = :left_node_id AND " - "node_to_nodes.right_node_id = :right_node_id", - lambda ctx: [ - {"right_node_id": n2.id, "left_node_id": n1.id}, - {"right_node_id": n3.id, "left_node_id": n1.id}, - {"right_node_id": n4.id, "left_node_id": n1.id}, - ], - ), - CompiledSQL( - "DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: {"id": n1.id}, - ), - ) + self.assert_sql_execution( + testing.db, + sess.flush, + # this is n1.parents firing off, as it should, since + # passive_deletes is False for n1.parents + CompiledSQL( + "SELECT nodes.id AS nodes_id, nodes.data AS nodes_data, " + "nodes.favorite_node_id AS nodes_favorite_node_id FROM " + "nodes, node_to_nodes WHERE :param_1 = " + "node_to_nodes.right_node_id AND nodes.id = " + "node_to_nodes.left_node_id", + lambda ctx: {"param_1": n1.id}, + ), + CompiledSQL( + "DELETE FROM node_to_nodes WHERE " + "node_to_nodes.left_node_id = :left_node_id AND " + "node_to_nodes.right_node_id = :right_node_id", + lambda ctx: [ + {"right_node_id": n2.id, "left_node_id": n1.id}, + {"right_node_id": n3.id, "left_node_id": n1.id}, + {"right_node_id": n4.id, "left_node_id": n1.id}, + ], + ), + CompiledSQL( + "DELETE FROM nodes WHERE nodes.id = :id", + lambda ctx: {"id": n1.id}, + ), + ) - for n in [n2, n3, n4, n5]: - sess.delete(n) + for n in [n2, n3, n4, n5]: + sess.delete(n) - # load these collections - # outside of the flush() below - n4.children - n5.children + # load these collections + # outside of the flush() below + n4.children + n5.children - self.assert_sql_execution( - testing.db, - sess.flush, - CompiledSQL( - "DELETE FROM node_to_nodes WHERE node_to_nodes.left_node_id " - "= :left_node_id AND node_to_nodes.right_node_id = " - ":right_node_id", - lambda ctx: [ - {"right_node_id": n5.id, "left_node_id": n3.id}, - {"right_node_id": n4.id, "left_node_id": n3.id}, - {"right_node_id": n3.id, "left_node_id": n2.id}, - {"right_node_id": n5.id, "left_node_id": n2.id}, - ], - ), - CompiledSQL( - "DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: [{"id": n4.id}, {"id": n5.id}], - ), - CompiledSQL( - "DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: [{"id": n2.id}, {"id": n3.id}], - ), - ) + self.assert_sql_execution( + testing.db, + sess.flush, + CompiledSQL( + "DELETE FROM node_to_nodes " + "WHERE node_to_nodes.left_node_id " + "= :left_node_id AND node_to_nodes.right_node_id = " + ":right_node_id", + lambda ctx: [ + {"right_node_id": n5.id, "left_node_id": n3.id}, + {"right_node_id": n4.id, "left_node_id": n3.id}, + {"right_node_id": n3.id, "left_node_id": n2.id}, + {"right_node_id": n5.id, "left_node_id": n2.id}, + ], + ), + CompiledSQL( + "DELETE FROM nodes WHERE nodes.id = :id", + lambda ctx: [{"id": n4.id}, {"id": n5.id}], + ), + CompiledSQL( + "DELETE FROM nodes WHERE nodes.id = :id", + lambda ctx: [{"id": n2.id}, {"id": n3.id}], + ), + ) class RowswitchAccountingTest(fixtures.MappedTest): @@ -1504,7 +1505,7 @@ class RowswitchAccountingTest(fixtures.MappedTest): def test_switch_on_update(self): Parent, Child = self._fixture() - sess = create_session(autocommit=False) + sess = fixture_session(autocommit=False) p1 = Parent(id=1, child=Child()) sess.add(p1) @@ -1535,7 +1536,7 @@ class RowswitchAccountingTest(fixtures.MappedTest): def test_switch_on_delete(self): Parent, Child = self._fixture() - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=2, child=None) sess.add(p1) sess.flush() @@ -1603,7 +1604,7 @@ class RowswitchM2OTest(fixtures.MappedTest): # change that previously showed up as nothing. A, B, C = self._fixture() - sess = Session() + sess = fixture_session() sess.add(A(id=1, bs=[B(id=1, c=C(id=1))])) sess.commit() @@ -1615,7 +1616,7 @@ class RowswitchM2OTest(fixtures.MappedTest): def test_set_none_w_get_replaces_m2o(self): A, B, C = self._fixture() - sess = Session() + sess = fixture_session() sess.add(A(id=1, bs=[B(id=1, c=C(id=1))])) sess.commit() @@ -1634,7 +1635,7 @@ class RowswitchM2OTest(fixtures.MappedTest): # shows, we can't rely on this - the get of None will blow # away the history. A, B, C = self._fixture() - sess = Session() + sess = fixture_session() sess.add(A(id=1, bs=[B(id=1, data="somedata")])) sess.commit() @@ -1646,7 +1647,7 @@ class RowswitchM2OTest(fixtures.MappedTest): def test_set_none_w_get_replaces_scalar(self): A, B, C = self._fixture() - sess = Session() + sess = fixture_session() sess.add(A(id=1, bs=[B(id=1, data="somedata")])) sess.commit() @@ -1706,7 +1707,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): @testing.requires.sane_rowcount def test_update_single_missing(self): Parent, Child = self._fixture() - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=2) sess.add(p1) sess.flush() @@ -1737,7 +1738,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): "sqlalchemy.engine.cursor.CursorResult.rowcount", rowcount ): Parent, Child = self._fixture() - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=2) sess.add(p1) sess.flush() @@ -1767,7 +1768,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): "sqlalchemy.engine.cursor.CursorResult.rowcount", rowcount ): Parent, Child = self._fixture() - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=2) p2 = Parent(id=2, data=3) sess.add_all([p1, p2]) @@ -1797,7 +1798,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): "sqlalchemy.engine.cursor.CursorResult.rowcount", rowcount ): Parent, Child = self._fixture() - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=1) sess.add(p1) sess.flush() @@ -1815,7 +1816,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): @testing.requires.sane_rowcount def test_delete_twice(self): Parent, Child = self._fixture() - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=2, child=None) sess.add(p1) sess.commit() @@ -1835,7 +1836,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): @testing.requires.sane_multi_rowcount def test_delete_multi_missing_warning(self): Parent, Child = self._fixture() - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=2, child=None) p2 = Parent(id=2, data=3, child=None) sess.add_all([p1, p2]) @@ -1856,7 +1857,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): # raise occurs for single row UPDATE that misses even if # supports_sane_multi_rowcount is False Parent, Child = self._fixture() - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=2, child=None) sess.add(p1) sess.flush() @@ -1879,7 +1880,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): # supports_sane_multi_rowcount is False, even if rowcount is still # correct Parent, Child = self._fixture() - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=2, child=None) p2 = Parent(id=2, data=3, child=None) sess.add_all([p1, p2]) @@ -1897,7 +1898,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): def test_delete_single_broken_multi_rowcount_still_warns(self): Parent, Child = self._fixture() - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=2, child=None) sess.add(p1) sess.flush() @@ -1919,7 +1920,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): def test_delete_multi_broken_multi_rowcount_doesnt_warn(self): Parent, Child = self._fixture() - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=2, child=None) p2 = Parent(id=2, data=3, child=None) sess.add_all([p1, p2]) @@ -1941,7 +1942,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): def test_delete_multi_missing_allow(self): Parent, Child = self._fixture(confirm_deleted_rows=False) - sess = Session() + sess = fixture_session() p1 = Parent(id=1, data=2, child=None) p2 = Parent(id=2, data=3, child=None) sess.add_all([p1, p2]) @@ -1979,7 +1980,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults): pass mapper(T, t) - sess = Session() + sess = fixture_session() sess.add_all( [ T(data="t1"), @@ -2078,7 +2079,7 @@ class LoadersUsingCommittedTest(UOWTest): }, ) mapper(Address, addresses) - return create_session(autocommit=False) + return fixture_session(expire_on_commit=False) def test_before_update_m2o(self): """Expect normal many to one attribute load behavior @@ -2225,7 +2226,7 @@ class NoAttrEventInFlushTest(fixtures.MappedTest): event.listen(Thing.prefetch_val, "set", mock.prefetch_val) event.listen(Thing.returning_val, "set", mock.prefetch_val) t1 = Thing() - s = Session() + s = fixture_session() s.add(t1) s.flush() @@ -2275,7 +2276,7 @@ class EagerDefaultsTest(fixtures.MappedTest): def test_insert_defaults_present(self): Thing = self.classes.Thing - s = Session() + s = fixture_session() t1, t2 = (Thing(id=1, foo=5), Thing(id=2, foo=10)) @@ -2298,7 +2299,7 @@ class EagerDefaultsTest(fixtures.MappedTest): def test_insert_defaults_present_as_expr(self): Thing = self.classes.Thing - s = Session() + s = fixture_session() t1, t2 = ( Thing(id=1, foo=text("2 + 5")), @@ -2357,7 +2358,7 @@ class EagerDefaultsTest(fixtures.MappedTest): def test_insert_defaults_nonpresent(self): Thing = self.classes.Thing - s = Session() + s = fixture_session() t1, t2 = (Thing(id=1), Thing(id=2)) @@ -2416,7 +2417,7 @@ class EagerDefaultsTest(fixtures.MappedTest): def test_update_defaults_nonpresent(self): Thing2 = self.classes.Thing2 - s = Session() + s = fixture_session() t1, t2, t3, t4 = ( Thing2(id=1, foo=1, bar=2), @@ -2511,7 +2512,7 @@ class EagerDefaultsTest(fixtures.MappedTest): def test_update_defaults_present_as_expr(self): Thing2 = self.classes.Thing2 - s = Session() + s = fixture_session() t1, t2, t3, t4 = ( Thing2(id=1, foo=1, bar=2), @@ -2612,7 +2613,7 @@ class EagerDefaultsTest(fixtures.MappedTest): def test_insert_defaults_bulk_insert(self): Thing = self.classes.Thing - s = Session() + s = fixture_session() mappings = [{"id": 1}, {"id": 2}] @@ -2626,7 +2627,7 @@ class EagerDefaultsTest(fixtures.MappedTest): def test_update_defaults_bulk_update(self): Thing2 = self.classes.Thing2 - s = Session() + s = fixture_session() t1, t2, t3, t4 = ( Thing2(id=1, foo=1, bar=2), @@ -2665,7 +2666,7 @@ class EagerDefaultsTest(fixtures.MappedTest): def test_update_defaults_present(self): Thing2 = self.classes.Thing2 - s = Session() + s = fixture_session() t1, t2 = (Thing2(id=1, foo=1, bar=2), Thing2(id=2, foo=2, bar=3)) @@ -2687,7 +2688,7 @@ class EagerDefaultsTest(fixtures.MappedTest): def test_insert_dont_fetch_nondefaults(self): Thing2 = self.classes.Thing2 - s = Session() + s = fixture_session() t1 = Thing2(id=1, bar=2) @@ -2704,7 +2705,7 @@ class EagerDefaultsTest(fixtures.MappedTest): def test_update_dont_fetch_nondefaults(self): Thing2 = self.classes.Thing2 - s = Session() + s = fixture_session() t1 = Thing2(id=1, bar=2) @@ -2783,7 +2784,7 @@ class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults): def test_update_against_none(self): Thing = self.classes.Thing - s = Session() + s = fixture_session() s.add(Thing(value=self.MyWidget("foo"))) s.commit() @@ -2796,7 +2797,7 @@ class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults): def test_update_against_something_else(self): Thing = self.classes.Thing - s = Session() + s = fixture_session() s.add(Thing(value=self.MyWidget("foo"))) s.commit() @@ -2809,7 +2810,7 @@ class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults): def test_no_update_no_change(self): Thing = self.classes.Thing - s = Session() + s = fixture_session() s.add(Thing(value=self.MyWidget("foo"), unrelated="unrelated")) s.commit() @@ -2923,7 +2924,7 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults): def _assert_col(self, name, value): Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing - s = Session() + s = fixture_session() col = getattr(Thing, name) obj = s.query(col).filter(col == value).one() @@ -2936,7 +2937,7 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults): def _test_insert(self, attr, expected): Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing - s = Session() + s = fixture_session() t1 = Thing(**{attr: None}) s.add(t1) @@ -2950,7 +2951,7 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults): def _test_bulk_insert(self, attr, expected): Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing - s = Session() + s = fixture_session() s.bulk_insert_mappings(Thing, [{attr: None}]) s.bulk_insert_mappings(AltNameThing, [{"_foo_" + attr: None}]) s.commit() @@ -2960,7 +2961,7 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults): def _test_insert_novalue(self, attr, expected): Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing - s = Session() + s = fixture_session() t1 = Thing() s.add(t1) @@ -2974,7 +2975,7 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults): def _test_bulk_insert_novalue(self, attr, expected): Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing - s = Session() + s = fixture_session() s.bulk_insert_mappings(Thing, [{}]) s.bulk_insert_mappings(AltNameThing, [{}]) s.commit() @@ -3059,7 +3060,7 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults): def test_json_none_as_null(self): JSONThing = self.classes.JSONThing - s = Session() + s = fixture_session() f1 = JSONThing(data=None, data_null=None) s.add(f1) s.commit() diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 0b0c9cea7..e350ee018 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -29,6 +29,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import not_in from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -96,7 +97,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_illegal_eval(self): User = self.classes.User - s = Session() + s = fixture_session() assert_raises_message( exc.ArgumentError, "Valid strategies for session synchronization " @@ -110,7 +111,7 @@ class UpdateDeleteTest(fixtures.MappedTest): User = self.classes.User Address = self.classes.Address - s = Session() + s = fixture_session() for q, mname in ( (s.query(User).limit(2), r"limit\(\)"), @@ -197,7 +198,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def __clause_element__(self): return User.name.__clause_element__() - s = Session() + s = fixture_session() jill = s.query(User).get(3) s.query(User).update( {Thing(): "moonbeam"}, synchronize_session="evaluate" @@ -211,7 +212,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def __clause_element__(self): return 5 - s = Session() + s = fixture_session() assert_raises_message( exc.ArgumentError, @@ -224,7 +225,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_evaluate_unmapped_col(self): User = self.classes.User - s = Session() + s = fixture_session() jill = s.query(User).get(3) s.query(User).update( {column("name"): "moonbeam"}, synchronize_session="evaluate" @@ -239,7 +240,7 @@ class UpdateDeleteTest(fixtures.MappedTest): mapper(Foo, self.tables.users, properties={"uname": synonym("name")}) - s = Session() + s = fixture_session() jill = s.query(Foo).get(3) s.query(Foo).update( {"uname": "moonbeam"}, synchronize_session="evaluate" @@ -252,7 +253,7 @@ class UpdateDeleteTest(fixtures.MappedTest): mapper(Foo, self.tables.users, properties={"uname": synonym("name")}) - s = Session() + s = fixture_session() jill = s.query(Foo).get(3) s.query(Foo).update( {Foo.uname: "moonbeam"}, synchronize_session="evaluate" @@ -269,7 +270,7 @@ class UpdateDeleteTest(fixtures.MappedTest): properties={"uname": synonym("name"), "ufoo": synonym("uname")}, ) - s = Session() + s = fixture_session() jill = s.query(Foo).get(3) s.query(Foo).update( {Foo.ufoo: "moonbeam"}, synchronize_session="evaluate" @@ -287,7 +288,7 @@ class UpdateDeleteTest(fixtures.MappedTest): ): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() @@ -392,7 +393,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_fetch_dont_refresh_expired_objects(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() @@ -455,7 +456,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_delete(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter( @@ -470,14 +471,14 @@ class UpdateDeleteTest(fixtures.MappedTest): User = self.classes.User users = self.tables.users - sess = Session() + sess = fixture_session() sess.query(users).delete(synchronize_session=False) eq_(sess.query(User).count(), 0) def test_delete_with_bindparams(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter(text("name = :name")).params( @@ -490,7 +491,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_delete_rollback(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter( or_(User.name == "john", User.name == "jill") @@ -502,7 +503,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_delete_rollback_with_fetch(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter( or_(User.name == "john", User.name == "jill") @@ -514,7 +515,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_delete_without_session_sync(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter( @@ -528,7 +529,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_delete_with_fetch_strategy(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter( @@ -543,7 +544,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_delete_invalid_evaluation(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() @@ -566,7 +567,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update(self): User, users = self.classes.User, self.tables.users - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter(User.age > 29).update( @@ -779,7 +780,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_against_table_col(self): User, users = self.classes.User, self.tables.users - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() eq_([john.age, jack.age, jill.age, jane.age], [25, 47, 29, 37]) sess.query(User).filter(User.age > 27).update( @@ -790,7 +791,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_against_metadata(self): User, users = self.classes.User, self.tables.users - sess = Session() + sess = fixture_session() sess.query(users).update( {users.c.age_int: 29}, synchronize_session=False @@ -803,7 +804,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_with_bindparams(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() @@ -820,7 +821,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_fetch_returning(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() @@ -905,7 +906,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_delete_fetch_returning(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() @@ -994,7 +995,7 @@ class UpdateDeleteTest(fixtures.MappedTest): User = self.classes.User - sess = Session() + sess = fixture_session() assert_raises( exc.ArgumentError, lambda: sess.query(User.name == "filter").update( @@ -1005,7 +1006,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_without_load(self): User = self.classes.User - sess = Session() + sess = fixture_session() sess.query(User).filter(User.id == 3).update( {"age": 44}, synchronize_session="fetch" @@ -1018,7 +1019,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_changes_resets_dirty(self): User = self.classes.User - sess = Session(autoflush=False) + sess = fixture_session(autoflush=False) john, jack, jill, jane = sess.query(User).order_by(User.id).all() @@ -1047,7 +1048,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_changes_with_autoflush(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() @@ -1073,7 +1074,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_with_expire_strategy(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter(User.age > 29).update( @@ -1090,7 +1091,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_returns_rowcount(self): User = self.classes.User - sess = Session() + sess = fixture_session() rowcount = ( sess.query(User) @@ -1110,7 +1111,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_delete_returns_rowcount(self): User = self.classes.User - sess = Session() + sess = fixture_session() rowcount = ( sess.query(User) @@ -1122,7 +1123,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_all(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).update({"age": 42}, synchronize_session="evaluate") @@ -1136,7 +1137,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_delete_all(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).delete(synchronize_session="evaluate") @@ -1149,7 +1150,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_autoflush_before_evaluate_update(self): User = self.classes.User - sess = Session() + sess = fixture_session() john = sess.query(User).filter_by(name="john").one() john.name = "j2" @@ -1161,7 +1162,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_autoflush_before_fetch_update(self): User = self.classes.User - sess = Session() + sess = fixture_session() john = sess.query(User).filter_by(name="john").one() john.name = "j2" @@ -1173,7 +1174,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_autoflush_before_evaluate_delete(self): User = self.classes.User - sess = Session() + sess = fixture_session() john = sess.query(User).filter_by(name="john").one() john.name = "j2" @@ -1185,7 +1186,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_autoflush_before_fetch_delete(self): User = self.classes.User - sess = Session() + sess = fixture_session() john = sess.query(User).filter_by(name="john").one() john.name = "j2" @@ -1197,7 +1198,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_evaluate_before_update(self): User = self.classes.User - sess = Session() + sess = fixture_session() john = sess.query(User).filter_by(name="john").one() sess.expire(john, ["age"]) @@ -1213,7 +1214,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_fetch_before_update(self): User = self.classes.User - sess = Session() + sess = fixture_session() john = sess.query(User).filter_by(name="john").one() sess.expire(john, ["age"]) @@ -1226,7 +1227,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_evaluate_before_delete(self): User = self.classes.User - sess = Session() + sess = fixture_session() john = sess.query(User).filter_by(name="john").one() sess.expire(john, ["age"]) @@ -1238,7 +1239,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_fetch_before_delete(self): User = self.classes.User - sess = Session() + sess = fixture_session() john = sess.query(User).filter_by(name="john").one() sess.expire(john, ["age"]) @@ -1249,7 +1250,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_unordered_dict(self): User = self.classes.User - session = Session() + session = fixture_session() # Do an update using unordered dict and check that the parameters used # are ordered in table order @@ -1266,7 +1267,7 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_update_preserve_parameter_order_query(self): User = self.classes.User - session = Session() + session = fixture_session() # Do update using a tuple and check that order is preserved @@ -1444,7 +1445,7 @@ class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest): def test_update_with_eager_relationships(self): Document = self.classes.Document - sess = Session() + sess = fixture_session() foo, bar, baz = sess.query(Document).order_by(Document.id).all() sess.query(Document).filter(Document.user_id == 1).update( @@ -1461,7 +1462,7 @@ class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest): def test_update_with_explicit_joinedload(self): User = self.classes.User - sess = Session() + sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).options(joinedload(User.documents)).filter( @@ -1477,7 +1478,7 @@ class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest): def test_delete_with_eager_relationships(self): Document = self.classes.Document - sess = Session() + sess = fixture_session() sess.query(Document).filter(Document.user_id == 1).delete( synchronize_session=False @@ -1556,7 +1557,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): @testing.requires.update_from def test_update_from_joined_subq_test(self): Document = self.classes.Document - s = Session() + s = fixture_session() subq = ( s.query(func.max(Document.title).label("title")) @@ -1585,7 +1586,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): @testing.requires.delete_from def test_delete_from_joined_subq_test(self): Document = self.classes.Document - s = Session() + s = fixture_session() subq = ( s.query(func.max(Document.title).label("title")) @@ -1606,7 +1607,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): User = self.classes.User Document = self.classes.Document - s = Session() + s = fixture_session() q = s.query(User).filter(User.id == Document.user_id) assert_raises_message( @@ -1619,7 +1620,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): @testing.requires.update_where_target_in_subquery def test_update_using_in(self): Document = self.classes.Document - s = Session() + s = fixture_session() subq = ( s.query(func.max(Document.title).label("title")) @@ -1649,7 +1650,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): @testing.requires.standalone_binds def test_update_using_case(self): Document = self.classes.Document - s = Session() + s = fixture_session() subq = ( s.query(func.max(Document.title).label("title")) @@ -1684,7 +1685,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): Document = self.classes.Document User = self.classes.User - s = Session() + s = fixture_session() s.query(Document).filter(User.id == Document.user_id).filter( User.id == 2 @@ -1735,7 +1736,7 @@ class ExpressionUpdateTest(fixtures.MappedTest): Data = self.classes.Data d1 = Data() - sess = Session() + sess = fixture_session() sess.add(d1) sess.commit() eq_(d1.cnt, 0) @@ -1753,7 +1754,7 @@ class ExpressionUpdateTest(fixtures.MappedTest): def test_update_args(self): Data = self.classes.Data - session = Session() + session = fixture_session() update_args = {"mysql_limit": 1} m1 = testing.mock.Mock() @@ -1937,7 +1938,7 @@ class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest): def test_update(self, fetchstyle, future): Staff, Sales, Support = self.classes("Staff", "Sales", "Support") - sess = Session() + sess = fixture_session() en1, en2 = ( sess.execute(select(Sales).order_by(Sales.sales_stats)) @@ -1987,7 +1988,7 @@ class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest): def test_delete(self, fetchstyle, future): Staff, Sales, Support = self.classes("Staff", "Sales", "Support") - sess = Session() + sess = fixture_session() en1, en2 = sess.query(Sales).order_by(Sales.sales_stats).all() mn1, mn2 = sess.query(Support).order_by(Support.support_stats).all() diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py index d3082accd..260cae37b 100644 --- a/test/orm/test_utils.py +++ b/test/orm/test_utils.py @@ -8,9 +8,7 @@ from sqlalchemy import util from sqlalchemy.ext.hybrid import hybrid_method from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper -from sqlalchemy.orm import Session from sqlalchemy.orm import synonym from sqlalchemy.orm import util as orm_util from sqlalchemy.orm import with_polymorphic @@ -22,6 +20,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat from test.orm import _fixtures from .inheritance import _poly_fixtures @@ -177,7 +176,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): self._fixture(Point) alias = aliased(Point) - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(alias).filter(alias.left_of(Point)), @@ -203,7 +202,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): eq_(str(Point.double_x.__clause_element__()), "point.x * :x_1") eq_(str(alias.double_x.__clause_element__()), "point_1.x * :x_1") - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(alias).filter(alias.double_x > Point.x), @@ -262,7 +261,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): a2 = aliased(Point) eq_(str(a2.x_alone == alias.x), "point_1.x = point_2.x") - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(alias).filter(alias.x_alone > Point.x), @@ -282,7 +281,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): eq_(str(Point.x_syn), "Point.x_syn") eq_(str(alias.x_syn), "AliasedClass_Point.x_syn") - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(alias.x_syn).filter(alias.x_syn > Point.x_syn), "SELECT point_1.x AS point_1_x FROM point AS point_1, point " @@ -321,7 +320,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): a2 = aliased(Point) eq_(str(a2.x_syn == alias.x), "point_1.x = point_2.x") - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(alias).filter(alias.x_syn > Point.x), @@ -350,7 +349,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): eq_(str(Point.double_x.__clause_element__()), "point.x * :x_1") eq_(str(alias.double_x.__clause_element__()), "point_1.x * :x_1") - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(alias).filter(alias.double_x > Point.x), @@ -380,7 +379,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): eq_(str(Point.double_x.__clause_element__()), "point.x * :x_1") eq_(str(alias.double_x.__clause_element__()), "point_1.x * :x_1") - sess = Session() + sess = fixture_session() self.assert_compile( sess.query(alias).filter(alias.double_x > Point.x), @@ -469,7 +468,7 @@ class IdentityKeyTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - s = create_session() + s = fixture_session() u = User(name="u1") s.add(u) s.flush() diff --git a/test/orm/test_validators.py b/test/orm/test_validators.py index 547815745..887ff7754 100644 --- a/test/orm/test_validators.py +++ b/test/orm/test_validators.py @@ -2,13 +2,13 @@ from sqlalchemy import exc from sqlalchemy.orm import collections from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session from sqlalchemy.orm import validates from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import ne_ +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.mock import call from sqlalchemy.testing.mock import Mock from test.orm import _fixtures @@ -27,7 +27,7 @@ class ValidatorTest(_fixtures.FixtureTest): return name + " modified" mapper(User, users) - sess = Session() + sess = fixture_session() u1 = User(name="ed") eq_(u1.name, "ed modified") assert_raises(AssertionError, setattr, u1, "name", "fred") @@ -60,7 +60,7 @@ class ValidatorTest(_fixtures.FixtureTest): mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) - sess = Session() + sess = fixture_session() u1 = User(name="edward") a0 = Address(email_address="noemail") assert_raises(AssertionError, u1.addresses.append, a0) diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index 32f18e47c..c185c59d0 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -15,12 +15,10 @@ from sqlalchemy import testing from sqlalchemy import TypeDecorator from sqlalchemy import util from sqlalchemy.orm import configure_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import Session -from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import config @@ -31,6 +29,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.mock import patch from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -98,7 +97,7 @@ class NullVersionIdTest(fixtures.MappedTest): version_id_generator=False, ) - s1 = Session() + s1 = fixture_session() return s1 def test_null_version_id_insert(self): @@ -175,7 +174,7 @@ class VersioningTest(fixtures.MappedTest): Foo, version_table = self.classes.Foo, self.tables.version_table mapper(Foo, version_table, version_id_col=version_table.c.version_id) - s1 = Session() + s1 = fixture_session() return s1 @engines.close_open_connections @@ -211,7 +210,7 @@ class VersioningTest(fixtures.MappedTest): ): s1.commit() - s2 = create_session(autocommit=False) + s2 = fixture_session(autocommit=False) f1_s = s2.query(Foo).get(f1.id) f1_s.value = "f1rev3" with conditional_sane_rowcount_warnings( @@ -371,7 +370,7 @@ class VersioningTest(fixtures.MappedTest): s1.add(f1s1) s1.commit() - s2 = create_session(autocommit=False) + s2 = fixture_session(autocommit=False) f1s2 = s2.query(Foo).get(f1s1.id) f1s2.value = "f1 new value" with conditional_sane_rowcount_warnings( @@ -406,7 +405,7 @@ class VersioningTest(fixtures.MappedTest): version_table = self.tables.version_table mapper(Foo, version_table) - s1 = Session() + s1 = fixture_session() f1s1 = Foo(value="f1 value", version_id=1) s1.add(f1s1) s1.commit() @@ -425,7 +424,7 @@ class VersioningTest(fixtures.MappedTest): s1.add(f1s1) s1.commit() - s2 = create_session(autocommit=False) + s2 = fixture_session(autocommit=False) f1s2 = s2.query(Foo).get(f1s1.id) # not sure if I like this API s2.refresh(f1s2, with_for_update=True) @@ -502,13 +501,13 @@ class VersioningTest(fixtures.MappedTest): Foo, version_table = self.classes.Foo, self.tables.version_table - s1 = create_session(autocommit=False) + s1 = fixture_session(autocommit=False) mapper(Foo, version_table) f1s1 = Foo(value="foo", version_id=0) s1.add(f1s1) s1.commit() - s2 = create_session(autocommit=False) + s2 = fixture_session(autocommit=False) f1s2 = s2.query(Foo).with_for_update(read=True).get(f1s1.id) assert f1s2.id == f1s1.id assert f1s2.value == f1s1.value @@ -647,7 +646,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): version_id_col=node.c.version_id, ) - s = Session() + s = fixture_session() n1 = Node(id=1) n2 = Node(id=2) @@ -827,7 +826,7 @@ class NoBumpOnRelationshipTest(fixtures.MappedTest): def _run_test(self, auto_version_counter=True): A, B = self.classes("A", "B") - s = Session() + s = fixture_session() if auto_version_counter: a1 = A() else: @@ -918,7 +917,7 @@ class ColumnTypeTest(fixtures.MappedTest): Foo, version_table = self.classes.Foo, self.tables.version_table mapper(Foo, version_table, version_id_col=version_table.c.version_id) - s1 = Session() + s1 = fixture_session() return s1 @engines.close_open_connections @@ -984,7 +983,7 @@ class RowSwitchTest(fixtures.MappedTest): def test_row_switch(self): P = self.classes.P - session = sessionmaker()() + session = fixture_session() session.add(P(id="P1", data="P version 1")) session.commit() session.close() @@ -1002,7 +1001,7 @@ class RowSwitchTest(fixtures.MappedTest): assert P.c.property.strategy.use_get - session = sessionmaker()() + session = fixture_session() session.add(P(id="P1", data="P version 1")) session.commit() session.close() @@ -1073,7 +1072,7 @@ class AlternateGeneratorTest(fixtures.MappedTest): def test_row_switch(self): P = self.classes.P - session = sessionmaker()() + session = fixture_session() session.add(P(id="P1", data="P version 1")) session.commit() session.close() @@ -1091,7 +1090,7 @@ class AlternateGeneratorTest(fixtures.MappedTest): assert P.c.property.strategy.use_get - session = sessionmaker()() + session = fixture_session() session.add(P(id="P1", data="P version 1")) session.commit() session.close() @@ -1111,19 +1110,17 @@ class AlternateGeneratorTest(fixtures.MappedTest): def test_child_row_switch_two(self): P = self.classes.P - Session = sessionmaker() - # TODO: not sure this test is # testing exactly what its looking for - sess1 = Session() + sess1 = fixture_session() sess1.add(P(id="P1", data="P version 1")) sess1.commit() sess1.close() p1 = sess1.query(P).first() - sess2 = Session() + sess2 = fixture_session() p2 = sess2.query(P).first() sess1.delete(p1) @@ -1185,7 +1182,7 @@ class PlainInheritanceTest(fixtures.MappedTest): mapper(Base, base, version_id_col=base.c.version_id) mapper(Sub, sub, inherits=Base) - s = Session() + s = fixture_session() s1 = Sub(data="b", sub_data="s") s.add(s1) s.commit() @@ -1245,13 +1242,13 @@ class InheritanceTwoVersionIdsTest(fixtures.MappedTest): mapper(Base, base, version_id_col=base.c.version_id) mapper(Sub, sub, inherits=Base) - session = Session() + session = fixture_session() b1 = Base(data="b1") session.add(b1) session.commit() eq_(b1.version_id, 1) # base is populated - eq_(select(base.c.version_id).scalar(), 1) + eq_(session.connection().scalar(select(base.c.version_id)), 1) def test_sub_both(self): Base, sub, base, Sub = ( @@ -1264,16 +1261,16 @@ class InheritanceTwoVersionIdsTest(fixtures.MappedTest): mapper(Base, base, version_id_col=base.c.version_id) mapper(Sub, sub, inherits=Base) - session = Session() + session = fixture_session() s1 = Sub(data="s1", sub_data="s1") session.add(s1) session.commit() # table is populated - eq_(select(sub.c.version_id).scalar(), 1) + eq_(session.connection().scalar(select(sub.c.version_id)), 1) # base is populated - eq_(select(base.c.version_id).scalar(), 1) + eq_(session.connection().scalar(select(base.c.version_id)), 1) def test_sub_only(self): Base, sub, base, Sub = ( @@ -1286,16 +1283,16 @@ class InheritanceTwoVersionIdsTest(fixtures.MappedTest): mapper(Base, base) mapper(Sub, sub, inherits=Base, version_id_col=sub.c.version_id) - session = Session() + session = fixture_session() s1 = Sub(data="s1", sub_data="s1") session.add(s1) session.commit() # table is populated - eq_(select(sub.c.version_id).scalar(), 1) + eq_(session.connection().scalar(select(sub.c.version_id)), 1) # base is not - eq_(select(base.c.version_id).scalar(), None) + eq_(session.connection().scalar(select(base.c.version_id)), None) def test_mismatch_version_col_warning(self): Base, sub, base, Sub = ( @@ -1384,7 +1381,7 @@ class ServerVersioningTest(fixtures.MappedTest): eager_defaults=eager_defaults, ) - s1 = Session(expire_on_commit=expire_on_commit) + s1 = fixture_session(expire_on_commit=expire_on_commit) return s1 def test_insert_col(self): @@ -1625,7 +1622,7 @@ class ServerVersioningTest(fixtures.MappedTest): f1.value - s2 = Session() + s2 = fixture_session() f2 = s2.query(self.classes.Foo).first() f2.value = "f2" s2.commit() @@ -1652,7 +1649,7 @@ class ServerVersioningTest(fixtures.MappedTest): # a SELECT for it within the flush. f1.value - s2 = Session(expire_on_commit=False) + s2 = fixture_session(expire_on_commit=False) f2 = s2.query(self.classes.Foo).first() f2.value = "f2" s2.commit() @@ -1698,7 +1695,7 @@ class ManualVersionTest(fixtures.MappedTest): ) def test_insert(self): - sess = Session() + sess = fixture_session() a1 = self.classes.A() a1.vid = 1 @@ -1708,7 +1705,7 @@ class ManualVersionTest(fixtures.MappedTest): eq_(a1.vid, 1) def test_update(self): - sess = Session() + sess = fixture_session() a1 = self.classes.A() a1.vid = 1 @@ -1728,7 +1725,7 @@ class ManualVersionTest(fixtures.MappedTest): @testing.requires.sane_rowcount_w_returning def test_update_concurrent_check(self): - sess = Session() + sess = fixture_session() a1 = self.classes.A() a1.vid = 1 @@ -1742,7 +1739,7 @@ class ManualVersionTest(fixtures.MappedTest): assert_raises(orm_exc.StaleDataError, sess.commit) def test_update_version_conditional(self): - sess = Session() + sess = fixture_session() a1 = self.classes.A() a1.vid = 1 @@ -1814,7 +1811,7 @@ class ManualInheritanceVersionTest(fixtures.MappedTest): mapper(cls.classes.B, cls.tables.b, inherits=cls.classes.A) def test_no_increment(self): - sess = Session() + sess = fixture_session() b1 = self.classes.B() b1.vid = 1 @@ -1874,7 +1871,7 @@ class VersioningMappedSelectTest(fixtures.MappedTest): ) mapper(Foo, current, version_id_col=version_table.c.version_id) - s1 = Session() + s1 = fixture_session() return s1 def _explicit_version_fixture(self): @@ -1892,7 +1889,7 @@ class VersioningMappedSelectTest(fixtures.MappedTest): version_id_col=version_table.c.version_id, version_id_generator=False, ) - s1 = Session() + s1 = fixture_session() return s1 def test_implicit(self): diff --git a/test/requirements.py b/test/requirements.py index cb2f4840f..d5a718372 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -120,7 +120,7 @@ class DefaultRequirements(SuiteRequirements): def deferrable_fks(self): """target database must support deferrable fks""" - return only_on(["oracle"]) + return only_on(["oracle", "postgresql"]) @property def foreign_key_constraint_option_reflection_ondelete(self): @@ -1330,7 +1330,10 @@ class DefaultRequirements(SuiteRequirements): """dialect makes use of await_() to invoke operations on the DBAPI.""" return only_on( - ["postgresql+asyncpg", "mysql+aiomysql", "mariadb+aiomysql"] + LambdaPredicate( + lambda config: config.db.dialect.is_async, + "Async dialect required", + ) ) @property diff --git a/test/sql/test_constraints.py b/test/sql/test_constraints.py index 019409ba3..8c1fa5424 100644 --- a/test/sql/test_constraints.py +++ b/test/sql/test_constraints.py @@ -59,7 +59,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): ) self.assert_sql_execution( testing.db, - lambda: metadata.create_all(checkfirst=False), + lambda: metadata.create_all(testing.db, checkfirst=False), CompiledSQL( "CREATE TABLE employees (" "id INTEGER NOT NULL, " @@ -292,7 +292,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): assertions.append(AllOf(*fk_assertions)) with self.sql_execution_asserter() as asserter: - metadata.create_all(checkfirst=False) + metadata.create_all(testing.db, checkfirst=False) asserter.assert_(*assertions) assertions = [ @@ -302,7 +302,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): ] with self.sql_execution_asserter() as asserter: - metadata.drop_all(checkfirst=False), + metadata.drop_all(testing.db, checkfirst=False), asserter.assert_(*assertions) def _assert_cyclic_constraint_no_alter( @@ -356,7 +356,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): assertions = [AllOf(*table_assertions)] with self.sql_execution_asserter() as asserter: - metadata.create_all(checkfirst=False) + metadata.create_all(testing.db, checkfirst=False) asserter.assert_(*assertions) assertions = [ @@ -366,15 +366,15 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): if sqlite_warning: with expect_warnings("Can't sort tables for DROP; "): with self.sql_execution_asserter() as asserter: - metadata.drop_all(checkfirst=False), + metadata.drop_all(testing.db, checkfirst=False), else: with self.sql_execution_asserter() as asserter: - metadata.drop_all(checkfirst=False), + metadata.drop_all(testing.db, checkfirst=False), asserter.assert_(*assertions) @testing.force_drop_names("a", "b") def test_cycle_unnamed_fks(self): - metadata = MetaData(testing.db) + metadata = MetaData() Table( "a", @@ -417,7 +417,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): ), ] with self.sql_execution_asserter() as asserter: - metadata.create_all(checkfirst=False) + metadata.create_all(testing.db, checkfirst=False) if testing.db.dialect.supports_alter: asserter.assert_(*assertions) @@ -431,6 +431,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): "cycle have names so that they can be dropped using " "DROP CONSTRAINT.", metadata.drop_all, + testing.db, checkfirst=False, ) else: @@ -439,7 +440,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): "foreign key dependency exists between tables" ): with self.sql_execution_asserter() as asserter: - metadata.drop_all(checkfirst=False) + metadata.drop_all(testing.db, checkfirst=False) asserter.assert_( AllOf(CompiledSQL("DROP TABLE b"), CompiledSQL("DROP TABLE a")) @@ -447,7 +448,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): @testing.force_drop_names("a", "b") def test_cycle_named_fks(self): - metadata = MetaData(testing.db) + metadata = MetaData() Table( "a", @@ -491,13 +492,13 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): ), ] with self.sql_execution_asserter() as asserter: - metadata.create_all(checkfirst=False) + metadata.create_all(testing.db, checkfirst=False) if testing.db.dialect.supports_alter: asserter.assert_(*assertions) with self.sql_execution_asserter() as asserter: - metadata.drop_all(checkfirst=False) + metadata.drop_all(testing.db, checkfirst=False) asserter.assert_( CompiledSQL("ALTER TABLE b DROP CONSTRAINT aidfk"), @@ -507,7 +508,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): ) else: with self.sql_execution_asserter() as asserter: - metadata.drop_all(checkfirst=False) + metadata.drop_all(testing.db, checkfirst=False) asserter.assert_( AllOf(CompiledSQL("DROP TABLE b"), CompiledSQL("DROP TABLE a")) @@ -536,7 +537,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): self.assert_sql_execution( testing.db, - lambda: metadata.create_all(checkfirst=False), + lambda: metadata.create_all(testing.db, checkfirst=False), AllOf( CompiledSQL( "CREATE TABLE foo (" @@ -579,7 +580,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): self.assert_sql_execution( testing.db, - lambda: metadata.create_all(checkfirst=False), + lambda: metadata.create_all(testing.db, checkfirst=False), AllOf( CompiledSQL( "CREATE TABLE foo (" @@ -628,7 +629,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): self.assert_sql_execution( testing.db, - lambda: metadata.create_all(checkfirst=False), + lambda: metadata.create_all(testing.db, checkfirst=False), RegexSQL("^CREATE TABLE"), AllOf( CompiledSQL( @@ -665,7 +666,7 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): self.assert_sql_execution( testing.db, - lambda: metadata.create_all(checkfirst=False), + lambda: metadata.create_all(testing.db, checkfirst=False), RegexSQL("^CREATE TABLE"), AllOf( CompiledSQL( diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 1722a1e69..91076f9c3 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -1064,14 +1064,10 @@ class ExecuteTest(fixtures.TestBase): @testing.fails_on_everything_except("postgresql") def test_as_from(self, connection): # TODO: shouldn't this work on oracle too ? - x = connection.execute(func.current_date(bind=testing.db)).scalar() - y = connection.execute( - func.current_date(bind=testing.db).select() - ).scalar() - z = connection.scalar(func.current_date(bind=testing.db)) - w = connection.scalar( - select("*").select_from(func.current_date(bind=testing.db)) - ) + x = connection.execute(func.current_date()).scalar() + y = connection.execute(func.current_date().select()).scalar() + z = connection.scalar(func.current_date()) + w = connection.scalar(select("*").select_from(func.current_date())) assert x == y == z == w diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 6d26f7975..3047b5d09 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -27,7 +27,6 @@ from sqlalchemy import union_all from sqlalchemy import VARCHAR from sqlalchemy.engine import default from sqlalchemy.testing import assert_raises_message -from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -35,22 +34,13 @@ from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table -# ongoing - these are old tests. those which are of general use -# to test a dialect are being slowly migrated to -# sqlalhcemy.testing.suite - -users = users2 = addresses = metadata = None - - -class QueryTest(fixtures.TestBase): +class QueryTest(fixtures.TablesTest): __backend__ = True @classmethod - def setup_class(cls): - global users, users2, addresses, metadata - metadata = MetaData(testing.db) - users = Table( - "query_users", + def define_tables(cls, metadata): + Table( + "users", metadata, Column( "user_id", INT, primary_key=True, test_needs_autoincrement=True @@ -58,8 +48,8 @@ class QueryTest(fixtures.TestBase): Column("user_name", VARCHAR(20)), test_needs_acid=True, ) - addresses = Table( - "query_addresses", + Table( + "addresses", metadata, Column( "address_id", @@ -67,12 +57,12 @@ class QueryTest(fixtures.TestBase): primary_key=True, test_needs_autoincrement=True, ), - Column("user_id", Integer, ForeignKey("query_users.user_id")), + Column("user_id", Integer, ForeignKey("users.user_id")), Column("address", String(30)), test_needs_acid=True, ) - users2 = Table( + Table( "u2", metadata, Column("user_id", INT, primary_key=True), @@ -80,19 +70,6 @@ class QueryTest(fixtures.TestBase): test_needs_acid=True, ) - metadata.create_all() - - @engines.close_first - def teardown(self): - with testing.db.begin() as conn: - conn.execute(addresses.delete()) - conn.execute(users.delete()) - conn.execute(users2.delete()) - - @classmethod - def teardown_class(cls): - metadata.drop_all() - @testing.fails_on( "firebird", "kinterbasdb doesn't send full type information" ) @@ -105,6 +82,8 @@ class QueryTest(fixtures.TestBase): """ + users = self.tables.users + connection.execute( users.insert(), {"user_id": 7, "user_name": "jack"}, @@ -133,6 +112,7 @@ class QueryTest(fixtures.TestBase): @testing.requires.order_by_label_with_expression def test_order_by_label_compound(self, connection): + users = self.tables.users connection.execute( users.insert(), {"user_id": 7, "user_name": "jack"}, @@ -174,6 +154,7 @@ class QueryTest(fixtures.TestBase): assert row.y == False # noqa def test_select_tuple(self, connection): + users = self.tables.users connection.execute( users.insert(), {"user_id": 1, "user_name": "apples"}, @@ -187,6 +168,7 @@ class QueryTest(fixtures.TestBase): ) def test_like_ops(self, connection): + users = self.tables.users connection.execute( users.insert(), {"user_id": 1, "user_name": "apples"}, @@ -238,6 +220,7 @@ class QueryTest(fixtures.TestBase): eq_(connection.scalar(expr), result) def test_ilike(self, connection): + users = self.tables.users connection.execute( users.insert(), {"user_id": 1, "user_name": "one"}, @@ -279,15 +262,25 @@ class QueryTest(fixtures.TestBase): ) def test_compiled_execute(self, connection): + users = self.tables.users connection.execute(users.insert(), user_id=7, user_name="jack") - s = select(users).where(users.c.user_id == bindparam("id")).compile() + s = ( + select(users) + .where(users.c.user_id == bindparam("id")) + .compile(connection) + ) eq_(connection.execute(s, id=7).first()._mapping["user_id"], 7) def test_compiled_insert_execute(self, connection): + users = self.tables.users connection.execute( - users.insert().compile(), user_id=7, user_name="jack" + users.insert().compile(connection), user_id=7, user_name="jack" + ) + s = ( + select(users) + .where(users.c.user_id == bindparam("id")) + .compile(connection) ) - s = select(users).where(users.c.user_id == bindparam("id")).compile() eq_(connection.execute(s, id=7).first()._mapping["user_id"], 7) def test_repeated_bindparams(self, connection): @@ -296,6 +289,7 @@ class QueryTest(fixtures.TestBase): This should be run for DB-APIs with both positional and named paramstyles. """ + users = self.tables.users connection.execute(users.insert(), user_id=7, user_name="jack") connection.execute(users.insert(), user_id=8, user_name="fred") @@ -369,6 +363,8 @@ class QueryTest(fixtures.TestBase): Tests simple, compound, aliased and DESC clauses. """ + users = self.tables.users + connection.execute(users.insert(), user_id=1, user_name="c") connection.execute(users.insert(), user_id=2, user_name="b") connection.execute(users.insert(), user_id=3, user_name="a") @@ -469,6 +465,8 @@ class QueryTest(fixtures.TestBase): Tests simple, compound, aliased and DESC clauses. """ + users = self.tables.users + connection.execute(users.insert(), user_id=1) connection.execute(users.insert(), user_id=2, user_name="b") connection.execute(users.insert(), user_id=3, user_name="a") @@ -563,6 +561,7 @@ class QueryTest(fixtures.TestBase): def test_in_filtering(self, connection): """test the behavior of the in_() function.""" + users = self.tables.users connection.execute(users.insert(), user_id=7, user_name="jack") connection.execute(users.insert(), user_id=8, user_name="fred") @@ -587,6 +586,7 @@ class QueryTest(fixtures.TestBase): assert len(r) == 0 def test_expanding_in(self, connection): + users = self.tables.users connection.execute( users.insert(), [ @@ -626,6 +626,7 @@ class QueryTest(fixtures.TestBase): @testing.requires.no_quoting_special_bind_names def test_expanding_in_special_chars(self, connection): + users = self.tables.users connection.execute( users.insert(), [ @@ -663,6 +664,8 @@ class QueryTest(fixtures.TestBase): ) def test_expanding_in_multiple(self, connection): + users = self.tables.users + connection.execute( users.insert(), [ @@ -687,6 +690,8 @@ class QueryTest(fixtures.TestBase): ) def test_expanding_in_repeated(self, connection): + users = self.tables.users + connection.execute( users.insert(), [ @@ -727,6 +732,8 @@ class QueryTest(fixtures.TestBase): @testing.requires.tuple_in def test_expanding_in_composite(self, connection): + users = self.tables.users + connection.execute( users.insert(), [ @@ -768,7 +775,7 @@ class QueryTest(fixtures.TestBase): return value[3:] users = Table( - "query_users", + "users", MetaData(), Column("user_id", Integer, primary_key=True), Column("user_name", NameWithProcess()), @@ -812,6 +819,8 @@ class QueryTest(fixtures.TestBase): """ + users = self.tables.users + connection.execute(users.insert(), user_id=7, user_name="jack") connection.execute(users.insert(), user_id=8, user_name="fred") connection.execute(users.insert(), user_id=9, user_name=None) @@ -827,6 +836,8 @@ class QueryTest(fixtures.TestBase): def test_literal_in(self, connection): """similar to test_bind_in but use a bind with a value.""" + users = self.tables.users + connection.execute(users.insert(), user_id=7, user_name="jack") connection.execute(users.insert(), user_id=8, user_name="fred") connection.execute(users.insert(), user_id=9, user_name=None) @@ -842,6 +853,7 @@ class QueryTest(fixtures.TestBase): that a proper boolean value is generated. """ + users = self.tables.users connection.execute( users.insert(), @@ -932,63 +944,60 @@ class RequiredBindTest(fixtures.TablesTest): is_(bindparam("foo", callable_=c, required=False).required, False) -class LimitTest(fixtures.TestBase): +class LimitTest(fixtures.TablesTest): __backend__ = True @classmethod - def setup_class(cls): - global users, addresses, metadata - metadata = MetaData(testing.db) - users = Table( - "query_users", + def define_tables(cls, metadata): + Table( + "users", metadata, Column("user_id", INT, primary_key=True), Column("user_name", VARCHAR(20)), ) - addresses = Table( - "query_addresses", + Table( + "addresses", metadata, Column("address_id", Integer, primary_key=True), - Column("user_id", Integer, ForeignKey("query_users.user_id")), + Column("user_id", Integer, ForeignKey("users.user_id")), Column("address", String(30)), ) - metadata.create_all() - - with testing.db.begin() as conn: - conn.execute(users.insert(), user_id=1, user_name="john") - conn.execute( - addresses.insert(), address_id=1, user_id=1, address="addr1" - ) - conn.execute(users.insert(), user_id=2, user_name="jack") - conn.execute( - addresses.insert(), address_id=2, user_id=2, address="addr1" - ) - conn.execute(users.insert(), user_id=3, user_name="ed") - conn.execute( - addresses.insert(), address_id=3, user_id=3, address="addr2" - ) - conn.execute(users.insert(), user_id=4, user_name="wendy") - conn.execute( - addresses.insert(), address_id=4, user_id=4, address="addr3" - ) - conn.execute(users.insert(), user_id=5, user_name="laura") - conn.execute( - addresses.insert(), address_id=5, user_id=5, address="addr4" - ) - conn.execute(users.insert(), user_id=6, user_name="ralph") - conn.execute( - addresses.insert(), address_id=6, user_id=6, address="addr5" - ) - conn.execute(users.insert(), user_id=7, user_name="fido") - conn.execute( - addresses.insert(), address_id=7, user_id=7, address="addr5" - ) @classmethod - def teardown_class(cls): - metadata.drop_all() + def insert_data(cls, connection): + users, addresses = cls.tables("users", "addresses") + conn = connection + conn.execute(users.insert(), user_id=1, user_name="john") + conn.execute( + addresses.insert(), address_id=1, user_id=1, address="addr1" + ) + conn.execute(users.insert(), user_id=2, user_name="jack") + conn.execute( + addresses.insert(), address_id=2, user_id=2, address="addr1" + ) + conn.execute(users.insert(), user_id=3, user_name="ed") + conn.execute( + addresses.insert(), address_id=3, user_id=3, address="addr2" + ) + conn.execute(users.insert(), user_id=4, user_name="wendy") + conn.execute( + addresses.insert(), address_id=4, user_id=4, address="addr3" + ) + conn.execute(users.insert(), user_id=5, user_name="laura") + conn.execute( + addresses.insert(), address_id=5, user_id=5, address="addr4" + ) + conn.execute(users.insert(), user_id=6, user_name="ralph") + conn.execute( + addresses.insert(), address_id=6, user_id=6, address="addr5" + ) + conn.execute(users.insert(), user_id=7, user_name="fido") + conn.execute( + addresses.insert(), address_id=7, user_id=7, address="addr5" + ) def test_select_limit(self, connection): + users, addresses = self.tables("users", "addresses") r = connection.execute( users.select(limit=3, order_by=[users.c.user_id]) ).fetchall() @@ -998,6 +1007,8 @@ class LimitTest(fixtures.TestBase): def test_select_limit_offset(self, connection): """Test the interaction between limit and offset""" + users, addresses = self.tables("users", "addresses") + r = connection.execute( users.select(limit=3, offset=2, order_by=[users.c.user_id]) ).fetchall() @@ -1010,6 +1021,8 @@ class LimitTest(fixtures.TestBase): def test_select_distinct_limit(self, connection): """Test the interaction between limit and distinct""" + users, addresses = self.tables("users", "addresses") + r = sorted( [ x[0] @@ -1025,6 +1038,8 @@ class LimitTest(fixtures.TestBase): def test_select_distinct_offset(self, connection): """Test the interaction between distinct and offset""" + users, addresses = self.tables("users", "addresses") + r = sorted( [ x[0] @@ -1043,6 +1058,8 @@ class LimitTest(fixtures.TestBase): def test_select_distinct_limit_offset(self, connection): """Test the interaction between limit and limit/offset""" + users, addresses = self.tables("users", "addresses") + r = connection.execute( select(addresses.c.address) .order_by(addresses.c.address) @@ -1054,18 +1071,18 @@ class LimitTest(fixtures.TestBase): self.assert_(r[0] != r[1] and r[1] != r[2], repr(r)) -class CompoundTest(fixtures.TestBase): +class CompoundTest(fixtures.TablesTest): """test compound statements like UNION, INTERSECT, particularly their ability to nest on different databases.""" __backend__ = True + run_inserts = "each" + @classmethod - def setup_class(cls): - global metadata, t1, t2, t3 - metadata = MetaData(testing.db) - t1 = Table( + def define_tables(cls, metadata): + Table( "t1", metadata, Column( @@ -1078,7 +1095,7 @@ class CompoundTest(fixtures.TestBase): Column("col3", String(40)), Column("col4", String(30)), ) - t2 = Table( + Table( "t2", metadata, Column( @@ -1091,7 +1108,7 @@ class CompoundTest(fixtures.TestBase): Column("col3", String(40)), Column("col4", String(30)), ) - t3 = Table( + Table( "t3", metadata, Column( @@ -1104,47 +1121,42 @@ class CompoundTest(fixtures.TestBase): Column("col3", String(40)), Column("col4", String(30)), ) - metadata.create_all() - - with testing.db.begin() as conn: - conn.execute( - t1.insert(), - [ - dict(col2="t1col2r1", col3="aaa", col4="aaa"), - dict(col2="t1col2r2", col3="bbb", col4="bbb"), - dict(col2="t1col2r3", col3="ccc", col4="ccc"), - ], - ) - conn.execute( - t2.insert(), - [ - dict(col2="t2col2r1", col3="aaa", col4="bbb"), - dict(col2="t2col2r2", col3="bbb", col4="ccc"), - dict(col2="t2col2r3", col3="ccc", col4="aaa"), - ], - ) - conn.execute( - t3.insert(), - [ - dict(col2="t3col2r1", col3="aaa", col4="ccc"), - dict(col2="t3col2r2", col3="bbb", col4="aaa"), - dict(col2="t3col2r3", col3="ccc", col4="bbb"), - ], - ) - - @engines.close_first - def teardown(self): - pass @classmethod - def teardown_class(cls): - metadata.drop_all() + def insert_data(cls, connection): + t1, t2, t3 = cls.tables("t1", "t2", "t3") + conn = connection + conn.execute( + t1.insert(), + [ + dict(col2="t1col2r1", col3="aaa", col4="aaa"), + dict(col2="t1col2r2", col3="bbb", col4="bbb"), + dict(col2="t1col2r3", col3="ccc", col4="ccc"), + ], + ) + conn.execute( + t2.insert(), + [ + dict(col2="t2col2r1", col3="aaa", col4="bbb"), + dict(col2="t2col2r2", col3="bbb", col4="ccc"), + dict(col2="t2col2r3", col3="ccc", col4="aaa"), + ], + ) + conn.execute( + t3.insert(), + [ + dict(col2="t3col2r1", col3="aaa", col4="ccc"), + dict(col2="t3col2r2", col3="bbb", col4="aaa"), + dict(col2="t3col2r3", col3="ccc", col4="bbb"), + ], + ) def _fetchall_sorted(self, executed): return sorted([tuple(row) for row in executed.fetchall()]) @testing.requires.subqueries def test_union(self, connection): + t1, t2, t3 = self.tables("t1", "t2", "t3") (s1, s2) = ( select(t1.c.col3.label("col3"), t1.c.col4.label("col4")).where( t1.c.col2.in_(["t1col2r1", "t1col2r2"]), @@ -1171,6 +1183,8 @@ class CompoundTest(fixtures.TestBase): @testing.fails_on("firebird", "doesn't like ORDER BY with UNIONs") def test_union_ordered(self, connection): + t1, t2, t3 = self.tables("t1", "t2", "t3") + (s1, s2) = ( select(t1.c.col3.label("col3"), t1.c.col4.label("col4")).where( t1.c.col2.in_(["t1col2r1", "t1col2r2"]), @@ -1192,6 +1206,8 @@ class CompoundTest(fixtures.TestBase): @testing.fails_on("firebird", "doesn't like ORDER BY with UNIONs") @testing.requires.subqueries def test_union_ordered_alias(self, connection): + t1, t2, t3 = self.tables("t1", "t2", "t3") + (s1, s2) = ( select(t1.c.col3.label("col3"), t1.c.col4.label("col4")).where( t1.c.col2.in_(["t1col2r1", "t1col2r2"]), @@ -1220,6 +1236,8 @@ class CompoundTest(fixtures.TestBase): ) @testing.fails_on("sqlite", "FIXME: unknown") def test_union_all(self, connection): + t1, t2, t3 = self.tables("t1", "t2", "t3") + e = union_all( select(t1.c.col3), union(select(t1.c.col3), select(t1.c.col3)), @@ -1241,6 +1259,8 @@ class CompoundTest(fixtures.TestBase): """ + t1, t2, t3 = self.tables("t1", "t2", "t3") + u = union(select(t1.c.col3), select(t1.c.col3)).alias() e = union_all(select(t1.c.col3), select(u.c.col3)) @@ -1256,6 +1276,8 @@ class CompoundTest(fixtures.TestBase): @testing.requires.intersect def test_intersect(self, connection): + t1, t2, t3 = self.tables("t1", "t2", "t3") + i = intersect( select(t2.c.col3, t2.c.col4), select(t2.c.col3, t2.c.col4).where(t2.c.col4 == t3.c.col3), @@ -1274,6 +1296,8 @@ class CompoundTest(fixtures.TestBase): @testing.requires.except_ @testing.fails_on("sqlite", "Can't handle this style of nesting") def test_except_style1(self, connection): + t1, t2, t3 = self.tables("t1", "t2", "t3") + e = except_( union( select(t1.c.col3, t1.c.col4), @@ -1300,6 +1324,8 @@ class CompoundTest(fixtures.TestBase): # same as style1, but add alias().select() to the except_(). # sqlite can handle it now. + t1, t2, t3 = self.tables("t1", "t2", "t3") + e = except_( union( select(t1.c.col3, t1.c.col4), @@ -1333,6 +1359,8 @@ class CompoundTest(fixtures.TestBase): @testing.requires.except_ def test_except_style3(self, connection): # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc + t1, t2, t3 = self.tables("t1", "t2", "t3") + e = except_( select(t1.c.col3), # aaa, bbb, ccc except_( @@ -1346,6 +1374,8 @@ class CompoundTest(fixtures.TestBase): @testing.requires.except_ def test_except_style4(self, connection): # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc + t1, t2, t3 = self.tables("t1", "t2", "t3") + e = except_( select(t1.c.col3), # aaa, bbb, ccc except_( @@ -1365,6 +1395,8 @@ class CompoundTest(fixtures.TestBase): "sqlite can't handle leading parenthesis", ) def test_intersect_unions(self, connection): + t1, t2, t3 = self.tables("t1", "t2", "t3") + u = intersect( union(select(t1.c.col3, t1.c.col4), select(t3.c.col3, t3.c.col4)), union(select(t2.c.col3, t2.c.col4), select(t3.c.col3, t3.c.col4)) @@ -1378,6 +1410,8 @@ class CompoundTest(fixtures.TestBase): @testing.requires.intersect def test_intersect_unions_2(self, connection): + t1, t2, t3 = self.tables("t1", "t2", "t3") + u = intersect( union(select(t1.c.col3, t1.c.col4), select(t3.c.col3, t3.c.col4)) .alias() @@ -1393,6 +1427,8 @@ class CompoundTest(fixtures.TestBase): @testing.requires.intersect def test_intersect_unions_3(self, connection): + t1, t2, t3 = self.tables("t1", "t2", "t3") + u = intersect( select(t2.c.col3, t2.c.col4), union( @@ -1410,6 +1446,8 @@ class CompoundTest(fixtures.TestBase): @testing.requires.intersect def test_composite_alias(self, connection): + t1, t2, t3 = self.tables("t1", "t2", "t3") + ua = intersect( select(t2.c.col3, t2.c.col4), union( @@ -1426,10 +1464,7 @@ class CompoundTest(fixtures.TestBase): eq_(found, wanted) -t1 = t2 = t3 = None - - -class JoinTest(fixtures.TestBase): +class JoinTest(fixtures.TablesTest): """Tests join execution. @@ -1443,56 +1478,48 @@ class JoinTest(fixtures.TestBase): __backend__ = True @classmethod - def setup_class(cls): - global metadata - global t1, t2, t3 - - metadata = MetaData(testing.db) - t1 = Table( + def define_tables(cls, metadata): + Table( "t1", metadata, Column("t1_id", Integer, primary_key=True), Column("name", String(32)), ) - t2 = Table( + Table( "t2", metadata, Column("t2_id", Integer, primary_key=True), Column("t1_id", Integer, ForeignKey("t1.t1_id")), Column("name", String(32)), ) - t3 = Table( + Table( "t3", metadata, Column("t3_id", Integer, primary_key=True), Column("t2_id", Integer, ForeignKey("t2.t2_id")), Column("name", String(32)), ) - metadata.drop_all() - metadata.create_all() - - with testing.db.begin() as conn: - # t1.10 -> t2.20 -> t3.30 - # t1.11 -> t2.21 - # t1.12 - conn.execute( - t1.insert(), - {"t1_id": 10, "name": "t1 #10"}, - {"t1_id": 11, "name": "t1 #11"}, - {"t1_id": 12, "name": "t1 #12"}, - ) - conn.execute( - t2.insert(), - {"t2_id": 20, "t1_id": 10, "name": "t2 #20"}, - {"t2_id": 21, "t1_id": 11, "name": "t2 #21"}, - ) - conn.execute( - t3.insert(), {"t3_id": 30, "t2_id": 20, "name": "t3 #30"} - ) @classmethod - def teardown_class(cls): - metadata.drop_all() + def insert_data(cls, connection): + conn = connection + # t1.10 -> t2.20 -> t3.30 + # t1.11 -> t2.21 + # t1.12 + t1, t2, t3 = cls.tables("t1", "t2", "t3") + + conn.execute( + t1.insert(), + {"t1_id": 10, "name": "t1 #10"}, + {"t1_id": 11, "name": "t1 #11"}, + {"t1_id": 12, "name": "t1 #12"}, + ) + conn.execute( + t2.insert(), + {"t2_id": 20, "t1_id": 10, "name": "t2 #20"}, + {"t2_id": 21, "t1_id": 11, "name": "t2 #21"}, + ) + conn.execute(t3.insert(), {"t3_id": 30, "t2_id": 20, "name": "t3 #30"}) def assertRows(self, statement, expected): """Execute a statement and assert that rows returned equal expected.""" @@ -1504,6 +1531,7 @@ class JoinTest(fixtures.TestBase): def test_join_x1(self): """Joins t1->t2.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") for criteria in (t1.c.t1_id == t2.c.t1_id, t2.c.t1_id == t1.c.t1_id): expr = select(t1.c.t1_id, t2.c.t2_id).select_from( @@ -1513,6 +1541,7 @@ class JoinTest(fixtures.TestBase): def test_join_x2(self): """Joins t1->t2->t3.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") for criteria in (t1.c.t1_id == t2.c.t1_id, t2.c.t1_id == t1.c.t1_id): expr = select(t1.c.t1_id, t2.c.t2_id).select_from( @@ -1522,6 +1551,7 @@ class JoinTest(fixtures.TestBase): def test_outerjoin_x1(self): """Outer joins t1->t2.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select(t1.c.t1_id, t2.c.t2_id).select_from( @@ -1531,6 +1561,7 @@ class JoinTest(fixtures.TestBase): def test_outerjoin_x2(self): """Outer joins t1->t2,t3.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id).select_from( @@ -1544,6 +1575,7 @@ class JoinTest(fixtures.TestBase): def test_outerjoin_where_x2_t1(self): """Outer joins t1->t2,t3, where on t1.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = ( @@ -1574,6 +1606,7 @@ class JoinTest(fixtures.TestBase): def test_outerjoin_where_x2_t2(self): """Outer joins t1->t2,t3, where on t2.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = ( @@ -1604,6 +1637,7 @@ class JoinTest(fixtures.TestBase): def test_outerjoin_where_x2_t3(self): """Outer joins t1->t2,t3, where on t3.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = ( @@ -1635,6 +1669,8 @@ class JoinTest(fixtures.TestBase): def test_outerjoin_where_x2_t1t3(self): """Outer joins t1->t2,t3, where on t1 and t3.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") + for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = ( select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) @@ -1664,6 +1700,8 @@ class JoinTest(fixtures.TestBase): def test_outerjoin_where_x2_t1t2(self): """Outer joins t1->t2,t3, where on t1 and t2.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") + for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = ( select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) @@ -1693,6 +1731,7 @@ class JoinTest(fixtures.TestBase): def test_outerjoin_where_x2_t1t2t3(self): """Outer joins t1->t2,t3, where on t1, t2 and t3.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = ( @@ -1729,6 +1768,7 @@ class JoinTest(fixtures.TestBase): def test_mixed(self): """Joins t1->t2, outer t2->t3.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id).select_from( @@ -1739,6 +1779,7 @@ class JoinTest(fixtures.TestBase): def test_mixed_where(self): """Joins t1->t2, outer t2->t3, plus a where on each table in turn.""" + t1, t2, t3 = self.tables("t1", "t2", "t3") for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = ( @@ -1800,17 +1841,12 @@ class JoinTest(fixtures.TestBase): self.assertRows(expr, [(10, 20, 30)]) -metadata = flds = None - - -class OperatorTest(fixtures.TestBase): +class OperatorTest(fixtures.TablesTest): __backend__ = True @classmethod - def setup_class(cls): - global metadata, flds - metadata = MetaData(testing.db) - flds = Table( + def define_tables(cls, metadata): + Table( "flds", metadata, Column( @@ -1822,20 +1858,19 @@ class OperatorTest(fixtures.TestBase): Column("intcol", Integer), Column("strcol", String(50)), ) - metadata.create_all() - - with testing.db.begin() as conn: - conn.execute( - flds.insert(), - [dict(intcol=5, strcol="foo"), dict(intcol=13, strcol="bar")], - ) @classmethod - def teardown_class(cls): - metadata.drop_all() + def insert_data(cls, connection): + flds = cls.tables.flds + connection.execute( + flds.insert(), + [dict(intcol=5, strcol="foo"), dict(intcol=13, strcol="bar")], + ) # TODO: seems like more tests warranted for this setup. def test_modulo(self, connection): + flds = self.tables.flds + eq_( connection.execute( select(flds.c.intcol % 3).order_by(flds.c.idcol) @@ -1845,6 +1880,8 @@ class OperatorTest(fixtures.TestBase): @testing.requires.window_functions def test_over(self, connection): + flds = self.tables.flds + eq_( connection.execute( select( diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index a78d6c16b..c743918c8 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -187,31 +187,6 @@ class QuoteExecTest(fixtures.TablesTest): class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" - @classmethod - def setup_class(cls): - # TODO: figure out which databases/which identifiers allow special - # characters to be used, such as: spaces, quote characters, - # punctuation characters, set up tests for those as well. - - global table1, table2 - metadata = MetaData(testing.db) - - table1 = Table( - "WorstCase1", - metadata, - Column("lowercase", Integer, primary_key=True), - Column("UPPERCASE", Integer), - Column("MixedCase", Integer), - Column("ASC", Integer, key="a123"), - ) - table2 = Table( - "WorstCase2", - metadata, - Column("desc", Integer, primary_key=True, key="d123"), - Column("Union", Integer, key="u123"), - Column("MixedCase", Integer), - ) - @testing.crashes("oracle", "FIXME: unknown, verify not fails_on") @testing.requires.subqueries def test_labels(self): @@ -234,6 +209,23 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): where the "UPPERCASE" column of "LaLa" doesn't exist. """ + metadata = MetaData() + table1 = Table( + "WorstCase1", + metadata, + Column("lowercase", Integer, primary_key=True), + Column("UPPERCASE", Integer), + Column("MixedCase", Integer), + Column("ASC", Integer, key="a123"), + ) + Table( + "WorstCase2", + metadata, + Column("desc", Integer, primary_key=True, key="d123"), + Column("Union", Integer, key="u123"), + Column("MixedCase", Integer), + ) + self.assert_compile( table1.select(distinct=True).alias("LaLa").select(), "SELECT " diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 9f2afd7b7..187cf0dd0 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -11,6 +11,7 @@ from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import type_coerce from sqlalchemy import update from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -121,6 +122,7 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): Column("persons", Integer), Column("full", Boolean), Column("goofy", GoofyType(50)), + Column("strval", String(50)), ) def test_column_targeting(self, connection): @@ -197,6 +199,88 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): ) eq_(result2.fetchall(), [(1, True), (2, False)]) + @testing.fails_on( + "mssql", + "driver has unknown issue with string concatenation " + "in INSERT RETURNING", + ) + def test_insert_returning_w_expression_one(self, connection): + table = self.tables.tables + result = connection.execute( + table.insert().returning(table.c.strval + "hi"), + {"persons": 5, "full": False, "strval": "str1"}, + ) + + eq_(result.fetchall(), [("str1hi",)]) + + result2 = connection.execute( + select(table.c.id, table.c.strval).order_by(table.c.id) + ) + eq_(result2.fetchall(), [(1, "str1")]) + + def test_insert_returning_w_type_coerce_expression(self, connection): + table = self.tables.tables + result = connection.execute( + table.insert().returning(type_coerce(table.c.goofy, String)), + {"persons": 5, "goofy": "somegoofy"}, + ) + + eq_(result.fetchall(), [("FOOsomegoofy",)]) + + result2 = connection.execute( + select(table.c.id, table.c.goofy).order_by(table.c.id) + ) + eq_(result2.fetchall(), [(1, "FOOsomegoofyBAR")]) + + def test_update_returning_w_expression_one(self, connection): + table = self.tables.tables + connection.execute( + table.insert(), + [ + {"persons": 5, "full": False, "strval": "str1"}, + {"persons": 3, "full": False, "strval": "str2"}, + ], + ) + + result = connection.execute( + table.update() + .where(table.c.persons > 4) + .values(full=True) + .returning(table.c.strval + "hi") + ) + eq_(result.fetchall(), [("str1hi",)]) + + result2 = connection.execute( + select(table.c.id, table.c.strval).order_by(table.c.id) + ) + eq_(result2.fetchall(), [(1, "str1"), (2, "str2")]) + + def test_update_returning_w_type_coerce_expression(self, connection): + table = self.tables.tables + connection.execute( + table.insert(), + [ + {"persons": 5, "goofy": "somegoofy1"}, + {"persons": 3, "goofy": "somegoofy2"}, + ], + ) + + result = connection.execute( + table.update() + .where(table.c.persons > 4) + .values(goofy="newgoofy") + .returning(type_coerce(table.c.goofy, String)) + ) + eq_(result.fetchall(), [("FOOnewgoofy",)]) + + result2 = connection.execute( + select(table.c.id, table.c.goofy).order_by(table.c.id) + ) + eq_( + result2.fetchall(), + [(1, "FOOnewgoofyBAR"), (2, "FOOsomegoofy2BAR")], + ) + @testing.requires.full_returning def test_update_full_returning(self, connection): table = self.tables.tables diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py index 1809e0cca..65325aa6f 100644 --- a/test/sql/test_sequences.py +++ b/test/sql/test_sequences.py @@ -520,7 +520,7 @@ class SequenceAsServerDefaultTest( def test_drop_ordering(self): with self.sql_execution_asserter(testing.db) as asserter: - self.metadata.drop_all(checkfirst=False) + self.tables_test_metadata.drop_all(testing.db, checkfirst=False) asserter.assert_( AllOf( diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 919c4b4f9..0e1147800 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -584,6 +584,9 @@ class _UserDefinedTypeFixture(object): def copy(self): return MyUnicodeType(self.impl.length) + class MyDecOfDec(types.TypeDecorator): + impl = MyNewIntType + Table( "users", metadata, @@ -596,6 +599,7 @@ class _UserDefinedTypeFixture(object): Column("goofy7", MyNewUnicodeType(50), nullable=False), Column("goofy8", MyNewIntType, nullable=False), Column("goofy9", MyNewIntSubClass, nullable=False), + Column("goofy10", MyDecOfDec, nullable=False), ) @@ -614,6 +618,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): goofy7=util.u("jack"), goofy8=12, goofy9=12, + goofy10=12, ), ) connection.execute( @@ -626,6 +631,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): goofy7=util.u("lala"), goofy8=15, goofy9=15, + goofy10=15, ), ) connection.execute( @@ -638,6 +644,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): goofy7=util.u("fred"), goofy8=9, goofy9=9, + goofy10=9, ), ) @@ -665,7 +672,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): for col in row[3], row[4]: assert isinstance(col, util.text_type) - def test_plain_in(self, connection): + def test_plain_in_typedec(self, connection): users = self.tables.users self._data_fixture(connection) @@ -677,7 +684,19 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): result = connection.execute(stmt, {"goofy": [15, 9]}) eq_(result.fetchall(), [(3, 1500), (4, 900)]) - def test_expanding_in(self, connection): + def test_plain_in_typedec_of_typedec(self, connection): + users = self.tables.users + self._data_fixture(connection) + + stmt = ( + select(users.c.user_id, users.c.goofy10) + .where(users.c.goofy10.in_([15, 9])) + .order_by(users.c.user_id) + ) + result = connection.execute(stmt, {"goofy": [15, 9]}) + eq_(result.fetchall(), [(3, 1500), (4, 900)]) + + def test_expanding_in_typedec(self, connection): users = self.tables.users self._data_fixture(connection) @@ -689,6 +708,18 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): result = connection.execute(stmt, {"goofy": [15, 9]}) eq_(result.fetchall(), [(3, 1500), (4, 900)]) + def test_expanding_in_typedec_of_typedec(self, connection): + users = self.tables.users + self._data_fixture(connection) + + stmt = ( + select(users.c.user_id, users.c.goofy10) + .where(users.c.goofy10.in_(bindparam("goofy", expanding=True))) + .order_by(users.c.user_id) + ) + result = connection.execute(stmt, {"goofy": [15, 9]}) + eq_(result.fetchall(), [(3, 1500), (4, 900)]) + class UserDefinedTest( _UserDefinedTypeFixture, fixtures.TablesTest, AssertsCompiledSQL @@ -1177,6 +1208,172 @@ class TypeCoerceCastTest(fixtures.TablesTest): ) +class VariantBackendTest(fixtures.TestBase, AssertsCompiledSQL): + __backend__ = True + + @testing.fixture + def variant_roundtrip(self, metadata, connection): + def run(datatype, data, assert_data): + t = Table( + "t", + metadata, + Column("data", datatype), + ) + t.create(connection) + + connection.execute(t.insert(), [{"data": elem} for elem in data]) + eq_( + connection.execute(select(t).order_by(t.c.data)).all(), + [(elem,) for elem in assert_data], + ) + + eq_( + # test an IN, which in 1.4 is an expanding + connection.execute( + select(t).where(t.c.data.in_(data)).order_by(t.c.data) + ).all(), + [(elem,) for elem in assert_data], + ) + + return run + + def test_type_decorator_variant_one_roundtrip(self, variant_roundtrip): + class Foo(TypeDecorator): + impl = String(50) + + if testing.against("postgresql"): + data = [5, 6, 10] + else: + data = ["five", "six", "ten"] + variant_roundtrip( + Foo().with_variant(Integer, "postgresql"), data, data + ) + + def test_type_decorator_variant_two(self, variant_roundtrip): + class UTypeOne(types.UserDefinedType): + def get_col_spec(self): + return "VARCHAR(50)" + + def bind_processor(self, dialect): + def process(value): + return value + "UONE" + + return process + + class UTypeTwo(types.UserDefinedType): + def get_col_spec(self): + return "VARCHAR(50)" + + def bind_processor(self, dialect): + def process(value): + return value + "UTWO" + + return process + + variant = UTypeOne() + for db in ["postgresql", "mysql", "mariadb"]: + variant = variant.with_variant(UTypeTwo(), db) + + class Foo(TypeDecorator): + impl = variant + + if testing.against("postgresql"): + data = assert_data = [5, 6, 10] + elif testing.against("mysql") or testing.against("mariadb"): + data = ["five", "six", "ten"] + assert_data = ["fiveUTWO", "sixUTWO", "tenUTWO"] + else: + data = ["five", "six", "ten"] + assert_data = ["fiveUONE", "sixUONE", "tenUONE"] + + variant_roundtrip( + Foo().with_variant(Integer, "postgresql"), data, assert_data + ) + + def test_type_decorator_variant_three(self, variant_roundtrip): + class Foo(TypeDecorator): + impl = String + + if testing.against("postgresql"): + data = ["five", "six", "ten"] + else: + data = [5, 6, 10] + + variant_roundtrip( + Integer().with_variant(Foo(), "postgresql"), data, data + ) + + def test_type_decorator_compile_variant_one(self): + class Foo(TypeDecorator): + impl = String + + self.assert_compile( + Foo().with_variant(Integer, "sqlite"), + "INTEGER", + dialect=dialects.sqlite.dialect(), + ) + + self.assert_compile( + Foo().with_variant(Integer, "sqlite"), + "VARCHAR", + dialect=dialects.postgresql.dialect(), + ) + + def test_type_decorator_compile_variant_two(self): + class UTypeOne(types.UserDefinedType): + def get_col_spec(self): + return "UTYPEONE" + + def bind_processor(self, dialect): + def process(value): + return value + "UONE" + + return process + + class UTypeTwo(types.UserDefinedType): + def get_col_spec(self): + return "UTYPETWO" + + def bind_processor(self, dialect): + def process(value): + return value + "UTWO" + + return process + + variant = UTypeOne().with_variant(UTypeTwo(), "postgresql") + + class Foo(TypeDecorator): + impl = variant + + self.assert_compile( + Foo().with_variant(Integer, "sqlite"), + "INTEGER", + dialect=dialects.sqlite.dialect(), + ) + + self.assert_compile( + Foo().with_variant(Integer, "sqlite"), + "UTYPETWO", + dialect=dialects.postgresql.dialect(), + ) + + def test_type_decorator_compile_variant_three(self): + class Foo(TypeDecorator): + impl = String + + self.assert_compile( + Integer().with_variant(Foo(), "postgresql"), + "INTEGER", + dialect=dialects.sqlite.dialect(), + ) + + self.assert_compile( + Integer().with_variant(Foo(), "postgresql"), + "VARCHAR", + dialect=dialects.postgresql.dialect(), + ) + + class VariantTest(fixtures.TestBase, AssertsCompiledSQL): def setup(self): class UTypeOne(types.UserDefinedType): @@ -2539,6 +2736,9 @@ class ExpressionTest( def process_result_value(self, value, dialect): return value + "BIND_OUT" + class MyDecOfDec(types.TypeDecorator): + impl = MyTypeDec + Table( "test", metadata, @@ -2547,6 +2747,7 @@ class ExpressionTest( Column("atimestamp", Date), Column("avalue", MyCustomType), Column("bvalue", MyTypeDec(50)), + Column("cvalue", MyDecOfDec(50)), ) @classmethod @@ -2560,6 +2761,7 @@ class ExpressionTest( "atimestamp": datetime.date(2007, 10, 15), "avalue": 25, "bvalue": "foo", + "cvalue": "foo", }, ) @@ -2579,6 +2781,7 @@ class ExpressionTest( datetime.date(2007, 10, 15), 25, "BIND_INfooBIND_OUT", + "BIND_INfooBIND_OUT", ) ], ) @@ -2617,6 +2820,7 @@ class ExpressionTest( datetime.date(2007, 10, 15), 25, "BIND_INfooBIND_OUT", + "BIND_INfooBIND_OUT", ) ], ) @@ -2635,6 +2839,7 @@ class ExpressionTest( datetime.date(2007, 10, 15), 25, "BIND_INfooBIND_OUT", + "BIND_INfooBIND_OUT", ) ], ) @@ -3505,34 +3710,26 @@ class PickleTest(fixtures.TestBase): assert p1.compare_values(p1.copy_value(obj), obj) -meta = None - - class CallableTest(fixtures.TestBase): - @classmethod - def setup_class(cls): - global meta - meta = MetaData(testing.db) - - @classmethod - def teardown_class(cls): - meta.drop_all() - - def test_callable_as_arg(self): + @testing.provide_metadata + def test_callable_as_arg(self, connection): ucode = util.partial(Unicode) - thing_table = Table("thing", meta, Column("name", ucode(20))) + thing_table = Table("thing", self.metadata, Column("name", ucode(20))) assert isinstance(thing_table.c.name.type, Unicode) - thing_table.create() + thing_table.create(connection) - def test_callable_as_kwarg(self): + @testing.provide_metadata + def test_callable_as_kwarg(self, connection): ucode = util.partial(Unicode) thang_table = Table( - "thang", meta, Column("name", type_=ucode(20), primary_key=True) + "thang", + self.metadata, + Column("name", type_=ucode(20), primary_key=True), ) assert isinstance(thang_table.c.name.type, Unicode) - thang_table.create() + thang_table.create(connection) class LiteralTest(fixtures.TestBase): @@ -25,7 +25,7 @@ deps=pytest>=4.6.11 # this can be 6.x once we are on python 3 only postgresql: .[postgresql_pg8000]; python_version >= '3' mysql: .[mysql] mysql: .[pymysql] - mysql: .[aiomysql]; python_version >= '3' + mysql: git+https://github.com/sqlalchemy/aiomysql@sqlalchemy_tox; python_version >= '3' mysql: .[mariadb_connector]; python_version >= '3' # we should probably try to get mysql_connector back in the mix @@ -74,11 +74,11 @@ setenv= sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file} postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql} - py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg?async_fallback=true --dbdriver pg8000} + py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000} mysql: MYSQL={env:TOX_MYSQL:--db mysql} mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql} - py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql?async_fallback=true} + py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql} mssql: MSSQL={env:TOX_MSSQL:--db mssql} |
