diff options
Diffstat (limited to 'lib')
39 files changed, 991 insertions, 403 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 538679fcf..9d0e5d322 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -9,6 +9,16 @@ :name: Microsoft SQL Server +.. _mssql_external_dialects: + +External Dialects +----------------- + +In addition to the above DBAPI layers with native SQLAlchemy support, there +are third-party dialects for other DBAPI layers that are compatible +with SQL Server. See the "External Dialects" list on the +:ref:`dialect_toplevel` page. + .. _mssql_identity: Auto Increment Behavior / IDENTITY Columns @@ -2785,15 +2795,14 @@ class MSDialect(default.DefaultDialect): def has_table(self, connection, tablename, dbname, owner, schema): if tablename.startswith("#"): # temporary table tables = ischema.mssql_temp_table_columns - result = connection.execute( - sql.select(tables.c.table_name) - .where( - tables.c.table_name.like( - self._temp_table_name_like_pattern(tablename) - ) + + s = sql.select(tables.c.table_name).where( + tables.c.table_name.like( + self._temp_table_name_like_pattern(tablename) ) - .limit(1) ) + + result = connection.execute(s.limit(1)) return result.scalar() is not None else: tables = ischema.tables diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py index 269eb164f..56f3305a7 100644 --- a/lib/sqlalchemy/dialects/mssql/provision.py +++ b/lib/sqlalchemy/dialects/mssql/provision.py @@ -1,6 +1,14 @@ +from sqlalchemy import inspect +from sqlalchemy import Integer from ... import create_engine from ... import exc +from ...schema import Column +from ...schema import DropConstraint +from ...schema import ForeignKeyConstraint +from ...schema import MetaData +from ...schema import Table from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_pre_tables from ...testing.provision import drop_db from ...testing.provision import get_temp_table_name from ...testing.provision import log @@ -38,7 +46,6 @@ def _mssql_drop_ignore(conn, ident): # "where database_id=db_id('%s')" % ident): # log.info("killing SQL server session %s", row['session_id']) # conn.exec_driver_sql("kill %s" % row['session_id']) - conn.exec_driver_sql("drop database %s" % ident) log.info("Reaped db: %s", ident) return True @@ -83,4 +90,27 @@ def _mssql_temp_table_keyword_args(cfg, eng): @get_temp_table_name.for_db("mssql") def _mssql_get_temp_table_name(cfg, eng, base_name): - return "#" + base_name + return "##" + base_name + + +@drop_all_schema_objects_pre_tables.for_db("mssql") +def drop_all_schema_objects_pre_tables(cfg, eng): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + inspector = inspect(conn) + for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2): + for tname in inspector.get_table_names(schema=schema): + tb = Table( + tname, + MetaData(), + Column("x", Integer), + Column("y", Integer), + schema=schema, + ) + for fk in inspect(conn).get_foreign_keys(tname, schema=schema): + conn.execute( + DropConstraint( + ForeignKeyConstraint( + [tb.c.x], [tb.c.y], name=fk["name"] + ) + ) + ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 1cc8b7aef..a496f0354 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1731,7 +1731,7 @@ class OracleDialect(default.DefaultDialect): sql.text( "SELECT username FROM user_db_links " "WHERE db_link=:link" ), - link=dblink, + dict(link=dblink), ) dblink = "@" + dblink elif not owner: @@ -1805,7 +1805,7 @@ class OracleDialect(default.DefaultDialect): "SELECT sequence_name FROM all_sequences " "WHERE sequence_owner = :schema_name" ), - schema_name=self.denormalize_name(schema), + dict(schema_name=self.denormalize_name(schema)), ) return [self.normalize_name(row[0]) for row in cursor] diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 042443692..b8b4df760 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -93,6 +93,7 @@ The parameters accepted by the cx_oracle dialect are as follows: * ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail. + .. _cx_oracle_unicode: Unicode diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py index d51131c0b..e0dadd58e 100644 --- a/lib/sqlalchemy/dialects/oracle/provision.py +++ b/lib/sqlalchemy/dialects/oracle/provision.py @@ -6,11 +6,11 @@ from ...testing.provision import create_db from ...testing.provision import drop_db from ...testing.provision import follower_url_from_main from ...testing.provision import log +from ...testing.provision import post_configure_engine from ...testing.provision import run_reap_dbs from ...testing.provision import set_default_schema_on_connection -from ...testing.provision import stop_test_class +from ...testing.provision import stop_test_class_outside_fixtures from ...testing.provision import temp_table_keyword_args -from ...testing.provision import update_db_opts @create_db.for_db("oracle") @@ -57,21 +57,39 @@ def _oracle_drop_db(cfg, eng, ident): _ora_drop_ignore(conn, "%s_ts2" % ident) -@update_db_opts.for_db("oracle") -def _oracle_update_db_opts(db_url, db_opts): - pass +@stop_test_class_outside_fixtures.for_db("oracle") +def stop_test_class_outside_fixtures(config, db, cls): + with db.begin() as conn: + # run magic command to get rid of identity sequences + # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa E501 + conn.exec_driver_sql("purge recyclebin") -@stop_test_class.for_db("oracle") -def stop_test_class(config, db, cls): - """run magic command to get rid of identity sequences + # clear statement cache on all connections that were used + # https://github.com/oracle/python-cx_Oracle/issues/519 - # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ + for cx_oracle_conn in _all_conns: + try: + sc = cx_oracle_conn.stmtcachesize + except db.dialect.dbapi.InterfaceError: + # connection closed + pass + else: + cx_oracle_conn.stmtcachesize = 0 + cx_oracle_conn.stmtcachesize = sc + _all_conns.clear() - """ - with db.begin() as conn: - conn.exec_driver_sql("purge recyclebin") +_all_conns = set() + + +@post_configure_engine.for_db("oracle") +def _oracle_post_configure_engine(url, engine, follower_ident): + from sqlalchemy import event + + @event.listens_for(engine, "checkout") + def checkout(dbapi_con, con_record, con_proxy): + _all_conns.add(dbapi_con) @run_reap_dbs.for_db("oracle") diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 7c6e8fb02..424ed0d50 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -615,6 +615,10 @@ class AsyncAdapt_asyncpg_connection: return prepared_stmt, attributes def _handle_exception(self, error): + if self._connection.is_closed(): + self._transaction = None + self._started = False + if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error): exception_mapping = self.dbapi._asyncpg_error_translate @@ -669,16 +673,23 @@ class AsyncAdapt_asyncpg_connection: def rollback(self): if self._started: - self.await_(self._transaction.rollback()) - - self._transaction = None - self._started = False + try: + self.await_(self._transaction.rollback()) + except Exception as error: + self._handle_exception(error) + finally: + self._transaction = None + self._started = False def commit(self): if self._started: - self.await_(self._transaction.commit()) - self._transaction = None - self._started = False + try: + self.await_(self._transaction.commit()) + except Exception as error: + self._handle_exception(error) + finally: + self._transaction = None + self._started = False def close(self): self.rollback() diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 735990a20..7a898cb8a 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -3335,7 +3335,11 @@ class PGDialect(default.DefaultDialect): "WHERE n.nspname = :schema AND c.relkind IN (%s)" % (", ".join("'%s'" % elem for elem in kinds)) ).columns(relname=sqltypes.Unicode), - schema=schema if schema is not None else self.default_schema_name, + dict( + schema=schema + if schema is not None + else self.default_schema_name + ), ) return [name for name, in result] @@ -3367,8 +3371,12 @@ class PGDialect(default.DefaultDialect): "WHERE n.nspname = :schema AND c.relname = :view_name " "AND c.relkind IN ('v', 'm')" ).columns(view_def=sqltypes.Unicode), - schema=schema if schema is not None else self.default_schema_name, - view_name=view_name, + dict( + schema=schema + if schema is not None + else self.default_schema_name, + view_name=view_name, + ), ) return view_def diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py index d345cdfdf..9196337ba 100644 --- a/lib/sqlalchemy/dialects/postgresql/provision.py +++ b/lib/sqlalchemy/dialects/postgresql/provision.py @@ -8,6 +8,7 @@ from ...testing.provision import drop_all_schema_objects_post_tables from ...testing.provision import drop_all_schema_objects_pre_tables from ...testing.provision import drop_db from ...testing.provision import log +from ...testing.provision import prepare_for_drop_tables from ...testing.provision import set_default_schema_on_connection from ...testing.provision import temp_table_keyword_args @@ -61,7 +62,7 @@ def _pg_drop_db(cfg, eng, ident): "where usename=current_user and pid != pg_backend_pid() " "and datname=:dname" ), - dname=ident, + dict(dname=ident), ) conn.exec_driver_sql("DROP DATABASE %s" % ident) @@ -102,3 +103,23 @@ def drop_all_schema_objects_post_tables(cfg, eng): postgresql.ENUM(name=enum["name"], schema=enum["schema"]) ) ) + + +@prepare_for_drop_tables.for_db("postgresql") +def prepare_for_drop_tables(config, connection): + """Ensure there are no locks on the current username/database.""" + + result = connection.exec_driver_sql( + "select pid, state, wait_event_type, query " + # "select pg_terminate_backend(pid), state, wait_event_type " + "from pg_stat_activity where " + "usename=current_user " + "and datname=current_database() and state='idle in transaction' " + "and pid != pg_backend_pid()" + ) + rows = result.all() # noqa + assert not rows, ( + "PostgreSQL may not be able to DROP tables due to " + "idle in transaction: %s" + % ("; ".join(row._mapping["query"] for row in rows)) + ) diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py index f26c21e22..a481be27e 100644 --- a/lib/sqlalchemy/dialects/sqlite/provision.py +++ b/lib/sqlalchemy/dialects/sqlite/provision.py @@ -7,7 +7,7 @@ from ...testing.provision import follower_url_from_main from ...testing.provision import log from ...testing.provision import post_configure_engine from ...testing.provision import run_reap_dbs -from ...testing.provision import stop_test_class +from ...testing.provision import stop_test_class_outside_fixtures from ...testing.provision import temp_table_keyword_args @@ -57,8 +57,8 @@ def _sqlite_drop_db(cfg, eng, ident): os.remove(path) -@stop_test_class.for_db("sqlite") -def stop_test_class(config, db, cls): +@stop_test_class_outside_fixtures.for_db("sqlite") +def stop_test_class_outside_fixtures(config, db, cls): with db.connect() as conn: files = [ row.file diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 50f00c025..31bf885db 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -813,10 +813,11 @@ class Connection(Connectable): if self._echo: self.engine.logger.info("BEGIN (implicit)") + self.__in_begin = True + if self._has_events or self.engine._has_events: self.dispatch.begin(self) - self.__in_begin = True try: self.engine.dialect.do_begin(self.connection) except BaseException as e: @@ -2729,14 +2730,16 @@ class Engine(Connectable, log.Identified): return self.conn def __exit__(self, type_, value, traceback): - - if type_ is not None: - self.transaction.rollback() - else: - if self.transaction.is_active: - self.transaction.commit() - if not self.close_with_result: - self.conn.close() + try: + if type_ is not None: + if self.transaction.is_active: + self.transaction.rollback() + else: + if self.transaction.is_active: + self.transaction.commit() + finally: + if not self.close_with_result: + self.conn.close() def begin(self, close_with_result=False): """Return a context manager delivering a :class:`_engine.Connection` diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index f89be1809..72d232085 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -655,9 +655,12 @@ def create_engine(url, **kwargs): c = base.Connection( engine, connection=dbapi_connection, _has_events=False ) - c._execution_options = util.immutabledict() - dialect.initialize(c) - dialect.do_rollback(c.connection) + c._execution_options = util.EMPTY_DICT + + try: + dialect.initialize(c) + finally: + dialect.do_rollback(c.connection) # previously, the "first_connect" event was used here, which was then # scaled back if the "on_connect" handler were present. now, diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index b031c1610..08b1bb060 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -179,10 +179,10 @@ class UnsupportedCompilationError(CompileError): code = "l7de" - def __init__(self, compiler, element_type): + def __init__(self, compiler, element_type, message=None): super(UnsupportedCompilationError, self).__init__( - "Compiler %r can't render element of type %s" - % (compiler, element_type) + "Compiler %r can't render element of type %s%s" + % (compiler, element_type, ": %s" % message if message else "") ) diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 5a31173ec..47fb6720b 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -425,9 +425,11 @@ def compiles(class_, *specs): return existing_dispatch(element, compiler, **kw) except exc.UnsupportedCompilationError as uce: util.raise_( - exc.CompileError( - "%s construct has no default " - "compilation handler." % type(element) + exc.UnsupportedCompilationError( + compiler, + type(element), + message="%s construct has no default " + "compilation handler." % type(element), ), from_=uce, ) @@ -476,9 +478,11 @@ class _dispatcher(object): fn = self.specs["default"] except KeyError as ke: util.raise_( - exc.CompileError( - "%s construct has no default " - "compilation handler." % type(element) + exc.UnsupportedCompilationError( + compiler, + type(element), + message="%s construct has no default " + "compilation handler." % type(element), ), replace_context=ke, ) diff --git a/lib/sqlalchemy/future/engine.py b/lib/sqlalchemy/future/engine.py index d2f609326..bfdcdfc7f 100644 --- a/lib/sqlalchemy/future/engine.py +++ b/lib/sqlalchemy/future/engine.py @@ -368,12 +368,15 @@ class Engine(_LegacyEngine): return self.conn def __exit__(self, type_, value, traceback): - if type_ is not None: - self.transaction.rollback() - else: - if self.transaction.is_active: - self.transaction.commit() - self.conn.close() + try: + if type_ is not None: + if self.transaction.is_active: + self.transaction.rollback() + else: + if self.transaction.is_active: + self.transaction.commit() + finally: + self.conn.close() def begin(self): """Return a :class:`_future.Connection` object with a transaction diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 2b7ad7bbd..f6b1a2e93 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -264,6 +264,7 @@ def clear_mappers(): upon a fixed set of classes. """ + with mapperlib._CONFIGURE_MUTEX: while _mapper_registry: try: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index e8312f393..8a7a64f78 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -531,6 +531,10 @@ class SessionTransaction(object): self._take_snapshot(autobegin=autobegin) + # make sure transaction is assigned before we call the + # dispatch + self.session._transaction = self + self.session.dispatch.after_transaction_create(self.session, self) @property @@ -1242,7 +1246,8 @@ class Session(_SessionClassMethods): def _autobegin(self): if not self.autocommit and self._transaction is None: - self._transaction = SessionTransaction(self, autobegin=True) + trans = SessionTransaction(self, autobegin=True) + assert self._transaction is trans return True return False @@ -1299,7 +1304,7 @@ class Session(_SessionClassMethods): if self._transaction is not None: if subtransactions or _subtrans or nested: trans = self._transaction._begin(nested=nested) - self._transaction = trans + assert self._transaction is trans if nested: self._nested_transaction = trans else: @@ -1307,7 +1312,8 @@ class Session(_SessionClassMethods): "A transaction is already begun on this Session." ) else: - self._transaction = SessionTransaction(self, nested=nested) + trans = SessionTransaction(self, nested=nested) + assert self._transaction is trans return self._transaction # needed for __enter__/__exit__ hook def begin_nested(self): diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 7c9509e45..47d9e2cba 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -18,7 +18,6 @@ from .. import event from .. import exc from .. import log from .. import util -from ..util import threading reset_rollback = util.symbol("reset_rollback") @@ -172,7 +171,6 @@ class Pool(log.Identified): self._orig_logging_name = None log.instance_logger(self, echoflag=echo) - self._threadconns = threading.local() self._creator = creator self._recycle = recycle self._invalidate_time = 0 @@ -423,26 +421,37 @@ class _ConnectionRecord(object): dbapi_connection = rec.get_connection() except Exception as err: with util.safe_reraise(): - rec._checkin_failed(err) + rec._checkin_failed(err, _fairy_was_created=False) echo = pool._should_log_debug() fairy = _ConnectionFairy(dbapi_connection, rec, echo) - rec.fairy_ref = weakref.ref( + + rec.fairy_ref = ref = weakref.ref( fairy, lambda ref: _finalize_fairy and _finalize_fairy(None, rec, pool, ref, echo), ) + _strong_ref_connection_records[ref] = rec if echo: pool.logger.debug( "Connection %r checked out from pool", dbapi_connection ) return fairy - def _checkin_failed(self, err): + def _checkin_failed(self, err, _fairy_was_created=True): self.invalidate(e=err) - self.checkin(_no_fairy_ref=True) + self.checkin( + _fairy_was_created=_fairy_was_created, + ) - def checkin(self, _no_fairy_ref=False): - if self.fairy_ref is None and not _no_fairy_ref: + def checkin(self, _fairy_was_created=True): + if self.fairy_ref is None and _fairy_was_created: + # _fairy_was_created is False for the initial get connection phase; + # meaning there was no _ConnectionFairy and we must unconditionally + # do a checkin. + # + # otherwise, if fairy_was_created==True, if fairy_ref is None here + # that means we were checked in already, so this looks like + # a double checkin. util.warn("Double checkin attempted on %s" % self) return self.fairy_ref = None @@ -453,6 +462,7 @@ class _ConnectionRecord(object): finalizer(connection) if pool.dispatch.checkin: pool.dispatch.checkin(connection, self) + pool._return_conn(self) @property @@ -603,12 +613,26 @@ def _finalize_fairy( """ + if ref: + _strong_ref_connection_records.pop(ref, None) + elif fairy: + _strong_ref_connection_records.pop(weakref.ref(fairy), None) + if ref is not None: if connection_record.fairy_ref is not ref: return assert connection is None connection = connection_record.connection + dont_restore_gced = pool._is_asyncio + + if dont_restore_gced: + detach = not connection_record or ref + can_manipulate_connection = not ref + else: + detach = not connection_record + can_manipulate_connection = True + if connection is not None: if connection_record and echo: pool.logger.debug( @@ -620,13 +644,26 @@ def _finalize_fairy( connection, connection_record, echo ) assert fairy.connection is connection - fairy._reset(pool) + if can_manipulate_connection: + fairy._reset(pool) + + if detach: + if connection_record: + fairy._pool = pool + fairy.detach() + + if can_manipulate_connection: + if pool.dispatch.close_detached: + pool.dispatch.close_detached(connection) + + pool._close_connection(connection) + else: + util.warn( + "asyncio connection is being garbage " + "collected without being properly closed: %r" + % connection + ) - # Immediately close detached instances - if not connection_record: - if pool.dispatch.close_detached: - pool.dispatch.close_detached(connection) - pool._close_connection(connection) except BaseException as e: pool.logger.error( "Exception during reset or similar", exc_info=True @@ -640,6 +677,13 @@ def _finalize_fairy( connection_record.checkin() +# a dictionary of the _ConnectionFairy weakrefs to _ConnectionRecord, so that +# GC under pypy will call ConnectionFairy finalizers. linked directly to the +# weakref that will empty itself when collected so that it should not create +# any unmanaged memory references. +_strong_ref_connection_records = {} + + class _ConnectionFairy(object): """Proxies a DBAPI connection and provides return-on-dereference @@ -774,7 +818,17 @@ class _ConnectionFairy(object): ) except Exception as err: with util.safe_reraise(): - fairy._connection_record._checkin_failed(err) + fairy._connection_record._checkin_failed( + err, + _fairy_was_created=True, + ) + + # prevent _ConnectionFairy from being carried + # in the stack trace. Do this after the + # connection record has been checked in, so that + # if the del triggers a finalize fairy, it won't + # try to checkin a second time. + del fairy attempts -= 1 diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a1426b628..550111020 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -729,7 +729,7 @@ class ExecutableOption(HasCopyInternals, HasCacheKey): return c -class Executable(Generative): +class Executable(roles.CoerceTextStatementRole, Generative): """Mark a :class:`_expression.ClauseElement` as supporting execution. :class:`.Executable` is a superclass for all "statement" types diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index d3c767b5d..5ea3526ea 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1502,7 +1502,6 @@ class TextClause( roles.OrderByRole, roles.FromClauseRole, roles.SelectStatementRole, - roles.CoerceTextStatementRole, roles.BinaryElementRole, roles.InElementRole, Executable, diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index c1afeb907..9f2d0b857 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -32,6 +32,7 @@ from .assertions import is_instance_of # noqa from .assertions import is_none # noqa from .assertions import is_not # noqa from .assertions import is_not_ # noqa +from .assertions import is_not_none # noqa from .assertions import is_true # noqa from .assertions import le_ # noqa from .assertions import ne_ # noqa diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index b2a4ac66e..f2ed91b79 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -228,6 +228,10 @@ def is_none(a, msg=None): is_(a, None, msg=msg) +def is_not_none(a, msg=None): + is_not(a, None, msg=msg) + + def is_true(a, msg=None): is_(bool(a), True, msg=msg) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index f64153f33..750671f9f 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -97,6 +97,10 @@ def get_current_test_name(): return _fixture_functions.get_current_test_name() +def mark_base_test_class(): + return _fixture_functions.mark_base_test_class() + + class Config(object): def __init__(self, db, db_opts, options, file_config): self._set_name(db) diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index a4c1f3973..a313c298a 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -7,12 +7,14 @@ from __future__ import absolute_import +import collections import re import warnings import weakref from . import config from .util import decorator +from .util import gc_collect from .. import event from .. import pool @@ -20,26 +22,29 @@ from .. import pool class ConnectionKiller(object): def __init__(self): self.proxy_refs = weakref.WeakKeyDictionary() - self.testing_engines = weakref.WeakKeyDictionary() - self.conns = set() + self.testing_engines = collections.defaultdict(set) + self.dbapi_connections = set() def add_pool(self, pool): - event.listen(pool, "connect", self.connect) - event.listen(pool, "checkout", self.checkout) - event.listen(pool, "invalidate", self.invalidate) - - def add_engine(self, engine): - self.add_pool(engine.pool) - self.testing_engines[engine] = True + event.listen(pool, "checkout", self._add_conn) + event.listen(pool, "checkin", self._remove_conn) + event.listen(pool, "close", self._remove_conn) + event.listen(pool, "close_detached", self._remove_conn) + # note we are keeping "invalidated" here, as those are still + # opened connections we would like to roll back + + def _add_conn(self, dbapi_con, con_record, con_proxy): + self.dbapi_connections.add(dbapi_con) + self.proxy_refs[con_proxy] = True - def connect(self, dbapi_conn, con_record): - self.conns.add((dbapi_conn, con_record)) + def _remove_conn(self, dbapi_conn, *arg): + self.dbapi_connections.discard(dbapi_conn) - def checkout(self, dbapi_con, con_record, con_proxy): - self.proxy_refs[con_proxy] = True + def add_engine(self, engine, scope): + self.add_pool(engine.pool) - def invalidate(self, dbapi_con, con_record, exception): - self.conns.discard((dbapi_con, con_record)) + assert scope in ("class", "global", "function", "fixture") + self.testing_engines[scope].add(engine) def _safe(self, fn): try: @@ -54,53 +59,89 @@ class ConnectionKiller(object): if rec is not None and rec.is_valid: self._safe(rec.rollback) - def close_all(self): + def checkin_all(self): + # run pool.checkin() for all ConnectionFairy instances we have + # tracked. + for rec in list(self.proxy_refs): if rec is not None and rec.is_valid: - self._safe(rec._close) - - def _after_test_ctx(self): - # this can cause a deadlock with pg8000 - pg8000 acquires - # prepared statement lock inside of rollback() - if async gc - # is collecting in finalize_fairy, deadlock. - # not sure if this should be for non-cpython only. - # note that firebird/fdb definitely needs this though - for conn, rec in list(self.conns): - if rec.connection is None: - # this is a hint that the connection is closed, which - # is causing segfaults on mysqlclient due to - # https://github.com/PyMySQL/mysqlclient-python/issues/270; - # try to work around here - continue - self._safe(conn.rollback) - - def _stop_test_ctx(self): - if config.options.low_connections: - self._stop_test_ctx_minimal() - else: - self._stop_test_ctx_aggressive() - - def _stop_test_ctx_minimal(self): - self.close_all() + self.dbapi_connections.discard(rec.connection) + self._safe(rec._checkin) - self.conns = set() - - for rec in list(self.testing_engines): - if rec is not config.db: - rec.dispose() - - def _stop_test_ctx_aggressive(self): - self.close_all() - for conn, rec in list(self.conns): - self._safe(conn.close) - rec.connection = None + # for fairy refs that were GCed and could not close the connection, + # such as asyncio, roll back those remaining connections + for con in self.dbapi_connections: + self._safe(con.rollback) + self.dbapi_connections.clear() - self.conns = set() - for rec in list(self.testing_engines): - if hasattr(rec, "sync_engine"): - rec.sync_engine.dispose() - else: - rec.dispose() + def close_all(self): + self.checkin_all() + + def prepare_for_drop_tables(self, connection): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return + + from . import provision + + provision.prepare_for_drop_tables(connection.engine.url, connection) + + def _drop_testing_engines(self, scope): + eng = self.testing_engines[scope] + for rec in list(eng): + for proxy_ref in list(self.proxy_refs): + if proxy_ref is not None and proxy_ref.is_valid: + if ( + proxy_ref._pool is not None + and proxy_ref._pool is rec.pool + ): + self._safe(proxy_ref._checkin) + rec.dispose() + eng.clear() + + def after_test(self): + self._drop_testing_engines("function") + + def after_test_outside_fixtures(self, test): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return + + if test.__class__.__leave_connections_for_teardown__: + return + + self.checkin_all() + + # on PostgreSQL, this will test for any "idle in transaction" + # connections. useful to identify tests with unusual patterns + # that can't be cleaned up correctly. + from . import provision + + with config.db.connect() as conn: + provision.prepare_for_drop_tables(conn.engine.url, conn) + + def stop_test_class_inside_fixtures(self): + self.checkin_all() + self._drop_testing_engines("function") + self._drop_testing_engines("class") + + def stop_test_class_outside_fixtures(self): + # ensure no refs to checked out connections at all. + + if pool.base._strong_ref_connection_records: + gc_collect() + + if pool.base._strong_ref_connection_records: + ln = len(pool.base._strong_ref_connection_records) + pool.base._strong_ref_connection_records.clear() + assert ( + False + ), "%d connection recs not cleared after test suite" % (ln) + + def final_cleanup(self): + self.checkin_all() + for scope in self.testing_engines: + self._drop_testing_engines(scope) def assert_all_closed(self): for rec in self.proxy_refs: @@ -111,20 +152,6 @@ class ConnectionKiller(object): testing_reaper = ConnectionKiller() -def drop_all_tables(metadata, bind): - testing_reaper.close_all() - if hasattr(bind, "close"): - bind.close() - - if not config.db.dialect.supports_alter: - from . import assertions - - with assertions.expect_warnings("Can't sort tables", assert_=False): - metadata.drop_all(bind) - else: - metadata.drop_all(bind) - - @decorator def assert_conns_closed(fn, *args, **kw): try: @@ -147,7 +174,7 @@ def rollback_open_connections(fn, *args, **kw): def close_first(fn, *args, **kw): """Decorator that closes all connections before fn execution.""" - testing_reaper.close_all() + testing_reaper.checkin_all() fn(*args, **kw) @@ -157,7 +184,7 @@ def close_open_connections(fn, *args, **kw): try: fn(*args, **kw) finally: - testing_reaper.close_all() + testing_reaper.checkin_all() def all_dialects(exclude=None): @@ -239,12 +266,14 @@ def reconnecting_engine(url=None, options=None): return engine -def testing_engine(url=None, options=None, future=False, asyncio=False): +def testing_engine(url=None, options=None, future=None, asyncio=False): """Produce an engine configured by --options with optional overrides.""" if asyncio: from sqlalchemy.ext.asyncio import create_async_engine as create_engine - elif future or config.db and config.db._is_future: + elif future or ( + config.db and config.db._is_future and future is not False + ): from sqlalchemy.future import create_engine else: from sqlalchemy import create_engine @@ -252,8 +281,10 @@ def testing_engine(url=None, options=None, future=False, asyncio=False): if not options: use_reaper = True + scope = "function" else: use_reaper = options.pop("use_reaper", True) + scope = options.pop("scope", "function") url = url or config.db.url @@ -268,16 +299,20 @@ def testing_engine(url=None, options=None, future=False, asyncio=False): default_opt.update(options) engine = create_engine(url, **options) - if asyncio: - engine.sync_engine._has_events = True - else: - engine._has_events = True # enable event blocks, helps with profiling + + if scope == "global": + if asyncio: + engine.sync_engine._has_events = True + else: + engine._has_events = ( + True # enable event blocks, helps with profiling + ) if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 - engine.pool._max_overflow = 5 + engine.pool._max_overflow = 0 if use_reaper: - testing_reaper.add_engine(engine) + testing_reaper.add_engine(engine, scope) return engine diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index ac4d3d8fa..45ca48444 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -5,6 +5,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import contextlib import re import sys @@ -12,12 +13,11 @@ import sqlalchemy as sa from . import assertions from . import config from . import schema -from .engines import drop_all_tables -from .engines import testing_engine from .entities import BasicEntity from .entities import ComparableEntity from .entities import ComparableMixin # noqa from .util import adict +from .util import drop_all_tables_from_metadata from .. import event from .. import util from ..orm import declarative_base @@ -25,10 +25,8 @@ from ..orm import registry from ..orm.decl_api import DeclarativeMeta from ..schema import sort_tables_and_constraints -# whether or not we use unittest changes things dramatically, -# as far as how pytest collection works. - +@config.mark_base_test_class() class TestBase(object): # A sequence of database names to always run, regardless of the # constraints below. @@ -48,81 +46,122 @@ class TestBase(object): # skipped. __skip_if__ = None + # if True, the testing reaper will not attempt to touch connection + # state after a test is completed and before the outer teardown + # starts + __leave_connections_for_teardown__ = False + def assert_(self, val, msg=None): assert val, msg - # apparently a handful of tests are doing this....OK - def setup(self): - if hasattr(self, "setUp"): - self.setUp() - - def teardown(self): - if hasattr(self, "tearDown"): - self.tearDown() - @config.fixture() def connection(self): - eng = getattr(self, "bind", config.db) + global _connection_fixture_connection + + eng = getattr(self, "bind", None) or config.db conn = eng.connect() trans = conn.begin() - try: - yield conn - finally: - if trans.is_active: - trans.rollback() - conn.close() + + _connection_fixture_connection = conn + yield conn + + _connection_fixture_connection = None + + if trans.is_active: + trans.rollback() + # trans would not be active here if the test is using + # the legacy @provide_metadata decorator still, as it will + # run a close all connections. + conn.close() @config.fixture() - def future_connection(self): + def future_connection(self, future_engine, connection): + # integrate the future_engine and connection fixtures so + # that users of the "connection" fixture will get at the + # "future" connection + yield connection - eng = testing_engine(future=True) - conn = eng.connect() - trans = conn.begin() - try: - yield conn - finally: - if trans.is_active: - trans.rollback() - conn.close() + @config.fixture() + def future_engine(self): + eng = getattr(self, "bind", None) or config.db + with _push_future_engine(eng): + yield + + @config.fixture() + def testing_engine(self): + from . import engines + + def gen_testing_engine( + url=None, options=None, future=None, asyncio=False + ): + if options is None: + options = {} + options["scope"] = "fixture" + return engines.testing_engine( + url=url, options=options, future=future, asyncio=asyncio + ) + + yield gen_testing_engine + + engines.testing_reaper._drop_testing_engines("fixture") @config.fixture() - def metadata(self): + def async_testing_engine(self, testing_engine): + def go(**kw): + kw["asyncio"] = True + return testing_engine(**kw) + + return go + + @config.fixture() + def metadata(self, request): """Provide bound MetaData for a single test, dropping afterwards.""" - from . import engines from ..sql import schema metadata = schema.MetaData() - try: - yield metadata - finally: - engines.drop_all_tables(metadata, config.db) + request.instance.metadata = metadata + yield metadata + del request.instance.metadata + if ( + _connection_fixture_connection + and _connection_fixture_connection.in_transaction() + ): + trans = _connection_fixture_connection.get_transaction() + trans.rollback() + with _connection_fixture_connection.begin(): + drop_all_tables_from_metadata( + metadata, _connection_fixture_connection + ) + else: + drop_all_tables_from_metadata(metadata, config.db) -class FutureEngineMixin(object): - @classmethod - def setup_class(cls): - from ..future.engine import Engine - from sqlalchemy import testing +_connection_fixture_connection = None - facade = Engine._future_facade(config.db) - config._current.push_engine(facade, testing) - super_ = super(FutureEngineMixin, cls) - if hasattr(super_, "setup_class"): - super_.setup_class() +@contextlib.contextmanager +def _push_future_engine(engine): - @classmethod - def teardown_class(cls): - super_ = super(FutureEngineMixin, cls) - if hasattr(super_, "teardown_class"): - super_.teardown_class() + from ..future.engine import Engine + from sqlalchemy import testing + + facade = Engine._future_facade(engine) + config._current.push_engine(facade, testing) + + yield facade - from sqlalchemy import testing + config._current.pop(testing) - config._current.pop(testing) + +class FutureEngineMixin(object): + @config.fixture(autouse=True, scope="class") + def _push_future_engine(self): + eng = getattr(self, "bind", None) or config.db + with _push_future_engine(eng): + yield class TablesTest(TestBase): @@ -151,18 +190,32 @@ class TablesTest(TestBase): other = None sequences = None - @property - def tables_test_metadata(self): - return self._tables_metadata - - @classmethod - def setup_class(cls): + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ cls._init_class() cls._setup_once_tables() cls._setup_once_inserts() + yield + + cls._teardown_once_metadata_bind() + + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): + self._setup_each_tables() + self._setup_each_inserts() + + yield + + self._teardown_each_tables() + + @property + def tables_test_metadata(self): + return self._tables_metadata + @classmethod def _init_class(cls): if cls.run_define_tables == "each": @@ -213,10 +266,10 @@ class TablesTest(TestBase): if self.run_define_tables == "each": self.tables.clear() if self.run_create_tables == "each": - drop_all_tables(self._tables_metadata, self.bind) + drop_all_tables_from_metadata(self._tables_metadata, self.bind) self._tables_metadata.clear() elif self.run_create_tables == "each": - drop_all_tables(self._tables_metadata, self.bind) + drop_all_tables_from_metadata(self._tables_metadata, self.bind) # no need to run deletes if tables are recreated on setup if ( @@ -242,17 +295,10 @@ class TablesTest(TestBase): file=sys.stderr, ) - def setup(self): - self._setup_each_tables() - self._setup_each_inserts() - - def teardown(self): - self._teardown_each_tables() - @classmethod def _teardown_once_metadata_bind(cls): if cls.run_create_tables: - drop_all_tables(cls._tables_metadata, cls.bind) + drop_all_tables_from_metadata(cls._tables_metadata, cls.bind) if cls.run_dispose_bind == "once": cls.dispose_bind(cls.bind) @@ -263,10 +309,6 @@ class TablesTest(TestBase): cls.bind = None @classmethod - def teardown_class(cls): - cls._teardown_once_metadata_bind() - - @classmethod def setup_bind(cls): return config.db @@ -332,38 +374,45 @@ class RemovesEvents(object): self._event_fns.add((target, name, fn)) event.listen(target, name, fn, **kw) - def teardown(self): + @config.fixture(autouse=True, scope="function") + def _remove_events(self): + yield for key in self._event_fns: event.remove(*key) - super_ = super(RemovesEvents, self) - if hasattr(super_, "teardown"): - super_.teardown() -class _ORMTest(object): - @classmethod - def teardown_class(cls): - sa.orm.session.close_all_sessions() - sa.orm.clear_mappers() - - -def create_session(**kw): - kw.setdefault("autoflush", False) - kw.setdefault("expire_on_commit", False) - return sa.orm.Session(config.db, **kw) +_fixture_sessions = set() def fixture_session(**kw): kw.setdefault("autoflush", True) kw.setdefault("expire_on_commit", True) - return sa.orm.Session(config.db, **kw) + sess = sa.orm.Session(config.db, **kw) + _fixture_sessions.add(sess) + return sess + + +def _close_all_sessions(): + # will close all still-referenced sessions + sa.orm.session.close_all_sessions() + _fixture_sessions.clear() -class ORMTest(_ORMTest, TestBase): +def stop_test_class_inside_fixtures(cls): + _close_all_sessions() + sa.orm.clear_mappers() + + +def after_test(): + if _fixture_sessions: + _close_all_sessions() + + +class ORMTest(TestBase): pass -class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): +class MappedTest(TablesTest, assertions.AssertsExecutionResults): # 'once', 'each', None run_setup_classes = "once" @@ -372,8 +421,9 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): classes = None - @classmethod - def setup_class(cls): + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ cls._init_class() if cls.classes is None: @@ -384,18 +434,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): cls._setup_once_mappers() cls._setup_once_inserts() - @classmethod - def teardown_class(cls): + yield + cls._teardown_once_class() cls._teardown_once_metadata_bind() - def setup(self): + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): self._setup_each_tables() self._setup_each_classes() self._setup_each_mappers() self._setup_each_inserts() - def teardown(self): + yield + sa.orm.session.close_all_sessions() self._teardown_each_mappers() self._teardown_each_classes() @@ -404,7 +456,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): @classmethod def _teardown_once_class(cls): cls.classes.clear() - _ORMTest.teardown_class() @classmethod def _setup_once_classes(cls): @@ -440,6 +491,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): """ cls_registry = cls.classes + assert cls_registry is not None + class FindFixture(type): def __init__(cls, classname, bases, dict_): cls_registry[classname] = cls diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py index a95c947e2..1f568dfc8 100644 --- a/lib/sqlalchemy/testing/plugin/bootstrap.py +++ b/lib/sqlalchemy/testing/plugin/bootstrap.py @@ -40,6 +40,11 @@ def load_file_as_module(name): if to_bootstrap == "pytest": sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") + sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True + if sys.version_info < (3, 0): + sys.modules["sqla_reinvent_fixtures"] = load_file_as_module( + "reinvent_fixtures_py2k" + ) sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin") else: raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 3594cd276..858814f91 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -21,6 +21,9 @@ import logging import re import sys +# flag which indicates we are in the SQLAlchemy testing suite, +# and not that of Alembic or a third party dialect. +bootstrapped_as_sqlalchemy = False log = logging.getLogger("sqlalchemy.testing.plugin_base") @@ -381,7 +384,7 @@ def _init_symbols(options, file_config): @post def _set_disable_asyncio(opt, file_config): - if opt.disable_asyncio: + if opt.disable_asyncio or not py3k: from sqlalchemy.testing import asyncio asyncio.ENABLE_ASYNCIO = False @@ -458,6 +461,8 @@ def _setup_requirements(argument): config.requirements = testing.requires = req_cls() + config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy + @post def _prep_testing_database(options, file_config): @@ -566,17 +571,23 @@ def generate_sub_tests(cls, module): yield cls -def start_test_class(cls): +def start_test_class_outside_fixtures(cls): _do_skips(cls) _setup_engine(cls) def stop_test_class(cls): - # from sqlalchemy import inspect - # assert not inspect(testing.db).get_table_names() + # close sessions, immediate connections, etc. + fixtures.stop_test_class_inside_fixtures(cls) + + # close outstanding connection pool connections, dispose of + # additional engines + engines.testing_reaper.stop_test_class_inside_fixtures() - provision.stop_test_class(config, config.db, cls) - engines.testing_reaper._stop_test_ctx() + +def stop_test_class_outside_fixtures(cls): + engines.testing_reaper.stop_test_class_outside_fixtures() + provision.stop_test_class_outside_fixtures(config, config.db, cls) try: if not options.low_connections: assertions.global_cleanup_assertions() @@ -590,14 +601,16 @@ def _restore_engine(): def final_process_cleanup(): - engines.testing_reaper._stop_test_ctx_aggressive() + engines.testing_reaper.final_cleanup() assertions.global_cleanup_assertions() _restore_engine() def _setup_engine(cls): if getattr(cls, "__engine_options__", None): - eng = engines.testing_engine(options=cls.__engine_options__) + opts = dict(cls.__engine_options__) + opts["scope"] = "class" + eng = engines.testing_engine(options=opts) config._current.push_engine(eng, testing) @@ -614,7 +627,12 @@ def before_test(test, test_module_name, test_class, test_name): def after_test(test): - engines.testing_reaper._after_test_ctx() + fixtures.after_test() + engines.testing_reaper.after_test() + + +def after_test_fixtures(test): + engines.testing_reaper.after_test_outside_fixtures(test) def _possible_configs_for_cls(cls, reasons=None, sparse=False): @@ -748,6 +766,10 @@ class FixtureFunctions(ABC): def get_current_test_name(self): raise NotImplementedError() + @abc.abstractmethod + def mark_base_test_class(self): + raise NotImplementedError() + _fixture_fn_class = None diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 46468a07d..4eaaecebb 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -17,6 +17,7 @@ import sys import pytest + try: import typing except ImportError: @@ -33,6 +34,14 @@ except ImportError: has_xdist = False +py2k = sys.version_info < (3, 0) +if py2k: + try: + import sqla_reinvent_fixtures as reinvent_fixtures_py2k + except ImportError: + from . import reinvent_fixtures_py2k + + def pytest_addoption(parser): group = parser.getgroup("sqlalchemy") @@ -238,6 +247,10 @@ def pytest_collection_modifyitems(session, config, items): else: newitems.append(item) + if py2k: + for item in newitems: + reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item) + # seems like the functions attached to a test class aren't sorted already? # is that true and why's that? (when using unittest, they're sorted) items[:] = sorted( @@ -251,7 +264,6 @@ def pytest_collection_modifyitems(session, config, items): def pytest_pycollect_makeitem(collector, name, obj): - if inspect.isclass(obj) and plugin_base.want_class(name, obj): from sqlalchemy.testing import config @@ -259,7 +271,6 @@ def pytest_pycollect_makeitem(collector, name, obj): obj = _apply_maybe_async(obj) ctor = getattr(pytest.Class, "from_parent", pytest.Class) - return [ ctor(name=parametrize_cls.__name__, parent=collector) for parametrize_cls in _parametrize_cls(collector.module, obj) @@ -287,12 +298,11 @@ def _is_wrapped_coroutine_function(fn): def _apply_maybe_async(obj, recurse=True): from sqlalchemy.testing import asyncio - setup_names = {"setup", "setup_class", "teardown", "teardown_class"} for name, value in vars(obj).items(): if ( (callable(value) or isinstance(value, classmethod)) and not getattr(value, "_maybe_async_applied", False) - and (name.startswith("test_") or name in setup_names) + and (name.startswith("test_")) and not _is_wrapped_coroutine_function(value) ): is_classmethod = False @@ -317,9 +327,6 @@ def _apply_maybe_async(obj, recurse=True): return obj -_current_class = None - - def _parametrize_cls(module, cls): """implement a class-based version of pytest parametrize.""" @@ -355,63 +362,153 @@ def _parametrize_cls(module, cls): return classes +_current_class = None + + def pytest_runtest_setup(item): from sqlalchemy.testing import asyncio - # here we seem to get called only based on what we collected - # in pytest_collection_modifyitems. So to do class-based stuff - # we have to tear that out. - global _current_class - if not isinstance(item, pytest.Function): return - # ... so we're doing a little dance here to figure it out... + # pytest_runtest_setup runs *before* pytest fixtures with scope="class". + # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest + # for the whole class and has to run things that are across all current + # databases, so we run this outside of the pytest fixture system altogether + # and ensure asyncio greenlet if any engines are async + + global _current_class + if _current_class is None: - asyncio._maybe_async(class_setup, item.parent.parent) + asyncio._maybe_async_provisioning( + plugin_base.start_test_class_outside_fixtures, + item.parent.parent.cls, + ) _current_class = item.parent.parent - # this is needed for the class-level, to ensure that the - # teardown runs after the class is completed with its own - # class-level teardown... def finalize(): global _current_class - asyncio._maybe_async(class_teardown, item.parent.parent) _current_class = None + asyncio._maybe_async_provisioning( + plugin_base.stop_test_class_outside_fixtures, + item.parent.parent.cls, + ) + item.parent.parent.addfinalizer(finalize) - asyncio._maybe_async(test_setup, item) +def pytest_runtest_call(item): + # runs inside of pytest function fixture scope + # before test function runs -def pytest_runtest_teardown(item): from sqlalchemy.testing import asyncio - # ...but this works better as the hook here rather than - # using a finalizer, as the finalizer seems to get in the way - # of the test reporting failures correctly (you get a bunch of - # pytest assertion stuff instead) - asyncio._maybe_async(test_teardown, item) + asyncio._maybe_async( + plugin_base.before_test, + item, + item.parent.module.__name__, + item.parent.cls, + item.name, + ) -def test_setup(item): - plugin_base.before_test( - item, item.parent.module.__name__, item.parent.cls, item.name - ) +def pytest_runtest_teardown(item, nextitem): + # runs inside of pytest function fixture scope + # after test function runs + from sqlalchemy.testing import asyncio -def test_teardown(item): - plugin_base.after_test(item) + asyncio._maybe_async(plugin_base.after_test, item) -def class_setup(item): +@pytest.fixture(scope="class") +def setup_class_methods(request): from sqlalchemy.testing import asyncio - asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls) + cls = request.cls + + if hasattr(cls, "setup_test_class"): + asyncio._maybe_async(cls.setup_test_class) + + if py2k: + reinvent_fixtures_py2k.run_class_fixture_setup(request) + + yield + + if py2k: + reinvent_fixtures_py2k.run_class_fixture_teardown(request) + if hasattr(cls, "teardown_test_class"): + asyncio._maybe_async(cls.teardown_test_class) -def class_teardown(item): - plugin_base.stop_test_class(item.cls) + asyncio._maybe_async(plugin_base.stop_test_class, cls) + + +@pytest.fixture(scope="function") +def setup_test_methods(request): + from sqlalchemy.testing import asyncio + + # called for each test + + self = request.instance + + # 1. run outer xdist-style setup + if hasattr(self, "setup_test"): + asyncio._maybe_async(self.setup_test) + + # alembic test suite is using setUp and tearDown + # xdist methods; support these in the test suite + # for the near term + if hasattr(self, "setUp"): + asyncio._maybe_async(self.setUp) + + # 2. run homegrown function level "autouse" fixtures under py2k + if py2k: + reinvent_fixtures_py2k.run_fn_fixture_setup(request) + + # inside the yield: + + # 3. function level "autouse" fixtures under py3k (examples: TablesTest + # define tables / data, MappedTest define tables / mappers / data) + + # 4. function level fixtures defined on test functions themselves, + # e.g. "connection", "metadata" run next + + # 5. pytest hook pytest_runtest_call then runs + + # 6. test itself runs + + yield + + # yield finishes: + + # 7. pytest hook pytest_runtest_teardown hook runs, this is associated + # with fixtures close all sessions, provisioning.stop_test_class(), + # engines.testing_reaper -> ensure all connection pool connections + # are returned, engines created by testing_engine that aren't the + # config engine are disposed + + # 8. function level fixtures defined on test functions + # themselves, e.g. "connection" rolls back the transaction, "metadata" + # emits drop all + + # 9. function level "autouse" fixtures under py3k (examples: TablesTest / + # MappedTest delete table data, possibly drop tables and clear mappers + # depending on the flags defined by the test class) + + # 10. run homegrown function-level "autouse" fixtures under py2k + if py2k: + reinvent_fixtures_py2k.run_fn_fixture_teardown(request) + + asyncio._maybe_async(plugin_base.after_test_fixtures, self) + + # 11. run outer xdist-style teardown + if hasattr(self, "tearDown"): + asyncio._maybe_async(self.tearDown) + + if hasattr(self, "teardown_test"): + asyncio._maybe_async(self.teardown_test) def getargspec(fn): @@ -461,6 +558,8 @@ def %(name)s(%(args)s): # for the wrapped function decorated.__module__ = fn.__module__ decorated.__name__ = fn.__name__ + if hasattr(fn, "pytestmark"): + decorated.pytestmark = fn.pytestmark return decorated return decorate @@ -470,6 +569,11 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): def skip_test_exception(self, *arg, **kw): return pytest.skip.Exception(*arg, **kw) + def mark_base_test_class(self): + return pytest.mark.usefixtures( + "setup_class_methods", "setup_test_methods" + ) + _combination_id_fns = { "i": lambda obj: obj, "r": repr, @@ -647,8 +751,18 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): fn = asyncio._maybe_async_wrapper(fn) # other wrappers may be added here - # now apply FixtureFunctionMarker - fn = fixture(fn) + if py2k and "autouse" in kw: + # py2k workaround for too-slow collection of autouse fixtures + # in pytest 4.6.11. See notes in reinvent_fixtures_py2k for + # rationale. + + # comment this condition out in order to disable the + # py2k workaround entirely. + reinvent_fixtures_py2k.add_fixture(fn, fixture) + else: + # now apply FixtureFunctionMarker + fn = fixture(fn) + return fn if fn: diff --git a/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py new file mode 100644 index 000000000..36b68417b --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py @@ -0,0 +1,112 @@ +""" +invent a quick version of pytest autouse fixtures as pytest's unacceptably slow +collection/high memory use in pytest 4.6.11, which is the highest version that +works in py2k. + +by "too-slow" we mean the test suite can't even manage to be collected for a +single process in less than 70 seconds or so and memory use seems to be very +high as well. for two or four workers the job just times out after ten +minutes. + +so instead we have invented a very limited form of these fixtures, as our +current use of "autouse" fixtures are limited to those in fixtures.py. + +assumptions for these fixtures: + +1. we are only using "function" or "class" scope + +2. the functions must be associated with a test class + +3. the fixture functions cannot themselves use pytest fixtures + +4. the fixture functions must use yield, not return + +When py2k support is removed and we can stay on a modern pytest version, this +can all be removed. + + +""" +import collections + + +_py2k_fixture_fn_names = collections.defaultdict(set) +_py2k_class_fixtures = collections.defaultdict( + lambda: collections.defaultdict(set) +) +_py2k_function_fixtures = collections.defaultdict( + lambda: collections.defaultdict(set) +) + +_py2k_cls_fixture_stack = [] +_py2k_fn_fixture_stack = [] + + +def add_fixture(fn, fixture): + assert fixture.scope in ("class", "function") + _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope)) + + +def scan_for_fixtures_to_use_for_class(item): + test_class = item.parent.parent.obj + + for name in _py2k_fixture_fn_names: + for fixture_fn, scope in _py2k_fixture_fn_names[name]: + meth = getattr(test_class, name, None) + if meth and meth.im_func is fixture_fn: + for sup in test_class.__mro__: + if name in sup.__dict__: + if scope == "class": + _py2k_class_fixtures[test_class][sup].add(meth) + elif scope == "function": + _py2k_function_fixtures[test_class][sup].add(meth) + break + break + + +def run_class_fixture_setup(request): + + cls = request.cls + self = cls.__new__(cls) + + fixtures_for_this_class = _py2k_class_fixtures.get(cls) + + if fixtures_for_this_class: + for sup_ in cls.__mro__: + for fn in fixtures_for_this_class.get(sup_, ()): + iter_ = fn(self) + next(iter_) + + _py2k_cls_fixture_stack.append(iter_) + + +def run_class_fixture_teardown(request): + while _py2k_cls_fixture_stack: + iter_ = _py2k_cls_fixture_stack.pop(-1) + try: + next(iter_) + except StopIteration: + pass + + +def run_fn_fixture_setup(request): + cls = request.cls + self = request.instance + + fixtures_for_this_class = _py2k_function_fixtures.get(cls) + + if fixtures_for_this_class: + for sup_ in reversed(cls.__mro__): + for fn in fixtures_for_this_class.get(sup_, ()): + iter_ = fn(self) + next(iter_) + + _py2k_fn_fixture_stack.append(iter_) + + +def run_fn_fixture_teardown(request): + while _py2k_fn_fixture_stack: + iter_ = _py2k_fn_fixture_stack.pop(-1) + try: + next(iter_) + except StopIteration: + pass diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 4ee0567f2..2fade1c32 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -67,6 +67,7 @@ def setup_config(db_url, options, file_config, follower_ident): db_url = follower_url_from_main(db_url, follower_ident) db_opts = {} update_db_opts(db_url, db_opts) + db_opts["scope"] = "global" eng = engines.testing_engine(db_url, db_opts) post_configure_engine(db_url, eng, follower_ident) eng.connect().close() @@ -264,6 +265,7 @@ def drop_all_schema_objects(cfg, eng): if config.requirements.schemas.enabled_for_config(cfg): util.drop_all_tables(eng, inspector, schema=cfg.test_schema) + util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2) drop_all_schema_objects_post_tables(cfg, eng) @@ -299,7 +301,7 @@ def update_db_opts(db_url, db_opts): def post_configure_engine(url, engine, follower_ident): """Perform extra steps after configuring an engine for testing. - (For the internal dialects, currently only used by sqlite.) + (For the internal dialects, currently only used by sqlite, oracle) """ pass @@ -375,7 +377,12 @@ def temp_table_keyword_args(cfg, eng): @register.init -def stop_test_class(config, db, testcls): +def prepare_for_drop_tables(config, connection): + pass + + +@register.init +def stop_test_class_outside_fixtures(config, db, testcls): pass diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index da59d831f..35f3315c7 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -51,13 +51,15 @@ class LastrowidTest(fixtures.TablesTest): def test_autoincrement_on_insert(self, connection): - connection.execute(self.tables.autoinc_pk.insert(), data="some data") + connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) self._assert_round_trip(self.tables.autoinc_pk, connection) def test_last_inserted_id(self, connection): r = connection.execute( - self.tables.autoinc_pk.insert(), data="some data" + self.tables.autoinc_pk.insert(), dict(data="some data") ) pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) eq_(r.inserted_primary_key, (pk,)) @@ -65,7 +67,7 @@ class LastrowidTest(fixtures.TablesTest): @requirements.dbapi_lastrowid def test_native_lastrowid_autoinc(self, connection): r = connection.execute( - self.tables.autoinc_pk.insert(), data="some data" + self.tables.autoinc_pk.insert(), dict(data="some data") ) lastrowid = r.lastrowid pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) @@ -116,7 +118,9 @@ class InsertBehaviorTest(fixtures.TablesTest): engine = config.db with engine.begin() as conn: - r = conn.execute(self.tables.autoinc_pk.insert(), data="some data") + r = conn.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) assert r._soft_closed assert not r.closed assert r.is_insert @@ -131,7 +135,7 @@ class InsertBehaviorTest(fixtures.TablesTest): @requirements.returning def test_autoclose_on_insert_implicit_returning(self, connection): r = connection.execute( - self.tables.autoinc_pk.insert(), data="some data" + self.tables.autoinc_pk.insert(), dict(data="some data") ) assert r._soft_closed assert not r.closed @@ -315,7 +319,7 @@ class ReturningTest(fixtures.TablesTest): def test_explicit_returning_pk_autocommit(self, connection): table = self.tables.autoinc_pk r = connection.execute( - table.insert().returning(table.c.id), data="some data" + table.insert().returning(table.c.id), dict(data="some data") ) pk = r.first()[0] fetched_pk = connection.scalar(select(table.c.id)) @@ -324,7 +328,7 @@ class ReturningTest(fixtures.TablesTest): def test_explicit_returning_pk_no_autocommit(self, connection): table = self.tables.autoinc_pk r = connection.execute( - table.insert().returning(table.c.id), data="some data" + table.insert().returning(table.c.id), dict(data="some data") ) pk = r.first()[0] fetched_pk = connection.scalar(select(table.c.id)) @@ -332,13 +336,15 @@ class ReturningTest(fixtures.TablesTest): def test_autoincrement_on_insert_implicit_returning(self, connection): - connection.execute(self.tables.autoinc_pk.insert(), data="some data") + connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) self._assert_round_trip(self.tables.autoinc_pk, connection) def test_last_inserted_id_implicit_returning(self, connection): r = connection.execute( - self.tables.autoinc_pk.insert(), data="some data" + self.tables.autoinc_pk.insert(), dict(data="some data") ) pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) eq_(r.inserted_primary_key, (pk,)) diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 6c3c1005a..916d74db3 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -293,7 +293,7 @@ class ComponentReflectionTest(fixtures.TablesTest): from sqlalchemy import pool return engines.testing_engine( - options=dict(poolclass=pool.StaticPool) + options=dict(poolclass=pool.StaticPool, scope="class"), ) else: return config.db @@ -523,6 +523,7 @@ class ComponentReflectionTest(fixtures.TablesTest): insp = inspect(self.bind) eq_(insp.default_schema_name, self.bind.dialect.default_schema_name) + @testing.requires.foreign_key_constraint_reflection @testing.combinations( (None, True, False, False), (None, True, False, True, testing.requires.schemas), @@ -630,8 +631,12 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.combinations( (False, False), (False, True, testing.requires.schemas), - (True, False), - (False, True, testing.requires.schemas), + (True, False, testing.requires.view_reflection), + ( + True, + True, + testing.requires.schemas + testing.requires.view_reflection, + ), argnames="use_views,use_schema", ) def test_get_columns(self, connection, use_views, use_schema): @@ -999,6 +1004,7 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(names_that_duplicate_index, idx_names) eq_(uq_names, set()) + @testing.requires.view_reflection @testing.combinations( (False,), (True, testing.requires.schemas), argnames="use_schema" ) diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index e0fdbe47a..e8dd6cf2c 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -261,10 +261,6 @@ class ServerSideCursorsTest( ) return self.engine - def tearDown(self): - engines.testing_reaper.close_all() - self.engine.dispose() - @testing.combinations( ("global_string", True, "select 1", True), ("global_text", True, text("select 1"), True), @@ -309,24 +305,22 @@ class ServerSideCursorsTest( def test_conn_option(self): engine = self._fixture(False) - # should be enabled for this one - result = ( - engine.connect() - .execution_options(stream_results=True) - .exec_driver_sql("select 1") - ) - assert self._is_server_side(result.cursor) + with engine.connect() as conn: + # should be enabled for this one + result = conn.execution_options( + stream_results=True + ).exec_driver_sql("select 1") + assert self._is_server_side(result.cursor) def test_stmt_enabled_conn_option_disabled(self): engine = self._fixture(False) s = select(1).execution_options(stream_results=True) - # not this one - result = ( - engine.connect().execution_options(stream_results=False).execute(s) - ) - assert not self._is_server_side(result.cursor) + with engine.connect() as conn: + # not this one + result = conn.execution_options(stream_results=False).execute(s) + assert not self._is_server_side(result.cursor) def test_aliases_and_ss(self): engine = self._fixture(False) @@ -344,8 +338,7 @@ class ServerSideCursorsTest( assert not self._is_server_side(result.cursor) result.close() - @testing.provide_metadata - def test_roundtrip_fetchall(self): + def test_roundtrip_fetchall(self, metadata): md = self.metadata engine = self._fixture(True) @@ -385,8 +378,7 @@ class ServerSideCursorsTest( 0, ) - @testing.provide_metadata - def test_roundtrip_fetchmany(self): + def test_roundtrip_fetchmany(self, metadata): md = self.metadata engine = self._fixture(True) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index f8d9b3d88..0d9f08848 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -742,7 +742,7 @@ class PostCompileParamsTest( with self.sql_execution_asserter() as asserter: with config.db.connect() as conn: - conn.execute(stmt, q=10) + conn.execute(stmt, dict(q=10)) asserter.assert_( CursorSQL( @@ -761,7 +761,7 @@ class PostCompileParamsTest( with self.sql_execution_asserter() as asserter: with config.db.connect() as conn: - conn.execute(stmt, q=[5, 6, 7]) + conn.execute(stmt, dict(q=[5, 6, 7])) asserter.assert_( CursorSQL( @@ -783,7 +783,7 @@ class PostCompileParamsTest( with self.sql_execution_asserter() as asserter: with config.db.connect() as conn: - conn.execute(stmt, q=[(5, 10), (12, 18)]) + conn.execute(stmt, dict(q=[(5, 10), (12, 18)])) asserter.assert_( CursorSQL( @@ -807,7 +807,7 @@ class PostCompileParamsTest( with self.sql_execution_asserter() as asserter: with config.db.connect() as conn: - conn.execute(stmt, q=[(5, "z1"), (12, "z3")]) + conn.execute(stmt, dict(q=[(5, "z1"), (12, "z3")])) asserter.assert_( CursorSQL( diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index d8c35ed0b..7445ade00 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -46,11 +46,13 @@ class SequenceTest(fixtures.TablesTest): ) def test_insert_roundtrip(self, connection): - connection.execute(self.tables.seq_pk.insert(), data="some data") + connection.execute(self.tables.seq_pk.insert(), dict(data="some data")) self._assert_round_trip(self.tables.seq_pk, connection) def test_insert_lastrowid(self, connection): - r = connection.execute(self.tables.seq_pk.insert(), data="some data") + r = connection.execute( + self.tables.seq_pk.insert(), dict(data="some data") + ) eq_( r.inserted_primary_key, (testing.db.dialect.default_sequence_base,) ) @@ -62,7 +64,7 @@ class SequenceTest(fixtures.TablesTest): @requirements.sequences_optional def test_optional_seq(self, connection): r = connection.execute( - self.tables.seq_opt_pk.insert(), data="some data" + self.tables.seq_opt_pk.insert(), dict(data="some data") ) eq_(r.inserted_primary_key, (1,)) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 3a5e02c32..ebcceaae7 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -511,24 +511,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True @testing.fixture - def do_numeric_test(self, metadata): + def do_numeric_test(self, metadata, connection): @testing.emits_warning( r".*does \*not\* support Decimal objects natively" ) def run(type_, input_, output, filter_=None, check_scale=False): t = Table("t", metadata, Column("x", type_)) - t.create(testing.db) - with config.db.begin() as conn: - conn.execute(t.insert(), [{"x": x} for x in input_]) - - result = {row[0] for row in conn.execute(t.select())} - output = set(output) - if filter_: - result = set(filter_(x) for x in result) - output = set(filter_(x) for x in output) - eq_(result, output) - if check_scale: - eq_([str(x) for x in result], [str(x) for x in output]) + t.create(connection) + connection.execute(t.insert(), [{"x": x} for x in input_]) + + result = {row[0] for row in connection.execute(t.select())} + output = set(output) + if filter_: + result = set(filter_(x) for x in result) + output = set(filter_(x) for x in output) + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) return run @@ -1165,40 +1164,39 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): }, ) - def test_eval_none_flag_orm(self): + def test_eval_none_flag_orm(self, connection): Base = declarative_base() class Data(Base): __table__ = self.tables.data_table - s = Session(testing.db) + with Session(connection) as s: + d1 = Data(name="d1", data=None, nulldata=None) + s.add(d1) + s.commit() - d1 = Data(name="d1", data=None, nulldata=None) - s.add(d1) - s.commit() - - s.bulk_insert_mappings( - Data, [{"name": "d2", "data": None, "nulldata": None}] - ) - eq_( - s.query( - cast(self.tables.data_table.c.data, String()), - cast(self.tables.data_table.c.nulldata, String), + s.bulk_insert_mappings( + Data, [{"name": "d2", "data": None, "nulldata": None}] ) - .filter(self.tables.data_table.c.name == "d1") - .first(), - ("null", None), - ) - eq_( - s.query( - cast(self.tables.data_table.c.data, String()), - cast(self.tables.data_table.c.nulldata, String), + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), ) - .filter(self.tables.data_table.c.name == "d2") - .first(), - ("null", None), - ) class JSONLegacyStringCastIndexTest( diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py index f7e6da98e..3fb51ead3 100644 --- a/lib/sqlalchemy/testing/suite/test_update_delete.py +++ b/lib/sqlalchemy/testing/suite/test_update_delete.py @@ -32,7 +32,9 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest): def test_update(self, connection): t = self.tables.plain_pk - r = connection.execute(t.update().where(t.c.id == 2), data="d2_new") + r = connection.execute( + t.update().where(t.c.id == 2), dict(data="d2_new") + ) assert not r.is_insert assert not r.returns_rows diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index eb9fcd1cd..01185c284 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -14,6 +14,7 @@ import types from . import config from . import mock from .. import inspect +from ..engine import Connection from ..schema import Column from ..schema import DropConstraint from ..schema import DropTable @@ -207,11 +208,13 @@ def fail(msg): @decorator def provide_metadata(fn, *args, **kw): - """Provide bound MetaData for a single test, dropping afterwards.""" + """Provide bound MetaData for a single test, dropping afterwards. - # import cycle that only occurs with py2k's import resolver - # in py3k this can be moved top level. - from . import engines + Legacy; use the "metadata" pytest fixture. + + """ + + from . import fixtures metadata = schema.MetaData() self = args[0] @@ -220,7 +223,31 @@ def provide_metadata(fn, *args, **kw): try: return fn(*args, **kw) finally: - engines.drop_all_tables(metadata, config.db) + # close out some things that get in the way of dropping tables. + # when using the "metadata" fixture, there is a set ordering + # of things that makes sure things are cleaned up in order, however + # the simple "decorator" nature of this legacy function means + # we have to hardcode some of that cleanup ahead of time. + + # close ORM sessions + fixtures._close_all_sessions() + + # integrate with the "connection" fixture as there are many + # tests where it is used along with provide_metadata + if fixtures._connection_fixture_connection: + # TODO: this warning can be used to find all the places + # this is used with connection fixture + # warn("mixing legacy provide metadata with connection fixture") + drop_all_tables_from_metadata( + metadata, fixtures._connection_fixture_connection + ) + # as the provide_metadata fixture is often used with "testing.db", + # when we do the drop we have to commit the transaction so that + # the DB is actually updated as the CREATE would have been + # committed + fixtures._connection_fixture_connection.get_transaction().commit() + else: + drop_all_tables_from_metadata(metadata, config.db) self.metadata = prev_meta @@ -359,6 +386,29 @@ class adict(dict): get_all = __call__ +def drop_all_tables_from_metadata(metadata, engine_or_connection): + from . import engines + + def go(connection): + engines.testing_reaper.prepare_for_drop_tables(connection) + + if not connection.dialect.supports_alter: + from . import assertions + + with assertions.expect_warnings( + "Can't sort tables", assert_=False + ): + metadata.drop_all(connection) + else: + metadata.drop_all(connection) + + if not isinstance(engine_or_connection, Connection): + with engine_or_connection.begin() as connection: + go(connection) + else: + go(engine_or_connection) + + def drop_all_tables(engine, inspector, schema=None, include_names=None): if include_names is not None: diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index 1b078a263..5699cd035 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -56,9 +56,6 @@ def setup_filters(): # Core execution # r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method", - r"The connection.execute\(\) method in SQLAlchemy 2.0 will accept " - "parameters as a single dictionary or a single sequence of " - "dictionaries only.", r"The Connection.connect\(\) method is considered legacy", # r".*DefaultGenerator.execute\(\)", # diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index 99ecb4fb3..ca5a3abde 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -230,13 +230,16 @@ class AsyncAdaptedQueue: return self.put_nowait(item) try: - if timeout: + if timeout is not None: return self.await_( asyncio.wait_for(self._queue.put(item), timeout) ) else: return self.await_(self._queue.put(item)) - except asyncio.queues.QueueFull as err: + except ( + asyncio.queues.QueueFull, + asyncio.exceptions.TimeoutError, + ) as err: compat.raise_( Full(), replace_context=err, @@ -254,14 +257,18 @@ class AsyncAdaptedQueue: def get(self, block=True, timeout=None): if not block: return self.get_nowait() + try: - if timeout: + if timeout is not None: return self.await_( asyncio.wait_for(self._queue.get(), timeout) ) else: return self.await_(self._queue.get()) - except asyncio.queues.QueueEmpty as err: + except ( + asyncio.queues.QueueEmpty, + asyncio.exceptions.TimeoutError, + ) as err: compat.raise_( Empty(), replace_context=err, |
