diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2021-01-14 03:40:22 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2021-01-14 03:40:22 +0000 |
| commit | 0e36c1bbe7c4fc83df13e756f04d9fa0fa8d5d39 (patch) | |
| tree | e204b02318476336a519cd6b0016ef54c9245f0f /lib | |
| parent | 0a41f9bea6602c52c59af0f7b572308b2c2b27ab (diff) | |
| parent | f1e96cb0874927a475d0c111393b7861796dd758 (diff) | |
| download | sqlalchemy-0e36c1bbe7c4fc83df13e756f04d9fa0fa8d5d39.tar.gz | |
Merge "reinvent xdist hooks in terms of pytest fixtures"
Diffstat (limited to 'lib')
26 files changed, 828 insertions, 342 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 538679fcf..0227e515d 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2785,15 +2785,14 @@ class MSDialect(default.DefaultDialect): def has_table(self, connection, tablename, dbname, owner, schema): if tablename.startswith("#"): # temporary table tables = ischema.mssql_temp_table_columns - result = connection.execute( - sql.select(tables.c.table_name) - .where( - tables.c.table_name.like( - self._temp_table_name_like_pattern(tablename) - ) + + s = sql.select(tables.c.table_name).where( + tables.c.table_name.like( + self._temp_table_name_like_pattern(tablename) ) - .limit(1) ) + + result = connection.execute(s.limit(1)) return result.scalar() is not None else: tables = ischema.tables diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py index 269eb164f..56f3305a7 100644 --- a/lib/sqlalchemy/dialects/mssql/provision.py +++ b/lib/sqlalchemy/dialects/mssql/provision.py @@ -1,6 +1,14 @@ +from sqlalchemy import inspect +from sqlalchemy import Integer from ... import create_engine from ... import exc +from ...schema import Column +from ...schema import DropConstraint +from ...schema import ForeignKeyConstraint +from ...schema import MetaData +from ...schema import Table from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_pre_tables from ...testing.provision import drop_db from ...testing.provision import get_temp_table_name from ...testing.provision import log @@ -38,7 +46,6 @@ def _mssql_drop_ignore(conn, ident): # "where database_id=db_id('%s')" % ident): # log.info("killing SQL server session %s", row['session_id']) # conn.exec_driver_sql("kill %s" % row['session_id']) - conn.exec_driver_sql("drop database %s" % ident) log.info("Reaped db: %s", ident) return True @@ -83,4 +90,27 @@ def _mssql_temp_table_keyword_args(cfg, eng): @get_temp_table_name.for_db("mssql") def _mssql_get_temp_table_name(cfg, eng, base_name): - return "#" + base_name + return "##" + base_name + + +@drop_all_schema_objects_pre_tables.for_db("mssql") +def drop_all_schema_objects_pre_tables(cfg, eng): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + inspector = inspect(conn) + for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2): + for tname in inspector.get_table_names(schema=schema): + tb = Table( + tname, + MetaData(), + Column("x", Integer), + Column("y", Integer), + schema=schema, + ) + for fk in inspect(conn).get_foreign_keys(tname, schema=schema): + conn.execute( + DropConstraint( + ForeignKeyConstraint( + [tb.c.x], [tb.c.y], name=fk["name"] + ) + ) + ) diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 042443692..b8b4df760 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -93,6 +93,7 @@ The parameters accepted by the cx_oracle dialect are as follows: * ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail. + .. _cx_oracle_unicode: Unicode diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py index d51131c0b..e0dadd58e 100644 --- a/lib/sqlalchemy/dialects/oracle/provision.py +++ b/lib/sqlalchemy/dialects/oracle/provision.py @@ -6,11 +6,11 @@ from ...testing.provision import create_db from ...testing.provision import drop_db from ...testing.provision import follower_url_from_main from ...testing.provision import log +from ...testing.provision import post_configure_engine from ...testing.provision import run_reap_dbs from ...testing.provision import set_default_schema_on_connection -from ...testing.provision import stop_test_class +from ...testing.provision import stop_test_class_outside_fixtures from ...testing.provision import temp_table_keyword_args -from ...testing.provision import update_db_opts @create_db.for_db("oracle") @@ -57,21 +57,39 @@ def _oracle_drop_db(cfg, eng, ident): _ora_drop_ignore(conn, "%s_ts2" % ident) -@update_db_opts.for_db("oracle") -def _oracle_update_db_opts(db_url, db_opts): - pass +@stop_test_class_outside_fixtures.for_db("oracle") +def stop_test_class_outside_fixtures(config, db, cls): + with db.begin() as conn: + # run magic command to get rid of identity sequences + # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa E501 + conn.exec_driver_sql("purge recyclebin") -@stop_test_class.for_db("oracle") -def stop_test_class(config, db, cls): - """run magic command to get rid of identity sequences + # clear statement cache on all connections that were used + # https://github.com/oracle/python-cx_Oracle/issues/519 - # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ + for cx_oracle_conn in _all_conns: + try: + sc = cx_oracle_conn.stmtcachesize + except db.dialect.dbapi.InterfaceError: + # connection closed + pass + else: + cx_oracle_conn.stmtcachesize = 0 + cx_oracle_conn.stmtcachesize = sc + _all_conns.clear() - """ - with db.begin() as conn: - conn.exec_driver_sql("purge recyclebin") +_all_conns = set() + + +@post_configure_engine.for_db("oracle") +def _oracle_post_configure_engine(url, engine, follower_ident): + from sqlalchemy import event + + @event.listens_for(engine, "checkout") + def checkout(dbapi_con, con_record, con_proxy): + _all_conns.add(dbapi_con) @run_reap_dbs.for_db("oracle") diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 7c6e8fb02..e542c77f4 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -670,7 +670,6 @@ class AsyncAdapt_asyncpg_connection: def rollback(self): if self._started: self.await_(self._transaction.rollback()) - self._transaction = None self._started = False diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py index d345cdfdf..70c390800 100644 --- a/lib/sqlalchemy/dialects/postgresql/provision.py +++ b/lib/sqlalchemy/dialects/postgresql/provision.py @@ -8,6 +8,7 @@ from ...testing.provision import drop_all_schema_objects_post_tables from ...testing.provision import drop_all_schema_objects_pre_tables from ...testing.provision import drop_db from ...testing.provision import log +from ...testing.provision import prepare_for_drop_tables from ...testing.provision import set_default_schema_on_connection from ...testing.provision import temp_table_keyword_args @@ -102,3 +103,23 @@ def drop_all_schema_objects_post_tables(cfg, eng): postgresql.ENUM(name=enum["name"], schema=enum["schema"]) ) ) + + +@prepare_for_drop_tables.for_db("postgresql") +def prepare_for_drop_tables(config, connection): + """Ensure there are no locks on the current username/database.""" + + result = connection.exec_driver_sql( + "select pid, state, wait_event_type, query " + # "select pg_terminate_backend(pid), state, wait_event_type " + "from pg_stat_activity where " + "usename=current_user " + "and datname=current_database() and state='idle in transaction' " + "and pid != pg_backend_pid()" + ) + rows = result.all() # noqa + assert not rows, ( + "PostgreSQL may not be able to DROP tables due to " + "idle in transaction: %s" + % ("; ".join(row._mapping["query"] for row in rows)) + ) diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py index f26c21e22..a481be27e 100644 --- a/lib/sqlalchemy/dialects/sqlite/provision.py +++ b/lib/sqlalchemy/dialects/sqlite/provision.py @@ -7,7 +7,7 @@ from ...testing.provision import follower_url_from_main from ...testing.provision import log from ...testing.provision import post_configure_engine from ...testing.provision import run_reap_dbs -from ...testing.provision import stop_test_class +from ...testing.provision import stop_test_class_outside_fixtures from ...testing.provision import temp_table_keyword_args @@ -57,8 +57,8 @@ def _sqlite_drop_db(cfg, eng, ident): os.remove(path) -@stop_test_class.for_db("sqlite") -def stop_test_class(config, db, cls): +@stop_test_class_outside_fixtures.for_db("sqlite") +def stop_test_class_outside_fixtures(config, db, cls): with db.connect() as conn: files = [ row.file diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 50f00c025..72d66b7c8 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -2729,14 +2729,16 @@ class Engine(Connectable, log.Identified): return self.conn def __exit__(self, type_, value, traceback): - - if type_ is not None: - self.transaction.rollback() - else: - if self.transaction.is_active: - self.transaction.commit() - if not self.close_with_result: - self.conn.close() + try: + if type_ is not None: + if self.transaction.is_active: + self.transaction.rollback() + else: + if self.transaction.is_active: + self.transaction.commit() + finally: + if not self.close_with_result: + self.conn.close() def begin(self, close_with_result=False): """Return a context manager delivering a :class:`_engine.Connection` diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index f89be1809..72d232085 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -655,9 +655,12 @@ def create_engine(url, **kwargs): c = base.Connection( engine, connection=dbapi_connection, _has_events=False ) - c._execution_options = util.immutabledict() - dialect.initialize(c) - dialect.do_rollback(c.connection) + c._execution_options = util.EMPTY_DICT + + try: + dialect.initialize(c) + finally: + dialect.do_rollback(c.connection) # previously, the "first_connect" event was used here, which was then # scaled back if the "on_connect" handler were present. now, diff --git a/lib/sqlalchemy/future/engine.py b/lib/sqlalchemy/future/engine.py index d2f609326..bfdcdfc7f 100644 --- a/lib/sqlalchemy/future/engine.py +++ b/lib/sqlalchemy/future/engine.py @@ -368,12 +368,15 @@ class Engine(_LegacyEngine): return self.conn def __exit__(self, type_, value, traceback): - if type_ is not None: - self.transaction.rollback() - else: - if self.transaction.is_active: - self.transaction.commit() - self.conn.close() + try: + if type_ is not None: + if self.transaction.is_active: + self.transaction.rollback() + else: + if self.transaction.is_active: + self.transaction.commit() + finally: + self.conn.close() def begin(self): """Return a :class:`_future.Connection` object with a transaction diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 7c9509e45..6c3aad037 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -426,6 +426,7 @@ class _ConnectionRecord(object): rec._checkin_failed(err) echo = pool._should_log_debug() fairy = _ConnectionFairy(dbapi_connection, rec, echo) + rec.fairy_ref = weakref.ref( fairy, lambda ref: _finalize_fairy @@ -609,6 +610,15 @@ def _finalize_fairy( assert connection is None connection = connection_record.connection + dont_restore_gced = pool._is_asyncio + + if dont_restore_gced: + detach = not connection_record or ref + can_manipulate_connection = not ref + else: + detach = not connection_record + can_manipulate_connection = True + if connection is not None: if connection_record and echo: pool.logger.debug( @@ -620,13 +630,26 @@ def _finalize_fairy( connection, connection_record, echo ) assert fairy.connection is connection - fairy._reset(pool) + if can_manipulate_connection: + fairy._reset(pool) + + if detach: + if connection_record: + fairy._pool = pool + fairy.detach() + + if can_manipulate_connection: + if pool.dispatch.close_detached: + pool.dispatch.close_detached(connection) + + pool._close_connection(connection) + else: + util.warn( + "asyncio connection is being garbage " + "collected without being properly closed: %r" + % connection + ) - # Immediately close detached instances - if not connection_record: - if pool.dispatch.close_detached: - pool.dispatch.close_detached(connection) - pool._close_connection(connection) except BaseException as e: pool.logger.error( "Exception during reset or similar", exc_info=True diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index c1afeb907..9f2d0b857 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -32,6 +32,7 @@ from .assertions import is_instance_of # noqa from .assertions import is_none # noqa from .assertions import is_not # noqa from .assertions import is_not_ # noqa +from .assertions import is_not_none # noqa from .assertions import is_true # noqa from .assertions import le_ # noqa from .assertions import ne_ # noqa diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index b2a4ac66e..40549f54c 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -236,6 +236,14 @@ def is_false(a, msg=None): is_(bool(a), False, msg=msg) +def is_none(a, msg=None): + is_(a, None, msg=msg) + + +def is_not_none(a, msg=None): + is_not(a, None, msg=msg) + + def is_(a, b, msg=None): """Assert a is b, with repr messaging on failure.""" assert a is b, msg or "%r is not %r" % (a, b) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index f64153f33..750671f9f 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -97,6 +97,10 @@ def get_current_test_name(): return _fixture_functions.get_current_test_name() +def mark_base_test_class(): + return _fixture_functions.mark_base_test_class() + + class Config(object): def __init__(self, db, db_opts, options, file_config): self._set_name(db) diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index a4c1f3973..8b334fde2 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -7,6 +7,7 @@ from __future__ import absolute_import +import collections import re import warnings import weakref @@ -20,26 +21,29 @@ from .. import pool class ConnectionKiller(object): def __init__(self): self.proxy_refs = weakref.WeakKeyDictionary() - self.testing_engines = weakref.WeakKeyDictionary() - self.conns = set() + self.testing_engines = collections.defaultdict(set) + self.dbapi_connections = set() def add_pool(self, pool): - event.listen(pool, "connect", self.connect) - event.listen(pool, "checkout", self.checkout) - event.listen(pool, "invalidate", self.invalidate) - - def add_engine(self, engine): - self.add_pool(engine.pool) - self.testing_engines[engine] = True + event.listen(pool, "checkout", self._add_conn) + event.listen(pool, "checkin", self._remove_conn) + event.listen(pool, "close", self._remove_conn) + event.listen(pool, "close_detached", self._remove_conn) + # note we are keeping "invalidated" here, as those are still + # opened connections we would like to roll back + + def _add_conn(self, dbapi_con, con_record, con_proxy): + self.dbapi_connections.add(dbapi_con) + self.proxy_refs[con_proxy] = True - def connect(self, dbapi_conn, con_record): - self.conns.add((dbapi_conn, con_record)) + def _remove_conn(self, dbapi_conn, *arg): + self.dbapi_connections.discard(dbapi_conn) - def checkout(self, dbapi_con, con_record, con_proxy): - self.proxy_refs[con_proxy] = True + def add_engine(self, engine, scope): + self.add_pool(engine.pool) - def invalidate(self, dbapi_con, con_record, exception): - self.conns.discard((dbapi_con, con_record)) + assert scope in ("class", "global", "function", "fixture") + self.testing_engines[scope].add(engine) def _safe(self, fn): try: @@ -54,53 +58,76 @@ class ConnectionKiller(object): if rec is not None and rec.is_valid: self._safe(rec.rollback) - def close_all(self): + def checkin_all(self): + # run pool.checkin() for all ConnectionFairy instances we have + # tracked. + for rec in list(self.proxy_refs): if rec is not None and rec.is_valid: - self._safe(rec._close) - - def _after_test_ctx(self): - # this can cause a deadlock with pg8000 - pg8000 acquires - # prepared statement lock inside of rollback() - if async gc - # is collecting in finalize_fairy, deadlock. - # not sure if this should be for non-cpython only. - # note that firebird/fdb definitely needs this though - for conn, rec in list(self.conns): - if rec.connection is None: - # this is a hint that the connection is closed, which - # is causing segfaults on mysqlclient due to - # https://github.com/PyMySQL/mysqlclient-python/issues/270; - # try to work around here - continue - self._safe(conn.rollback) - - def _stop_test_ctx(self): - if config.options.low_connections: - self._stop_test_ctx_minimal() - else: - self._stop_test_ctx_aggressive() + self.dbapi_connections.discard(rec.connection) + self._safe(rec._checkin) - def _stop_test_ctx_minimal(self): - self.close_all() + # for fairy refs that were GCed and could not close the connection, + # such as asyncio, roll back those remaining connections + for con in self.dbapi_connections: + self._safe(con.rollback) + self.dbapi_connections.clear() - self.conns = set() + def close_all(self): + self.checkin_all() - for rec in list(self.testing_engines): - if rec is not config.db: - rec.dispose() + def prepare_for_drop_tables(self, connection): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return - def _stop_test_ctx_aggressive(self): - self.close_all() - for conn, rec in list(self.conns): - self._safe(conn.close) - rec.connection = None + from . import provision + + provision.prepare_for_drop_tables(connection.engine.url, connection) + + def _drop_testing_engines(self, scope): + eng = self.testing_engines[scope] + for rec in list(eng): + for proxy_ref in list(self.proxy_refs): + if proxy_ref is not None and proxy_ref.is_valid: + if ( + proxy_ref._pool is not None + and proxy_ref._pool is rec.pool + ): + self._safe(proxy_ref._checkin) + rec.dispose() + eng.clear() + + def after_test(self): + self._drop_testing_engines("function") + + def after_test_outside_fixtures(self, test): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return + + if test.__class__.__leave_connections_for_teardown__: + return - self.conns = set() - for rec in list(self.testing_engines): - if hasattr(rec, "sync_engine"): - rec.sync_engine.dispose() - else: - rec.dispose() + self.checkin_all() + + # on PostgreSQL, this will test for any "idle in transaction" + # connections. useful to identify tests with unusual patterns + # that can't be cleaned up correctly. + from . import provision + + with config.db.connect() as conn: + provision.prepare_for_drop_tables(conn.engine.url, conn) + + def stop_test_class_inside_fixtures(self): + self.checkin_all() + self._drop_testing_engines("function") + self._drop_testing_engines("class") + + def final_cleanup(self): + self.checkin_all() + for scope in self.testing_engines: + self._drop_testing_engines(scope) def assert_all_closed(self): for rec in self.proxy_refs: @@ -111,20 +138,6 @@ class ConnectionKiller(object): testing_reaper = ConnectionKiller() -def drop_all_tables(metadata, bind): - testing_reaper.close_all() - if hasattr(bind, "close"): - bind.close() - - if not config.db.dialect.supports_alter: - from . import assertions - - with assertions.expect_warnings("Can't sort tables", assert_=False): - metadata.drop_all(bind) - else: - metadata.drop_all(bind) - - @decorator def assert_conns_closed(fn, *args, **kw): try: @@ -147,7 +160,7 @@ def rollback_open_connections(fn, *args, **kw): def close_first(fn, *args, **kw): """Decorator that closes all connections before fn execution.""" - testing_reaper.close_all() + testing_reaper.checkin_all() fn(*args, **kw) @@ -157,7 +170,7 @@ def close_open_connections(fn, *args, **kw): try: fn(*args, **kw) finally: - testing_reaper.close_all() + testing_reaper.checkin_all() def all_dialects(exclude=None): @@ -239,12 +252,14 @@ def reconnecting_engine(url=None, options=None): return engine -def testing_engine(url=None, options=None, future=False, asyncio=False): +def testing_engine(url=None, options=None, future=None, asyncio=False): """Produce an engine configured by --options with optional overrides.""" if asyncio: from sqlalchemy.ext.asyncio import create_async_engine as create_engine - elif future or config.db and config.db._is_future: + elif future or ( + config.db and config.db._is_future and future is not False + ): from sqlalchemy.future import create_engine else: from sqlalchemy import create_engine @@ -252,8 +267,10 @@ def testing_engine(url=None, options=None, future=False, asyncio=False): if not options: use_reaper = True + scope = "function" else: use_reaper = options.pop("use_reaper", True) + scope = options.pop("scope", "function") url = url or config.db.url @@ -268,16 +285,20 @@ def testing_engine(url=None, options=None, future=False, asyncio=False): default_opt.update(options) engine = create_engine(url, **options) - if asyncio: - engine.sync_engine._has_events = True - else: - engine._has_events = True # enable event blocks, helps with profiling + + if scope == "global": + if asyncio: + engine.sync_engine._has_events = True + else: + engine._has_events = ( + True # enable event blocks, helps with profiling + ) if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 - engine.pool._max_overflow = 5 + engine.pool._max_overflow = 0 if use_reaper: - testing_reaper.add_engine(engine) + testing_reaper.add_engine(engine, scope) return engine diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index ac4d3d8fa..f19b4652a 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -5,6 +5,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import contextlib import re import sys @@ -12,12 +13,11 @@ import sqlalchemy as sa from . import assertions from . import config from . import schema -from .engines import drop_all_tables -from .engines import testing_engine from .entities import BasicEntity from .entities import ComparableEntity from .entities import ComparableMixin # noqa from .util import adict +from .util import drop_all_tables_from_metadata from .. import event from .. import util from ..orm import declarative_base @@ -25,10 +25,8 @@ from ..orm import registry from ..orm.decl_api import DeclarativeMeta from ..schema import sort_tables_and_constraints -# whether or not we use unittest changes things dramatically, -# as far as how pytest collection works. - +@config.mark_base_test_class() class TestBase(object): # A sequence of database names to always run, regardless of the # constraints below. @@ -48,81 +46,114 @@ class TestBase(object): # skipped. __skip_if__ = None + # if True, the testing reaper will not attempt to touch connection + # state after a test is completed and before the outer teardown + # starts + __leave_connections_for_teardown__ = False + def assert_(self, val, msg=None): assert val, msg - # apparently a handful of tests are doing this....OK - def setup(self): - if hasattr(self, "setUp"): - self.setUp() - - def teardown(self): - if hasattr(self, "tearDown"): - self.tearDown() - @config.fixture() def connection(self): - eng = getattr(self, "bind", config.db) + global _connection_fixture_connection + + eng = getattr(self, "bind", None) or config.db conn = eng.connect() trans = conn.begin() - try: - yield conn - finally: - if trans.is_active: - trans.rollback() - conn.close() + + _connection_fixture_connection = conn + yield conn + + _connection_fixture_connection = None + + if trans.is_active: + trans.rollback() + # trans would not be active here if the test is using + # the legacy @provide_metadata decorator still, as it will + # run a close all connections. + conn.close() @config.fixture() - def future_connection(self): + def future_connection(self, future_engine, connection): + # integrate the future_engine and connection fixtures so + # that users of the "connection" fixture will get at the + # "future" connection + yield connection - eng = testing_engine(future=True) - conn = eng.connect() - trans = conn.begin() - try: - yield conn - finally: - if trans.is_active: - trans.rollback() - conn.close() + @config.fixture() + def future_engine(self): + eng = getattr(self, "bind", None) or config.db + with _push_future_engine(eng): + yield + + @config.fixture() + def testing_engine(self): + from . import engines + + def gen_testing_engine( + url=None, options=None, future=False, asyncio=False + ): + if options is None: + options = {} + options["scope"] = "fixture" + return engines.testing_engine( + url=url, options=options, future=future, asyncio=asyncio + ) + + yield gen_testing_engine + + engines.testing_reaper._drop_testing_engines("fixture") @config.fixture() - def metadata(self): + def metadata(self, request): """Provide bound MetaData for a single test, dropping afterwards.""" - from . import engines from ..sql import schema metadata = schema.MetaData() - try: - yield metadata - finally: - engines.drop_all_tables(metadata, config.db) + request.instance.metadata = metadata + yield metadata + del request.instance.metadata + if ( + _connection_fixture_connection + and _connection_fixture_connection.in_transaction() + ): + trans = _connection_fixture_connection.get_transaction() + trans.rollback() + with _connection_fixture_connection.begin(): + drop_all_tables_from_metadata( + metadata, _connection_fixture_connection + ) + else: + drop_all_tables_from_metadata(metadata, config.db) -class FutureEngineMixin(object): - @classmethod - def setup_class(cls): - from ..future.engine import Engine - from sqlalchemy import testing +_connection_fixture_connection = None - facade = Engine._future_facade(config.db) - config._current.push_engine(facade, testing) - super_ = super(FutureEngineMixin, cls) - if hasattr(super_, "setup_class"): - super_.setup_class() +@contextlib.contextmanager +def _push_future_engine(engine): - @classmethod - def teardown_class(cls): - super_ = super(FutureEngineMixin, cls) - if hasattr(super_, "teardown_class"): - super_.teardown_class() + from ..future.engine import Engine + from sqlalchemy import testing + + facade = Engine._future_facade(engine) + config._current.push_engine(facade, testing) + + yield facade - from sqlalchemy import testing + config._current.pop(testing) - config._current.pop(testing) + +class FutureEngineMixin(object): + @config.fixture(autouse=True, scope="class") + def _push_future_engine(self): + eng = getattr(self, "bind", None) or config.db + with _push_future_engine(eng): + yield class TablesTest(TestBase): @@ -151,18 +182,32 @@ class TablesTest(TestBase): other = None sequences = None - @property - def tables_test_metadata(self): - return self._tables_metadata - - @classmethod - def setup_class(cls): + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ cls._init_class() cls._setup_once_tables() cls._setup_once_inserts() + yield + + cls._teardown_once_metadata_bind() + + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): + self._setup_each_tables() + self._setup_each_inserts() + + yield + + self._teardown_each_tables() + + @property + def tables_test_metadata(self): + return self._tables_metadata + @classmethod def _init_class(cls): if cls.run_define_tables == "each": @@ -213,10 +258,10 @@ class TablesTest(TestBase): if self.run_define_tables == "each": self.tables.clear() if self.run_create_tables == "each": - drop_all_tables(self._tables_metadata, self.bind) + drop_all_tables_from_metadata(self._tables_metadata, self.bind) self._tables_metadata.clear() elif self.run_create_tables == "each": - drop_all_tables(self._tables_metadata, self.bind) + drop_all_tables_from_metadata(self._tables_metadata, self.bind) # no need to run deletes if tables are recreated on setup if ( @@ -242,17 +287,10 @@ class TablesTest(TestBase): file=sys.stderr, ) - def setup(self): - self._setup_each_tables() - self._setup_each_inserts() - - def teardown(self): - self._teardown_each_tables() - @classmethod def _teardown_once_metadata_bind(cls): if cls.run_create_tables: - drop_all_tables(cls._tables_metadata, cls.bind) + drop_all_tables_from_metadata(cls._tables_metadata, cls.bind) if cls.run_dispose_bind == "once": cls.dispose_bind(cls.bind) @@ -263,10 +301,6 @@ class TablesTest(TestBase): cls.bind = None @classmethod - def teardown_class(cls): - cls._teardown_once_metadata_bind() - - @classmethod def setup_bind(cls): return config.db @@ -332,38 +366,47 @@ class RemovesEvents(object): self._event_fns.add((target, name, fn)) event.listen(target, name, fn, **kw) - def teardown(self): + @config.fixture(autouse=True, scope="function") + def _remove_events(self): + yield for key in self._event_fns: event.remove(*key) - super_ = super(RemovesEvents, self) - if hasattr(super_, "teardown"): - super_.teardown() - - -class _ORMTest(object): - @classmethod - def teardown_class(cls): - sa.orm.session.close_all_sessions() - sa.orm.clear_mappers() -def create_session(**kw): - kw.setdefault("autoflush", False) - kw.setdefault("expire_on_commit", False) - return sa.orm.Session(config.db, **kw) +_fixture_sessions = set() def fixture_session(**kw): kw.setdefault("autoflush", True) kw.setdefault("expire_on_commit", True) - return sa.orm.Session(config.db, **kw) + sess = sa.orm.Session(config.db, **kw) + _fixture_sessions.add(sess) + return sess + + +def _close_all_sessions(): + # will close all still-referenced sessions + sa.orm.session.close_all_sessions() + _fixture_sessions.clear() + + +def stop_test_class_inside_fixtures(cls): + _close_all_sessions() + sa.orm.clear_mappers() -class ORMTest(_ORMTest, TestBase): +def after_test(): + + if _fixture_sessions: + + _close_all_sessions() + + +class ORMTest(TestBase): pass -class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): +class MappedTest(TablesTest, assertions.AssertsExecutionResults): # 'once', 'each', None run_setup_classes = "once" @@ -372,8 +415,9 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): classes = None - @classmethod - def setup_class(cls): + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ cls._init_class() if cls.classes is None: @@ -384,18 +428,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): cls._setup_once_mappers() cls._setup_once_inserts() - @classmethod - def teardown_class(cls): + yield + cls._teardown_once_class() cls._teardown_once_metadata_bind() - def setup(self): + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): self._setup_each_tables() self._setup_each_classes() self._setup_each_mappers() self._setup_each_inserts() - def teardown(self): + yield + sa.orm.session.close_all_sessions() self._teardown_each_mappers() self._teardown_each_classes() @@ -404,7 +450,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): @classmethod def _teardown_once_class(cls): cls.classes.clear() - _ORMTest.teardown_class() @classmethod def _setup_once_classes(cls): @@ -440,6 +485,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): """ cls_registry = cls.classes + assert cls_registry is not None + class FindFixture(type): def __init__(cls, classname, bases, dict_): cls_registry[classname] = cls diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py index a95c947e2..1f568dfc8 100644 --- a/lib/sqlalchemy/testing/plugin/bootstrap.py +++ b/lib/sqlalchemy/testing/plugin/bootstrap.py @@ -40,6 +40,11 @@ def load_file_as_module(name): if to_bootstrap == "pytest": sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") + sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True + if sys.version_info < (3, 0): + sys.modules["sqla_reinvent_fixtures"] = load_file_as_module( + "reinvent_fixtures_py2k" + ) sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin") else: raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 3594cd276..7851fbb3e 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -21,6 +21,9 @@ import logging import re import sys +# flag which indicates we are in the SQLAlchemy testing suite, +# and not that of Alembic or a third party dialect. +bootstrapped_as_sqlalchemy = False log = logging.getLogger("sqlalchemy.testing.plugin_base") @@ -381,7 +384,7 @@ def _init_symbols(options, file_config): @post def _set_disable_asyncio(opt, file_config): - if opt.disable_asyncio: + if opt.disable_asyncio or not py3k: from sqlalchemy.testing import asyncio asyncio.ENABLE_ASYNCIO = False @@ -458,6 +461,8 @@ def _setup_requirements(argument): config.requirements = testing.requires = req_cls() + config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy + @post def _prep_testing_database(options, file_config): @@ -566,17 +571,22 @@ def generate_sub_tests(cls, module): yield cls -def start_test_class(cls): +def start_test_class_outside_fixtures(cls): _do_skips(cls) _setup_engine(cls) def stop_test_class(cls): - # from sqlalchemy import inspect - # assert not inspect(testing.db).get_table_names() + # close sessions, immediate connections, etc. + fixtures.stop_test_class_inside_fixtures(cls) + + # close outstanding connection pool connections, dispose of + # additional engines + engines.testing_reaper.stop_test_class_inside_fixtures() - provision.stop_test_class(config, config.db, cls) - engines.testing_reaper._stop_test_ctx() + +def stop_test_class_outside_fixtures(cls): + provision.stop_test_class_outside_fixtures(config, config.db, cls) try: if not options.low_connections: assertions.global_cleanup_assertions() @@ -590,14 +600,16 @@ def _restore_engine(): def final_process_cleanup(): - engines.testing_reaper._stop_test_ctx_aggressive() + engines.testing_reaper.final_cleanup() assertions.global_cleanup_assertions() _restore_engine() def _setup_engine(cls): if getattr(cls, "__engine_options__", None): - eng = engines.testing_engine(options=cls.__engine_options__) + opts = dict(cls.__engine_options__) + opts["scope"] = "class" + eng = engines.testing_engine(options=opts) config._current.push_engine(eng, testing) @@ -614,7 +626,12 @@ def before_test(test, test_module_name, test_class, test_name): def after_test(test): - engines.testing_reaper._after_test_ctx() + fixtures.after_test() + engines.testing_reaper.after_test() + + +def after_test_fixtures(test): + engines.testing_reaper.after_test_outside_fixtures(test) def _possible_configs_for_cls(cls, reasons=None, sparse=False): @@ -748,6 +765,10 @@ class FixtureFunctions(ABC): def get_current_test_name(self): raise NotImplementedError() + @abc.abstractmethod + def mark_base_test_class(self): + raise NotImplementedError() + _fixture_fn_class = None diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 46468a07d..4eaaecebb 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -17,6 +17,7 @@ import sys import pytest + try: import typing except ImportError: @@ -33,6 +34,14 @@ except ImportError: has_xdist = False +py2k = sys.version_info < (3, 0) +if py2k: + try: + import sqla_reinvent_fixtures as reinvent_fixtures_py2k + except ImportError: + from . import reinvent_fixtures_py2k + + def pytest_addoption(parser): group = parser.getgroup("sqlalchemy") @@ -238,6 +247,10 @@ def pytest_collection_modifyitems(session, config, items): else: newitems.append(item) + if py2k: + for item in newitems: + reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item) + # seems like the functions attached to a test class aren't sorted already? # is that true and why's that? (when using unittest, they're sorted) items[:] = sorted( @@ -251,7 +264,6 @@ def pytest_collection_modifyitems(session, config, items): def pytest_pycollect_makeitem(collector, name, obj): - if inspect.isclass(obj) and plugin_base.want_class(name, obj): from sqlalchemy.testing import config @@ -259,7 +271,6 @@ def pytest_pycollect_makeitem(collector, name, obj): obj = _apply_maybe_async(obj) ctor = getattr(pytest.Class, "from_parent", pytest.Class) - return [ ctor(name=parametrize_cls.__name__, parent=collector) for parametrize_cls in _parametrize_cls(collector.module, obj) @@ -287,12 +298,11 @@ def _is_wrapped_coroutine_function(fn): def _apply_maybe_async(obj, recurse=True): from sqlalchemy.testing import asyncio - setup_names = {"setup", "setup_class", "teardown", "teardown_class"} for name, value in vars(obj).items(): if ( (callable(value) or isinstance(value, classmethod)) and not getattr(value, "_maybe_async_applied", False) - and (name.startswith("test_") or name in setup_names) + and (name.startswith("test_")) and not _is_wrapped_coroutine_function(value) ): is_classmethod = False @@ -317,9 +327,6 @@ def _apply_maybe_async(obj, recurse=True): return obj -_current_class = None - - def _parametrize_cls(module, cls): """implement a class-based version of pytest parametrize.""" @@ -355,63 +362,153 @@ def _parametrize_cls(module, cls): return classes +_current_class = None + + def pytest_runtest_setup(item): from sqlalchemy.testing import asyncio - # here we seem to get called only based on what we collected - # in pytest_collection_modifyitems. So to do class-based stuff - # we have to tear that out. - global _current_class - if not isinstance(item, pytest.Function): return - # ... so we're doing a little dance here to figure it out... + # pytest_runtest_setup runs *before* pytest fixtures with scope="class". + # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest + # for the whole class and has to run things that are across all current + # databases, so we run this outside of the pytest fixture system altogether + # and ensure asyncio greenlet if any engines are async + + global _current_class + if _current_class is None: - asyncio._maybe_async(class_setup, item.parent.parent) + asyncio._maybe_async_provisioning( + plugin_base.start_test_class_outside_fixtures, + item.parent.parent.cls, + ) _current_class = item.parent.parent - # this is needed for the class-level, to ensure that the - # teardown runs after the class is completed with its own - # class-level teardown... def finalize(): global _current_class - asyncio._maybe_async(class_teardown, item.parent.parent) _current_class = None + asyncio._maybe_async_provisioning( + plugin_base.stop_test_class_outside_fixtures, + item.parent.parent.cls, + ) + item.parent.parent.addfinalizer(finalize) - asyncio._maybe_async(test_setup, item) +def pytest_runtest_call(item): + # runs inside of pytest function fixture scope + # before test function runs -def pytest_runtest_teardown(item): from sqlalchemy.testing import asyncio - # ...but this works better as the hook here rather than - # using a finalizer, as the finalizer seems to get in the way - # of the test reporting failures correctly (you get a bunch of - # pytest assertion stuff instead) - asyncio._maybe_async(test_teardown, item) + asyncio._maybe_async( + plugin_base.before_test, + item, + item.parent.module.__name__, + item.parent.cls, + item.name, + ) -def test_setup(item): - plugin_base.before_test( - item, item.parent.module.__name__, item.parent.cls, item.name - ) +def pytest_runtest_teardown(item, nextitem): + # runs inside of pytest function fixture scope + # after test function runs + from sqlalchemy.testing import asyncio -def test_teardown(item): - plugin_base.after_test(item) + asyncio._maybe_async(plugin_base.after_test, item) -def class_setup(item): +@pytest.fixture(scope="class") +def setup_class_methods(request): from sqlalchemy.testing import asyncio - asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls) + cls = request.cls + + if hasattr(cls, "setup_test_class"): + asyncio._maybe_async(cls.setup_test_class) + + if py2k: + reinvent_fixtures_py2k.run_class_fixture_setup(request) + + yield + + if py2k: + reinvent_fixtures_py2k.run_class_fixture_teardown(request) + if hasattr(cls, "teardown_test_class"): + asyncio._maybe_async(cls.teardown_test_class) -def class_teardown(item): - plugin_base.stop_test_class(item.cls) + asyncio._maybe_async(plugin_base.stop_test_class, cls) + + +@pytest.fixture(scope="function") +def setup_test_methods(request): + from sqlalchemy.testing import asyncio + + # called for each test + + self = request.instance + + # 1. run outer xdist-style setup + if hasattr(self, "setup_test"): + asyncio._maybe_async(self.setup_test) + + # alembic test suite is using setUp and tearDown + # xdist methods; support these in the test suite + # for the near term + if hasattr(self, "setUp"): + asyncio._maybe_async(self.setUp) + + # 2. run homegrown function level "autouse" fixtures under py2k + if py2k: + reinvent_fixtures_py2k.run_fn_fixture_setup(request) + + # inside the yield: + + # 3. function level "autouse" fixtures under py3k (examples: TablesTest + # define tables / data, MappedTest define tables / mappers / data) + + # 4. function level fixtures defined on test functions themselves, + # e.g. "connection", "metadata" run next + + # 5. pytest hook pytest_runtest_call then runs + + # 6. test itself runs + + yield + + # yield finishes: + + # 7. pytest hook pytest_runtest_teardown hook runs, this is associated + # with fixtures close all sessions, provisioning.stop_test_class(), + # engines.testing_reaper -> ensure all connection pool connections + # are returned, engines created by testing_engine that aren't the + # config engine are disposed + + # 8. function level fixtures defined on test functions + # themselves, e.g. "connection" rolls back the transaction, "metadata" + # emits drop all + + # 9. function level "autouse" fixtures under py3k (examples: TablesTest / + # MappedTest delete table data, possibly drop tables and clear mappers + # depending on the flags defined by the test class) + + # 10. run homegrown function-level "autouse" fixtures under py2k + if py2k: + reinvent_fixtures_py2k.run_fn_fixture_teardown(request) + + asyncio._maybe_async(plugin_base.after_test_fixtures, self) + + # 11. run outer xdist-style teardown + if hasattr(self, "tearDown"): + asyncio._maybe_async(self.tearDown) + + if hasattr(self, "teardown_test"): + asyncio._maybe_async(self.teardown_test) def getargspec(fn): @@ -461,6 +558,8 @@ def %(name)s(%(args)s): # for the wrapped function decorated.__module__ = fn.__module__ decorated.__name__ = fn.__name__ + if hasattr(fn, "pytestmark"): + decorated.pytestmark = fn.pytestmark return decorated return decorate @@ -470,6 +569,11 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): def skip_test_exception(self, *arg, **kw): return pytest.skip.Exception(*arg, **kw) + def mark_base_test_class(self): + return pytest.mark.usefixtures( + "setup_class_methods", "setup_test_methods" + ) + _combination_id_fns = { "i": lambda obj: obj, "r": repr, @@ -647,8 +751,18 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): fn = asyncio._maybe_async_wrapper(fn) # other wrappers may be added here - # now apply FixtureFunctionMarker - fn = fixture(fn) + if py2k and "autouse" in kw: + # py2k workaround for too-slow collection of autouse fixtures + # in pytest 4.6.11. See notes in reinvent_fixtures_py2k for + # rationale. + + # comment this condition out in order to disable the + # py2k workaround entirely. + reinvent_fixtures_py2k.add_fixture(fn, fixture) + else: + # now apply FixtureFunctionMarker + fn = fixture(fn) + return fn if fn: diff --git a/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py new file mode 100644 index 000000000..36b68417b --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py @@ -0,0 +1,112 @@ +""" +invent a quick version of pytest autouse fixtures as pytest's unacceptably slow +collection/high memory use in pytest 4.6.11, which is the highest version that +works in py2k. + +by "too-slow" we mean the test suite can't even manage to be collected for a +single process in less than 70 seconds or so and memory use seems to be very +high as well. for two or four workers the job just times out after ten +minutes. + +so instead we have invented a very limited form of these fixtures, as our +current use of "autouse" fixtures are limited to those in fixtures.py. + +assumptions for these fixtures: + +1. we are only using "function" or "class" scope + +2. the functions must be associated with a test class + +3. the fixture functions cannot themselves use pytest fixtures + +4. the fixture functions must use yield, not return + +When py2k support is removed and we can stay on a modern pytest version, this +can all be removed. + + +""" +import collections + + +_py2k_fixture_fn_names = collections.defaultdict(set) +_py2k_class_fixtures = collections.defaultdict( + lambda: collections.defaultdict(set) +) +_py2k_function_fixtures = collections.defaultdict( + lambda: collections.defaultdict(set) +) + +_py2k_cls_fixture_stack = [] +_py2k_fn_fixture_stack = [] + + +def add_fixture(fn, fixture): + assert fixture.scope in ("class", "function") + _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope)) + + +def scan_for_fixtures_to_use_for_class(item): + test_class = item.parent.parent.obj + + for name in _py2k_fixture_fn_names: + for fixture_fn, scope in _py2k_fixture_fn_names[name]: + meth = getattr(test_class, name, None) + if meth and meth.im_func is fixture_fn: + for sup in test_class.__mro__: + if name in sup.__dict__: + if scope == "class": + _py2k_class_fixtures[test_class][sup].add(meth) + elif scope == "function": + _py2k_function_fixtures[test_class][sup].add(meth) + break + break + + +def run_class_fixture_setup(request): + + cls = request.cls + self = cls.__new__(cls) + + fixtures_for_this_class = _py2k_class_fixtures.get(cls) + + if fixtures_for_this_class: + for sup_ in cls.__mro__: + for fn in fixtures_for_this_class.get(sup_, ()): + iter_ = fn(self) + next(iter_) + + _py2k_cls_fixture_stack.append(iter_) + + +def run_class_fixture_teardown(request): + while _py2k_cls_fixture_stack: + iter_ = _py2k_cls_fixture_stack.pop(-1) + try: + next(iter_) + except StopIteration: + pass + + +def run_fn_fixture_setup(request): + cls = request.cls + self = request.instance + + fixtures_for_this_class = _py2k_function_fixtures.get(cls) + + if fixtures_for_this_class: + for sup_ in reversed(cls.__mro__): + for fn in fixtures_for_this_class.get(sup_, ()): + iter_ = fn(self) + next(iter_) + + _py2k_fn_fixture_stack.append(iter_) + + +def run_fn_fixture_teardown(request): + while _py2k_fn_fixture_stack: + iter_ = _py2k_fn_fixture_stack.pop(-1) + try: + next(iter_) + except StopIteration: + pass diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 4ee0567f2..2fade1c32 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -67,6 +67,7 @@ def setup_config(db_url, options, file_config, follower_ident): db_url = follower_url_from_main(db_url, follower_ident) db_opts = {} update_db_opts(db_url, db_opts) + db_opts["scope"] = "global" eng = engines.testing_engine(db_url, db_opts) post_configure_engine(db_url, eng, follower_ident) eng.connect().close() @@ -264,6 +265,7 @@ def drop_all_schema_objects(cfg, eng): if config.requirements.schemas.enabled_for_config(cfg): util.drop_all_tables(eng, inspector, schema=cfg.test_schema) + util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2) drop_all_schema_objects_post_tables(cfg, eng) @@ -299,7 +301,7 @@ def update_db_opts(db_url, db_opts): def post_configure_engine(url, engine, follower_ident): """Perform extra steps after configuring an engine for testing. - (For the internal dialects, currently only used by sqlite.) + (For the internal dialects, currently only used by sqlite, oracle) """ pass @@ -375,7 +377,12 @@ def temp_table_keyword_args(cfg, eng): @register.init -def stop_test_class(config, db, testcls): +def prepare_for_drop_tables(config, connection): + pass + + +@register.init +def stop_test_class_outside_fixtures(config, db, testcls): pass diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 6c3c1005a..de157d028 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -293,7 +293,7 @@ class ComponentReflectionTest(fixtures.TablesTest): from sqlalchemy import pool return engines.testing_engine( - options=dict(poolclass=pool.StaticPool) + options=dict(poolclass=pool.StaticPool, scope="class"), ) else: return config.db diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index e0fdbe47a..e8dd6cf2c 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -261,10 +261,6 @@ class ServerSideCursorsTest( ) return self.engine - def tearDown(self): - engines.testing_reaper.close_all() - self.engine.dispose() - @testing.combinations( ("global_string", True, "select 1", True), ("global_text", True, text("select 1"), True), @@ -309,24 +305,22 @@ class ServerSideCursorsTest( def test_conn_option(self): engine = self._fixture(False) - # should be enabled for this one - result = ( - engine.connect() - .execution_options(stream_results=True) - .exec_driver_sql("select 1") - ) - assert self._is_server_side(result.cursor) + with engine.connect() as conn: + # should be enabled for this one + result = conn.execution_options( + stream_results=True + ).exec_driver_sql("select 1") + assert self._is_server_side(result.cursor) def test_stmt_enabled_conn_option_disabled(self): engine = self._fixture(False) s = select(1).execution_options(stream_results=True) - # not this one - result = ( - engine.connect().execution_options(stream_results=False).execute(s) - ) - assert not self._is_server_side(result.cursor) + with engine.connect() as conn: + # not this one + result = conn.execution_options(stream_results=False).execute(s) + assert not self._is_server_side(result.cursor) def test_aliases_and_ss(self): engine = self._fixture(False) @@ -344,8 +338,7 @@ class ServerSideCursorsTest( assert not self._is_server_side(result.cursor) result.close() - @testing.provide_metadata - def test_roundtrip_fetchall(self): + def test_roundtrip_fetchall(self, metadata): md = self.metadata engine = self._fixture(True) @@ -385,8 +378,7 @@ class ServerSideCursorsTest( 0, ) - @testing.provide_metadata - def test_roundtrip_fetchmany(self): + def test_roundtrip_fetchmany(self, metadata): md = self.metadata engine = self._fixture(True) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 3a5e02c32..ebcceaae7 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -511,24 +511,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True @testing.fixture - def do_numeric_test(self, metadata): + def do_numeric_test(self, metadata, connection): @testing.emits_warning( r".*does \*not\* support Decimal objects natively" ) def run(type_, input_, output, filter_=None, check_scale=False): t = Table("t", metadata, Column("x", type_)) - t.create(testing.db) - with config.db.begin() as conn: - conn.execute(t.insert(), [{"x": x} for x in input_]) - - result = {row[0] for row in conn.execute(t.select())} - output = set(output) - if filter_: - result = set(filter_(x) for x in result) - output = set(filter_(x) for x in output) - eq_(result, output) - if check_scale: - eq_([str(x) for x in result], [str(x) for x in output]) + t.create(connection) + connection.execute(t.insert(), [{"x": x} for x in input_]) + + result = {row[0] for row in connection.execute(t.select())} + output = set(output) + if filter_: + result = set(filter_(x) for x in result) + output = set(filter_(x) for x in output) + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) return run @@ -1165,40 +1164,39 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): }, ) - def test_eval_none_flag_orm(self): + def test_eval_none_flag_orm(self, connection): Base = declarative_base() class Data(Base): __table__ = self.tables.data_table - s = Session(testing.db) + with Session(connection) as s: + d1 = Data(name="d1", data=None, nulldata=None) + s.add(d1) + s.commit() - d1 = Data(name="d1", data=None, nulldata=None) - s.add(d1) - s.commit() - - s.bulk_insert_mappings( - Data, [{"name": "d2", "data": None, "nulldata": None}] - ) - eq_( - s.query( - cast(self.tables.data_table.c.data, String()), - cast(self.tables.data_table.c.nulldata, String), + s.bulk_insert_mappings( + Data, [{"name": "d2", "data": None, "nulldata": None}] ) - .filter(self.tables.data_table.c.name == "d1") - .first(), - ("null", None), - ) - eq_( - s.query( - cast(self.tables.data_table.c.data, String()), - cast(self.tables.data_table.c.nulldata, String), + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), ) - .filter(self.tables.data_table.c.name == "d2") - .first(), - ("null", None), - ) class JSONLegacyStringCastIndexTest( diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index eb9fcd1cd..01185c284 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -14,6 +14,7 @@ import types from . import config from . import mock from .. import inspect +from ..engine import Connection from ..schema import Column from ..schema import DropConstraint from ..schema import DropTable @@ -207,11 +208,13 @@ def fail(msg): @decorator def provide_metadata(fn, *args, **kw): - """Provide bound MetaData for a single test, dropping afterwards.""" + """Provide bound MetaData for a single test, dropping afterwards. - # import cycle that only occurs with py2k's import resolver - # in py3k this can be moved top level. - from . import engines + Legacy; use the "metadata" pytest fixture. + + """ + + from . import fixtures metadata = schema.MetaData() self = args[0] @@ -220,7 +223,31 @@ def provide_metadata(fn, *args, **kw): try: return fn(*args, **kw) finally: - engines.drop_all_tables(metadata, config.db) + # close out some things that get in the way of dropping tables. + # when using the "metadata" fixture, there is a set ordering + # of things that makes sure things are cleaned up in order, however + # the simple "decorator" nature of this legacy function means + # we have to hardcode some of that cleanup ahead of time. + + # close ORM sessions + fixtures._close_all_sessions() + + # integrate with the "connection" fixture as there are many + # tests where it is used along with provide_metadata + if fixtures._connection_fixture_connection: + # TODO: this warning can be used to find all the places + # this is used with connection fixture + # warn("mixing legacy provide metadata with connection fixture") + drop_all_tables_from_metadata( + metadata, fixtures._connection_fixture_connection + ) + # as the provide_metadata fixture is often used with "testing.db", + # when we do the drop we have to commit the transaction so that + # the DB is actually updated as the CREATE would have been + # committed + fixtures._connection_fixture_connection.get_transaction().commit() + else: + drop_all_tables_from_metadata(metadata, config.db) self.metadata = prev_meta @@ -359,6 +386,29 @@ class adict(dict): get_all = __call__ +def drop_all_tables_from_metadata(metadata, engine_or_connection): + from . import engines + + def go(connection): + engines.testing_reaper.prepare_for_drop_tables(connection) + + if not connection.dialect.supports_alter: + from . import assertions + + with assertions.expect_warnings( + "Can't sort tables", assert_=False + ): + metadata.drop_all(connection) + else: + metadata.drop_all(connection) + + if not isinstance(engine_or_connection, Connection): + with engine_or_connection.begin() as connection: + go(connection) + else: + go(engine_or_connection) + + def drop_all_tables(engine, inspector, schema=None, include_names=None): if include_names is not None: diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index 99ecb4fb3..ca5a3abde 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -230,13 +230,16 @@ class AsyncAdaptedQueue: return self.put_nowait(item) try: - if timeout: + if timeout is not None: return self.await_( asyncio.wait_for(self._queue.put(item), timeout) ) else: return self.await_(self._queue.put(item)) - except asyncio.queues.QueueFull as err: + except ( + asyncio.queues.QueueFull, + asyncio.exceptions.TimeoutError, + ) as err: compat.raise_( Full(), replace_context=err, @@ -254,14 +257,18 @@ class AsyncAdaptedQueue: def get(self, block=True, timeout=None): if not block: return self.get_nowait() + try: - if timeout: + if timeout is not None: return self.await_( asyncio.wait_for(self._queue.get(), timeout) ) else: return self.await_(self._queue.get()) - except asyncio.queues.QueueEmpty as err: + except ( + asyncio.queues.QueueEmpty, + asyncio.exceptions.TimeoutError, + ) as err: compat.raise_( Empty(), replace_context=err, |
