diff options
Diffstat (limited to 'lib/sqlalchemy')
30 files changed, 848 insertions, 353 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 538679fcf..9d0e5d322 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -9,6 +9,16 @@      :name: Microsoft SQL Server +.. _mssql_external_dialects: + +External Dialects +----------------- + +In addition to the above DBAPI layers with native SQLAlchemy support, there +are third-party dialects for other DBAPI layers that are compatible +with SQL Server. See the "External Dialects" list on the +:ref:`dialect_toplevel` page.  +  .. _mssql_identity:  Auto Increment Behavior / IDENTITY Columns @@ -2785,15 +2795,14 @@ class MSDialect(default.DefaultDialect):      def has_table(self, connection, tablename, dbname, owner, schema):          if tablename.startswith("#"):  # temporary table              tables = ischema.mssql_temp_table_columns -            result = connection.execute( -                sql.select(tables.c.table_name) -                .where( -                    tables.c.table_name.like( -                        self._temp_table_name_like_pattern(tablename) -                    ) + +            s = sql.select(tables.c.table_name).where( +                tables.c.table_name.like( +                    self._temp_table_name_like_pattern(tablename)                  ) -                .limit(1)              ) + +            result = connection.execute(s.limit(1))              return result.scalar() is not None          else:              tables = ischema.tables diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py index 269eb164f..56f3305a7 100644 --- a/lib/sqlalchemy/dialects/mssql/provision.py +++ b/lib/sqlalchemy/dialects/mssql/provision.py @@ -1,6 +1,14 @@ +from sqlalchemy import inspect +from sqlalchemy import Integer  from ... import create_engine  from ... import exc +from ...schema import Column +from ...schema import DropConstraint +from ...schema import ForeignKeyConstraint +from ...schema import MetaData +from ...schema import Table  from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_pre_tables  from ...testing.provision import drop_db  from ...testing.provision import get_temp_table_name  from ...testing.provision import log @@ -38,7 +46,6 @@ def _mssql_drop_ignore(conn, ident):          #        "where database_id=db_id('%s')" % ident):          #    log.info("killing SQL server session %s", row['session_id'])          #    conn.exec_driver_sql("kill %s" % row['session_id']) -          conn.exec_driver_sql("drop database %s" % ident)          log.info("Reaped db: %s", ident)          return True @@ -83,4 +90,27 @@ def _mssql_temp_table_keyword_args(cfg, eng):  @get_temp_table_name.for_db("mssql")  def _mssql_get_temp_table_name(cfg, eng, base_name): -    return "#" + base_name +    return "##" + base_name + + +@drop_all_schema_objects_pre_tables.for_db("mssql") +def drop_all_schema_objects_pre_tables(cfg, eng): +    with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: +        inspector = inspect(conn) +        for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2): +            for tname in inspector.get_table_names(schema=schema): +                tb = Table( +                    tname, +                    MetaData(), +                    Column("x", Integer), +                    Column("y", Integer), +                    schema=schema, +                ) +                for fk in inspect(conn).get_foreign_keys(tname, schema=schema): +                    conn.execute( +                        DropConstraint( +                            ForeignKeyConstraint( +                                [tb.c.x], [tb.c.y], name=fk["name"] +                            ) +                        ) +                    ) diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 042443692..b8b4df760 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -93,6 +93,7 @@ The parameters accepted by the cx_oracle dialect are as follows:  * ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail. +  .. _cx_oracle_unicode:  Unicode diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py index d51131c0b..e0dadd58e 100644 --- a/lib/sqlalchemy/dialects/oracle/provision.py +++ b/lib/sqlalchemy/dialects/oracle/provision.py @@ -6,11 +6,11 @@ from ...testing.provision import create_db  from ...testing.provision import drop_db  from ...testing.provision import follower_url_from_main  from ...testing.provision import log +from ...testing.provision import post_configure_engine  from ...testing.provision import run_reap_dbs  from ...testing.provision import set_default_schema_on_connection -from ...testing.provision import stop_test_class +from ...testing.provision import stop_test_class_outside_fixtures  from ...testing.provision import temp_table_keyword_args -from ...testing.provision import update_db_opts  @create_db.for_db("oracle") @@ -57,21 +57,39 @@ def _oracle_drop_db(cfg, eng, ident):          _ora_drop_ignore(conn, "%s_ts2" % ident) -@update_db_opts.for_db("oracle") -def _oracle_update_db_opts(db_url, db_opts): -    pass +@stop_test_class_outside_fixtures.for_db("oracle") +def stop_test_class_outside_fixtures(config, db, cls): +    with db.begin() as conn: +        # run magic command to get rid of identity sequences +        # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/  # noqa E501 +        conn.exec_driver_sql("purge recyclebin") -@stop_test_class.for_db("oracle") -def stop_test_class(config, db, cls): -    """run magic command to get rid of identity sequences +    # clear statement cache on all connections that were used +    # https://github.com/oracle/python-cx_Oracle/issues/519 -    # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ +    for cx_oracle_conn in _all_conns: +        try: +            sc = cx_oracle_conn.stmtcachesize +        except db.dialect.dbapi.InterfaceError: +            # connection closed +            pass +        else: +            cx_oracle_conn.stmtcachesize = 0 +            cx_oracle_conn.stmtcachesize = sc +    _all_conns.clear() -    """ -    with db.begin() as conn: -        conn.exec_driver_sql("purge recyclebin") +_all_conns = set() + + +@post_configure_engine.for_db("oracle") +def _oracle_post_configure_engine(url, engine, follower_ident): +    from sqlalchemy import event + +    @event.listens_for(engine, "checkout") +    def checkout(dbapi_con, con_record, con_proxy): +        _all_conns.add(dbapi_con)  @run_reap_dbs.for_db("oracle") diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 7c6e8fb02..e542c77f4 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -670,7 +670,6 @@ class AsyncAdapt_asyncpg_connection:      def rollback(self):          if self._started:              self.await_(self._transaction.rollback()) -              self._transaction = None              self._started = False diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py index d345cdfdf..70c390800 100644 --- a/lib/sqlalchemy/dialects/postgresql/provision.py +++ b/lib/sqlalchemy/dialects/postgresql/provision.py @@ -8,6 +8,7 @@ from ...testing.provision import drop_all_schema_objects_post_tables  from ...testing.provision import drop_all_schema_objects_pre_tables  from ...testing.provision import drop_db  from ...testing.provision import log +from ...testing.provision import prepare_for_drop_tables  from ...testing.provision import set_default_schema_on_connection  from ...testing.provision import temp_table_keyword_args @@ -102,3 +103,23 @@ def drop_all_schema_objects_post_tables(cfg, eng):                      postgresql.ENUM(name=enum["name"], schema=enum["schema"])                  )              ) + + +@prepare_for_drop_tables.for_db("postgresql") +def prepare_for_drop_tables(config, connection): +    """Ensure there are no locks on the current username/database.""" + +    result = connection.exec_driver_sql( +        "select pid, state, wait_event_type, query " +        # "select pg_terminate_backend(pid), state, wait_event_type " +        "from pg_stat_activity where " +        "usename=current_user " +        "and datname=current_database() and state='idle in transaction' " +        "and pid != pg_backend_pid()" +    ) +    rows = result.all()  # noqa +    assert not rows, ( +        "PostgreSQL may not be able to DROP tables due to " +        "idle in transaction: %s" +        % ("; ".join(row._mapping["query"] for row in rows)) +    ) diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py index f26c21e22..a481be27e 100644 --- a/lib/sqlalchemy/dialects/sqlite/provision.py +++ b/lib/sqlalchemy/dialects/sqlite/provision.py @@ -7,7 +7,7 @@ from ...testing.provision import follower_url_from_main  from ...testing.provision import log  from ...testing.provision import post_configure_engine  from ...testing.provision import run_reap_dbs -from ...testing.provision import stop_test_class +from ...testing.provision import stop_test_class_outside_fixtures  from ...testing.provision import temp_table_keyword_args @@ -57,8 +57,8 @@ def _sqlite_drop_db(cfg, eng, ident):              os.remove(path) -@stop_test_class.for_db("sqlite") -def stop_test_class(config, db, cls): +@stop_test_class_outside_fixtures.for_db("sqlite") +def stop_test_class_outside_fixtures(config, db, cls):      with db.connect() as conn:          files = [              row.file diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 50f00c025..72d66b7c8 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -2729,14 +2729,16 @@ class Engine(Connectable, log.Identified):              return self.conn          def __exit__(self, type_, value, traceback): - -            if type_ is not None: -                self.transaction.rollback() -            else: -                if self.transaction.is_active: -                    self.transaction.commit() -            if not self.close_with_result: -                self.conn.close() +            try: +                if type_ is not None: +                    if self.transaction.is_active: +                        self.transaction.rollback() +                else: +                    if self.transaction.is_active: +                        self.transaction.commit() +            finally: +                if not self.close_with_result: +                    self.conn.close()      def begin(self, close_with_result=False):          """Return a context manager delivering a :class:`_engine.Connection` diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index f89be1809..72d232085 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -655,9 +655,12 @@ def create_engine(url, **kwargs):              c = base.Connection(                  engine, connection=dbapi_connection, _has_events=False              ) -            c._execution_options = util.immutabledict() -            dialect.initialize(c) -            dialect.do_rollback(c.connection) +            c._execution_options = util.EMPTY_DICT + +            try: +                dialect.initialize(c) +            finally: +                dialect.do_rollback(c.connection)          # previously, the "first_connect" event was used here, which was then          # scaled back if the "on_connect" handler were present.  now, diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index b031c1610..08b1bb060 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -179,10 +179,10 @@ class UnsupportedCompilationError(CompileError):      code = "l7de" -    def __init__(self, compiler, element_type): +    def __init__(self, compiler, element_type, message=None):          super(UnsupportedCompilationError, self).__init__( -            "Compiler %r can't render element of type %s" -            % (compiler, element_type) +            "Compiler %r can't render element of type %s%s" +            % (compiler, element_type, ": %s" % message if message else "")          ) diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 5a31173ec..47fb6720b 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -425,9 +425,11 @@ def compiles(class_, *specs):                          return existing_dispatch(element, compiler, **kw)                      except exc.UnsupportedCompilationError as uce:                          util.raise_( -                            exc.CompileError( -                                "%s construct has no default " -                                "compilation handler." % type(element) +                            exc.UnsupportedCompilationError( +                                compiler, +                                type(element), +                                message="%s construct has no default " +                                "compilation handler." % type(element),                              ),                              from_=uce,                          ) @@ -476,9 +478,11 @@ class _dispatcher(object):                  fn = self.specs["default"]              except KeyError as ke:                  util.raise_( -                    exc.CompileError( -                        "%s construct has no default " -                        "compilation handler." % type(element) +                    exc.UnsupportedCompilationError( +                        compiler, +                        type(element), +                        message="%s construct has no default " +                        "compilation handler." % type(element),                      ),                      replace_context=ke,                  ) diff --git a/lib/sqlalchemy/future/engine.py b/lib/sqlalchemy/future/engine.py index d2f609326..bfdcdfc7f 100644 --- a/lib/sqlalchemy/future/engine.py +++ b/lib/sqlalchemy/future/engine.py @@ -368,12 +368,15 @@ class Engine(_LegacyEngine):              return self.conn          def __exit__(self, type_, value, traceback): -            if type_ is not None: -                self.transaction.rollback() -            else: -                if self.transaction.is_active: -                    self.transaction.commit() -            self.conn.close() +            try: +                if type_ is not None: +                    if self.transaction.is_active: +                        self.transaction.rollback() +                else: +                    if self.transaction.is_active: +                        self.transaction.commit() +            finally: +                self.conn.close()      def begin(self):          """Return a :class:`_future.Connection` object with a transaction diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 7c9509e45..6c3aad037 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -426,6 +426,7 @@ class _ConnectionRecord(object):                  rec._checkin_failed(err)          echo = pool._should_log_debug()          fairy = _ConnectionFairy(dbapi_connection, rec, echo) +          rec.fairy_ref = weakref.ref(              fairy,              lambda ref: _finalize_fairy @@ -609,6 +610,15 @@ def _finalize_fairy(          assert connection is None          connection = connection_record.connection +    dont_restore_gced = pool._is_asyncio + +    if dont_restore_gced: +        detach = not connection_record or ref +        can_manipulate_connection = not ref +    else: +        detach = not connection_record +        can_manipulate_connection = True +      if connection is not None:          if connection_record and echo:              pool.logger.debug( @@ -620,13 +630,26 @@ def _finalize_fairy(                  connection, connection_record, echo              )              assert fairy.connection is connection -            fairy._reset(pool) +            if can_manipulate_connection: +                fairy._reset(pool) + +            if detach: +                if connection_record: +                    fairy._pool = pool +                    fairy.detach() + +                if can_manipulate_connection: +                    if pool.dispatch.close_detached: +                        pool.dispatch.close_detached(connection) + +                    pool._close_connection(connection) +                else: +                    util.warn( +                        "asyncio connection is being garbage " +                        "collected without being properly closed: %r" +                        % connection +                    ) -            # Immediately close detached instances -            if not connection_record: -                if pool.dispatch.close_detached: -                    pool.dispatch.close_detached(connection) -                pool._close_connection(connection)          except BaseException as e:              pool.logger.error(                  "Exception during reset or similar", exc_info=True diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a1426b628..550111020 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -729,7 +729,7 @@ class ExecutableOption(HasCopyInternals, HasCacheKey):          return c -class Executable(Generative): +class Executable(roles.CoerceTextStatementRole, Generative):      """Mark a :class:`_expression.ClauseElement` as supporting execution.      :class:`.Executable` is a superclass for all "statement" types diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index d3c767b5d..5ea3526ea 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1502,7 +1502,6 @@ class TextClause(      roles.OrderByRole,      roles.FromClauseRole,      roles.SelectStatementRole, -    roles.CoerceTextStatementRole,      roles.BinaryElementRole,      roles.InElementRole,      Executable, diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index c1afeb907..9f2d0b857 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -32,6 +32,7 @@ from .assertions import is_instance_of  # noqa  from .assertions import is_none  # noqa  from .assertions import is_not  # noqa  from .assertions import is_not_  # noqa +from .assertions import is_not_none  # noqa  from .assertions import is_true  # noqa  from .assertions import le_  # noqa  from .assertions import ne_  # noqa diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index b2a4ac66e..f2ed91b79 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -228,6 +228,10 @@ def is_none(a, msg=None):      is_(a, None, msg=msg) +def is_not_none(a, msg=None): +    is_not(a, None, msg=msg) + +  def is_true(a, msg=None):      is_(bool(a), True, msg=msg) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index f64153f33..750671f9f 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -97,6 +97,10 @@ def get_current_test_name():      return _fixture_functions.get_current_test_name() +def mark_base_test_class(): +    return _fixture_functions.mark_base_test_class() + +  class Config(object):      def __init__(self, db, db_opts, options, file_config):          self._set_name(db) diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index a4c1f3973..8b334fde2 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -7,6 +7,7 @@  from __future__ import absolute_import +import collections  import re  import warnings  import weakref @@ -20,26 +21,29 @@ from .. import pool  class ConnectionKiller(object):      def __init__(self):          self.proxy_refs = weakref.WeakKeyDictionary() -        self.testing_engines = weakref.WeakKeyDictionary() -        self.conns = set() +        self.testing_engines = collections.defaultdict(set) +        self.dbapi_connections = set()      def add_pool(self, pool): -        event.listen(pool, "connect", self.connect) -        event.listen(pool, "checkout", self.checkout) -        event.listen(pool, "invalidate", self.invalidate) - -    def add_engine(self, engine): -        self.add_pool(engine.pool) -        self.testing_engines[engine] = True +        event.listen(pool, "checkout", self._add_conn) +        event.listen(pool, "checkin", self._remove_conn) +        event.listen(pool, "close", self._remove_conn) +        event.listen(pool, "close_detached", self._remove_conn) +        # note we are keeping "invalidated" here, as those are still +        # opened connections we would like to roll back + +    def _add_conn(self, dbapi_con, con_record, con_proxy): +        self.dbapi_connections.add(dbapi_con) +        self.proxy_refs[con_proxy] = True -    def connect(self, dbapi_conn, con_record): -        self.conns.add((dbapi_conn, con_record)) +    def _remove_conn(self, dbapi_conn, *arg): +        self.dbapi_connections.discard(dbapi_conn) -    def checkout(self, dbapi_con, con_record, con_proxy): -        self.proxy_refs[con_proxy] = True +    def add_engine(self, engine, scope): +        self.add_pool(engine.pool) -    def invalidate(self, dbapi_con, con_record, exception): -        self.conns.discard((dbapi_con, con_record)) +        assert scope in ("class", "global", "function", "fixture") +        self.testing_engines[scope].add(engine)      def _safe(self, fn):          try: @@ -54,53 +58,76 @@ class ConnectionKiller(object):              if rec is not None and rec.is_valid:                  self._safe(rec.rollback) -    def close_all(self): +    def checkin_all(self): +        # run pool.checkin() for all ConnectionFairy instances we have +        # tracked. +          for rec in list(self.proxy_refs):              if rec is not None and rec.is_valid: -                self._safe(rec._close) - -    def _after_test_ctx(self): -        # this can cause a deadlock with pg8000 - pg8000 acquires -        # prepared statement lock inside of rollback() - if async gc -        # is collecting in finalize_fairy, deadlock. -        # not sure if this should be for non-cpython only. -        # note that firebird/fdb definitely needs this though -        for conn, rec in list(self.conns): -            if rec.connection is None: -                # this is a hint that the connection is closed, which -                # is causing segfaults on mysqlclient due to -                # https://github.com/PyMySQL/mysqlclient-python/issues/270; -                # try to work around here -                continue -            self._safe(conn.rollback) - -    def _stop_test_ctx(self): -        if config.options.low_connections: -            self._stop_test_ctx_minimal() -        else: -            self._stop_test_ctx_aggressive() +                self.dbapi_connections.discard(rec.connection) +                self._safe(rec._checkin) -    def _stop_test_ctx_minimal(self): -        self.close_all() +        # for fairy refs that were GCed and could not close the connection, +        # such as asyncio, roll back those remaining connections +        for con in self.dbapi_connections: +            self._safe(con.rollback) +        self.dbapi_connections.clear() -        self.conns = set() +    def close_all(self): +        self.checkin_all() -        for rec in list(self.testing_engines): -            if rec is not config.db: -                rec.dispose() +    def prepare_for_drop_tables(self, connection): +        # don't do aggressive checks for third party test suites +        if not config.bootstrapped_as_sqlalchemy: +            return -    def _stop_test_ctx_aggressive(self): -        self.close_all() -        for conn, rec in list(self.conns): -            self._safe(conn.close) -            rec.connection = None +        from . import provision + +        provision.prepare_for_drop_tables(connection.engine.url, connection) + +    def _drop_testing_engines(self, scope): +        eng = self.testing_engines[scope] +        for rec in list(eng): +            for proxy_ref in list(self.proxy_refs): +                if proxy_ref is not None and proxy_ref.is_valid: +                    if ( +                        proxy_ref._pool is not None +                        and proxy_ref._pool is rec.pool +                    ): +                        self._safe(proxy_ref._checkin) +            rec.dispose() +        eng.clear() + +    def after_test(self): +        self._drop_testing_engines("function") + +    def after_test_outside_fixtures(self, test): +        # don't do aggressive checks for third party test suites +        if not config.bootstrapped_as_sqlalchemy: +            return + +        if test.__class__.__leave_connections_for_teardown__: +            return -        self.conns = set() -        for rec in list(self.testing_engines): -            if hasattr(rec, "sync_engine"): -                rec.sync_engine.dispose() -            else: -                rec.dispose() +        self.checkin_all() + +        # on PostgreSQL, this will test for any "idle in transaction" +        # connections.   useful to identify tests with unusual patterns +        # that can't be cleaned up correctly. +        from . import provision + +        with config.db.connect() as conn: +            provision.prepare_for_drop_tables(conn.engine.url, conn) + +    def stop_test_class_inside_fixtures(self): +        self.checkin_all() +        self._drop_testing_engines("function") +        self._drop_testing_engines("class") + +    def final_cleanup(self): +        self.checkin_all() +        for scope in self.testing_engines: +            self._drop_testing_engines(scope)      def assert_all_closed(self):          for rec in self.proxy_refs: @@ -111,20 +138,6 @@ class ConnectionKiller(object):  testing_reaper = ConnectionKiller() -def drop_all_tables(metadata, bind): -    testing_reaper.close_all() -    if hasattr(bind, "close"): -        bind.close() - -    if not config.db.dialect.supports_alter: -        from . import assertions - -        with assertions.expect_warnings("Can't sort tables", assert_=False): -            metadata.drop_all(bind) -    else: -        metadata.drop_all(bind) - -  @decorator  def assert_conns_closed(fn, *args, **kw):      try: @@ -147,7 +160,7 @@ def rollback_open_connections(fn, *args, **kw):  def close_first(fn, *args, **kw):      """Decorator that closes all connections before fn execution.""" -    testing_reaper.close_all() +    testing_reaper.checkin_all()      fn(*args, **kw) @@ -157,7 +170,7 @@ def close_open_connections(fn, *args, **kw):      try:          fn(*args, **kw)      finally: -        testing_reaper.close_all() +        testing_reaper.checkin_all()  def all_dialects(exclude=None): @@ -239,12 +252,14 @@ def reconnecting_engine(url=None, options=None):      return engine -def testing_engine(url=None, options=None, future=False, asyncio=False): +def testing_engine(url=None, options=None, future=None, asyncio=False):      """Produce an engine configured by --options with optional overrides."""      if asyncio:          from sqlalchemy.ext.asyncio import create_async_engine as create_engine -    elif future or config.db and config.db._is_future: +    elif future or ( +        config.db and config.db._is_future and future is not False +    ):          from sqlalchemy.future import create_engine      else:          from sqlalchemy import create_engine @@ -252,8 +267,10 @@ def testing_engine(url=None, options=None, future=False, asyncio=False):      if not options:          use_reaper = True +        scope = "function"      else:          use_reaper = options.pop("use_reaper", True) +        scope = options.pop("scope", "function")      url = url or config.db.url @@ -268,16 +285,20 @@ def testing_engine(url=None, options=None, future=False, asyncio=False):          default_opt.update(options)      engine = create_engine(url, **options) -    if asyncio: -        engine.sync_engine._has_events = True -    else: -        engine._has_events = True  # enable event blocks, helps with profiling + +    if scope == "global": +        if asyncio: +            engine.sync_engine._has_events = True +        else: +            engine._has_events = ( +                True  # enable event blocks, helps with profiling +            )      if isinstance(engine.pool, pool.QueuePool):          engine.pool._timeout = 0 -        engine.pool._max_overflow = 5 +        engine.pool._max_overflow = 0      if use_reaper: -        testing_reaper.add_engine(engine) +        testing_reaper.add_engine(engine, scope)      return engine diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index ac4d3d8fa..f19b4652a 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -5,6 +5,7 @@  # This module is part of SQLAlchemy and is released under  # the MIT License: http://www.opensource.org/licenses/mit-license.php +import contextlib  import re  import sys @@ -12,12 +13,11 @@ import sqlalchemy as sa  from . import assertions  from . import config  from . import schema -from .engines import drop_all_tables -from .engines import testing_engine  from .entities import BasicEntity  from .entities import ComparableEntity  from .entities import ComparableMixin  # noqa  from .util import adict +from .util import drop_all_tables_from_metadata  from .. import event  from .. import util  from ..orm import declarative_base @@ -25,10 +25,8 @@ from ..orm import registry  from ..orm.decl_api import DeclarativeMeta  from ..schema import sort_tables_and_constraints -# whether or not we use unittest changes things dramatically, -# as far as how pytest collection works. - +@config.mark_base_test_class()  class TestBase(object):      # A sequence of database names to always run, regardless of the      # constraints below. @@ -48,81 +46,114 @@ class TestBase(object):      # skipped.      __skip_if__ = None +    # if True, the testing reaper will not attempt to touch connection +    # state after a test is completed and before the outer teardown +    # starts +    __leave_connections_for_teardown__ = False +      def assert_(self, val, msg=None):          assert val, msg -    # apparently a handful of tests are doing this....OK -    def setup(self): -        if hasattr(self, "setUp"): -            self.setUp() - -    def teardown(self): -        if hasattr(self, "tearDown"): -            self.tearDown() -      @config.fixture()      def connection(self): -        eng = getattr(self, "bind", config.db) +        global _connection_fixture_connection + +        eng = getattr(self, "bind", None) or config.db          conn = eng.connect()          trans = conn.begin() -        try: -            yield conn -        finally: -            if trans.is_active: -                trans.rollback() -            conn.close() + +        _connection_fixture_connection = conn +        yield conn + +        _connection_fixture_connection = None + +        if trans.is_active: +            trans.rollback() +        # trans would not be active here if the test is using +        # the legacy @provide_metadata decorator still, as it will +        # run a close all connections. +        conn.close()      @config.fixture() -    def future_connection(self): +    def future_connection(self, future_engine, connection): +        # integrate the future_engine and connection fixtures so +        # that users of the "connection" fixture will get at the +        # "future" connection +        yield connection -        eng = testing_engine(future=True) -        conn = eng.connect() -        trans = conn.begin() -        try: -            yield conn -        finally: -            if trans.is_active: -                trans.rollback() -            conn.close() +    @config.fixture() +    def future_engine(self): +        eng = getattr(self, "bind", None) or config.db +        with _push_future_engine(eng): +            yield + +    @config.fixture() +    def testing_engine(self): +        from . import engines + +        def gen_testing_engine( +            url=None, options=None, future=False, asyncio=False +        ): +            if options is None: +                options = {} +            options["scope"] = "fixture" +            return engines.testing_engine( +                url=url, options=options, future=future, asyncio=asyncio +            ) + +        yield gen_testing_engine + +        engines.testing_reaper._drop_testing_engines("fixture")      @config.fixture() -    def metadata(self): +    def metadata(self, request):          """Provide bound MetaData for a single test, dropping afterwards.""" -        from . import engines          from ..sql import schema          metadata = schema.MetaData() -        try: -            yield metadata -        finally: -            engines.drop_all_tables(metadata, config.db) +        request.instance.metadata = metadata +        yield metadata +        del request.instance.metadata +        if ( +            _connection_fixture_connection +            and _connection_fixture_connection.in_transaction() +        ): +            trans = _connection_fixture_connection.get_transaction() +            trans.rollback() +            with _connection_fixture_connection.begin(): +                drop_all_tables_from_metadata( +                    metadata, _connection_fixture_connection +                ) +        else: +            drop_all_tables_from_metadata(metadata, config.db) -class FutureEngineMixin(object): -    @classmethod -    def setup_class(cls): -        from ..future.engine import Engine -        from sqlalchemy import testing +_connection_fixture_connection = None -        facade = Engine._future_facade(config.db) -        config._current.push_engine(facade, testing) -        super_ = super(FutureEngineMixin, cls) -        if hasattr(super_, "setup_class"): -            super_.setup_class() +@contextlib.contextmanager +def _push_future_engine(engine): -    @classmethod -    def teardown_class(cls): -        super_ = super(FutureEngineMixin, cls) -        if hasattr(super_, "teardown_class"): -            super_.teardown_class() +    from ..future.engine import Engine +    from sqlalchemy import testing + +    facade = Engine._future_facade(engine) +    config._current.push_engine(facade, testing) + +    yield facade -        from sqlalchemy import testing +    config._current.pop(testing) -        config._current.pop(testing) + +class FutureEngineMixin(object): +    @config.fixture(autouse=True, scope="class") +    def _push_future_engine(self): +        eng = getattr(self, "bind", None) or config.db +        with _push_future_engine(eng): +            yield  class TablesTest(TestBase): @@ -151,18 +182,32 @@ class TablesTest(TestBase):      other = None      sequences = None -    @property -    def tables_test_metadata(self): -        return self._tables_metadata - -    @classmethod -    def setup_class(cls): +    @config.fixture(autouse=True, scope="class") +    def _setup_tables_test_class(self): +        cls = self.__class__          cls._init_class()          cls._setup_once_tables()          cls._setup_once_inserts() +        yield + +        cls._teardown_once_metadata_bind() + +    @config.fixture(autouse=True, scope="function") +    def _setup_tables_test_instance(self): +        self._setup_each_tables() +        self._setup_each_inserts() + +        yield + +        self._teardown_each_tables() + +    @property +    def tables_test_metadata(self): +        return self._tables_metadata +      @classmethod      def _init_class(cls):          if cls.run_define_tables == "each": @@ -213,10 +258,10 @@ class TablesTest(TestBase):          if self.run_define_tables == "each":              self.tables.clear()              if self.run_create_tables == "each": -                drop_all_tables(self._tables_metadata, self.bind) +                drop_all_tables_from_metadata(self._tables_metadata, self.bind)              self._tables_metadata.clear()          elif self.run_create_tables == "each": -            drop_all_tables(self._tables_metadata, self.bind) +            drop_all_tables_from_metadata(self._tables_metadata, self.bind)          # no need to run deletes if tables are recreated on setup          if ( @@ -242,17 +287,10 @@ class TablesTest(TestBase):                              file=sys.stderr,                          ) -    def setup(self): -        self._setup_each_tables() -        self._setup_each_inserts() - -    def teardown(self): -        self._teardown_each_tables() -      @classmethod      def _teardown_once_metadata_bind(cls):          if cls.run_create_tables: -            drop_all_tables(cls._tables_metadata, cls.bind) +            drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)          if cls.run_dispose_bind == "once":              cls.dispose_bind(cls.bind) @@ -263,10 +301,6 @@ class TablesTest(TestBase):              cls.bind = None      @classmethod -    def teardown_class(cls): -        cls._teardown_once_metadata_bind() - -    @classmethod      def setup_bind(cls):          return config.db @@ -332,38 +366,47 @@ class RemovesEvents(object):          self._event_fns.add((target, name, fn))          event.listen(target, name, fn, **kw) -    def teardown(self): +    @config.fixture(autouse=True, scope="function") +    def _remove_events(self): +        yield          for key in self._event_fns:              event.remove(*key) -        super_ = super(RemovesEvents, self) -        if hasattr(super_, "teardown"): -            super_.teardown() - - -class _ORMTest(object): -    @classmethod -    def teardown_class(cls): -        sa.orm.session.close_all_sessions() -        sa.orm.clear_mappers() -def create_session(**kw): -    kw.setdefault("autoflush", False) -    kw.setdefault("expire_on_commit", False) -    return sa.orm.Session(config.db, **kw) +_fixture_sessions = set()  def fixture_session(**kw):      kw.setdefault("autoflush", True)      kw.setdefault("expire_on_commit", True) -    return sa.orm.Session(config.db, **kw) +    sess = sa.orm.Session(config.db, **kw) +    _fixture_sessions.add(sess) +    return sess + + +def _close_all_sessions(): +    # will close all still-referenced sessions +    sa.orm.session.close_all_sessions() +    _fixture_sessions.clear() + + +def stop_test_class_inside_fixtures(cls): +    _close_all_sessions() +    sa.orm.clear_mappers() -class ORMTest(_ORMTest, TestBase): +def after_test(): + +    if _fixture_sessions: + +        _close_all_sessions() + + +class ORMTest(TestBase):      pass -class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): +class MappedTest(TablesTest, assertions.AssertsExecutionResults):      # 'once', 'each', None      run_setup_classes = "once" @@ -372,8 +415,9 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):      classes = None -    @classmethod -    def setup_class(cls): +    @config.fixture(autouse=True, scope="class") +    def _setup_tables_test_class(self): +        cls = self.__class__          cls._init_class()          if cls.classes is None: @@ -384,18 +428,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):          cls._setup_once_mappers()          cls._setup_once_inserts() -    @classmethod -    def teardown_class(cls): +        yield +          cls._teardown_once_class()          cls._teardown_once_metadata_bind() -    def setup(self): +    @config.fixture(autouse=True, scope="function") +    def _setup_tables_test_instance(self):          self._setup_each_tables()          self._setup_each_classes()          self._setup_each_mappers()          self._setup_each_inserts() -    def teardown(self): +        yield +          sa.orm.session.close_all_sessions()          self._teardown_each_mappers()          self._teardown_each_classes() @@ -404,7 +450,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):      @classmethod      def _teardown_once_class(cls):          cls.classes.clear() -        _ORMTest.teardown_class()      @classmethod      def _setup_once_classes(cls): @@ -440,6 +485,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):          """          cls_registry = cls.classes +        assert cls_registry is not None +          class FindFixture(type):              def __init__(cls, classname, bases, dict_):                  cls_registry[classname] = cls diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py index a95c947e2..1f568dfc8 100644 --- a/lib/sqlalchemy/testing/plugin/bootstrap.py +++ b/lib/sqlalchemy/testing/plugin/bootstrap.py @@ -40,6 +40,11 @@ def load_file_as_module(name):  if to_bootstrap == "pytest":      sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") +    sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True +    if sys.version_info < (3, 0): +        sys.modules["sqla_reinvent_fixtures"] = load_file_as_module( +            "reinvent_fixtures_py2k" +        )      sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")  else:      raise Exception("unknown bootstrap: %s" % to_bootstrap)  # noqa diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 3594cd276..7851fbb3e 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -21,6 +21,9 @@ import logging  import re  import sys +# flag which indicates we are in the SQLAlchemy testing suite, +# and not that of Alembic or a third party dialect. +bootstrapped_as_sqlalchemy = False  log = logging.getLogger("sqlalchemy.testing.plugin_base") @@ -381,7 +384,7 @@ def _init_symbols(options, file_config):  @post  def _set_disable_asyncio(opt, file_config): -    if opt.disable_asyncio: +    if opt.disable_asyncio or not py3k:          from sqlalchemy.testing import asyncio          asyncio.ENABLE_ASYNCIO = False @@ -458,6 +461,8 @@ def _setup_requirements(argument):      config.requirements = testing.requires = req_cls() +    config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy +  @post  def _prep_testing_database(options, file_config): @@ -566,17 +571,22 @@ def generate_sub_tests(cls, module):          yield cls -def start_test_class(cls): +def start_test_class_outside_fixtures(cls):      _do_skips(cls)      _setup_engine(cls)  def stop_test_class(cls): -    # from sqlalchemy import inspect -    # assert not inspect(testing.db).get_table_names() +    # close sessions, immediate connections, etc. +    fixtures.stop_test_class_inside_fixtures(cls) + +    # close outstanding connection pool connections, dispose of +    # additional engines +    engines.testing_reaper.stop_test_class_inside_fixtures() -    provision.stop_test_class(config, config.db, cls) -    engines.testing_reaper._stop_test_ctx() + +def stop_test_class_outside_fixtures(cls): +    provision.stop_test_class_outside_fixtures(config, config.db, cls)      try:          if not options.low_connections:              assertions.global_cleanup_assertions() @@ -590,14 +600,16 @@ def _restore_engine():  def final_process_cleanup(): -    engines.testing_reaper._stop_test_ctx_aggressive() +    engines.testing_reaper.final_cleanup()      assertions.global_cleanup_assertions()      _restore_engine()  def _setup_engine(cls):      if getattr(cls, "__engine_options__", None): -        eng = engines.testing_engine(options=cls.__engine_options__) +        opts = dict(cls.__engine_options__) +        opts["scope"] = "class" +        eng = engines.testing_engine(options=opts)          config._current.push_engine(eng, testing) @@ -614,7 +626,12 @@ def before_test(test, test_module_name, test_class, test_name):  def after_test(test): -    engines.testing_reaper._after_test_ctx() +    fixtures.after_test() +    engines.testing_reaper.after_test() + + +def after_test_fixtures(test): +    engines.testing_reaper.after_test_outside_fixtures(test)  def _possible_configs_for_cls(cls, reasons=None, sparse=False): @@ -748,6 +765,10 @@ class FixtureFunctions(ABC):      def get_current_test_name(self):          raise NotImplementedError() +    @abc.abstractmethod +    def mark_base_test_class(self): +        raise NotImplementedError() +  _fixture_fn_class = None diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 46468a07d..4eaaecebb 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -17,6 +17,7 @@ import sys  import pytest +  try:      import typing  except ImportError: @@ -33,6 +34,14 @@ except ImportError:      has_xdist = False +py2k = sys.version_info < (3, 0) +if py2k: +    try: +        import sqla_reinvent_fixtures as reinvent_fixtures_py2k +    except ImportError: +        from . import reinvent_fixtures_py2k + +  def pytest_addoption(parser):      group = parser.getgroup("sqlalchemy") @@ -238,6 +247,10 @@ def pytest_collection_modifyitems(session, config, items):          else:              newitems.append(item) +    if py2k: +        for item in newitems: +            reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item) +      # seems like the functions attached to a test class aren't sorted already?      # is that true and why's that? (when using unittest, they're sorted)      items[:] = sorted( @@ -251,7 +264,6 @@ def pytest_collection_modifyitems(session, config, items):  def pytest_pycollect_makeitem(collector, name, obj): -      if inspect.isclass(obj) and plugin_base.want_class(name, obj):          from sqlalchemy.testing import config @@ -259,7 +271,6 @@ def pytest_pycollect_makeitem(collector, name, obj):              obj = _apply_maybe_async(obj)          ctor = getattr(pytest.Class, "from_parent", pytest.Class) -          return [              ctor(name=parametrize_cls.__name__, parent=collector)              for parametrize_cls in _parametrize_cls(collector.module, obj) @@ -287,12 +298,11 @@ def _is_wrapped_coroutine_function(fn):  def _apply_maybe_async(obj, recurse=True):      from sqlalchemy.testing import asyncio -    setup_names = {"setup", "setup_class", "teardown", "teardown_class"}      for name, value in vars(obj).items():          if (              (callable(value) or isinstance(value, classmethod))              and not getattr(value, "_maybe_async_applied", False) -            and (name.startswith("test_") or name in setup_names) +            and (name.startswith("test_"))              and not _is_wrapped_coroutine_function(value)          ):              is_classmethod = False @@ -317,9 +327,6 @@ def _apply_maybe_async(obj, recurse=True):      return obj -_current_class = None - -  def _parametrize_cls(module, cls):      """implement a class-based version of pytest parametrize.""" @@ -355,63 +362,153 @@ def _parametrize_cls(module, cls):      return classes +_current_class = None + +  def pytest_runtest_setup(item):      from sqlalchemy.testing import asyncio -    # here we seem to get called only based on what we collected -    # in pytest_collection_modifyitems.   So to do class-based stuff -    # we have to tear that out. -    global _current_class -      if not isinstance(item, pytest.Function):          return -    # ... so we're doing a little dance here to figure it out... +    # pytest_runtest_setup runs *before* pytest fixtures with scope="class". +    # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest +    # for the whole class and has to run things that are across all current +    # databases, so we run this outside of the pytest fixture system altogether +    # and ensure asyncio greenlet if any engines are async + +    global _current_class +      if _current_class is None: -        asyncio._maybe_async(class_setup, item.parent.parent) +        asyncio._maybe_async_provisioning( +            plugin_base.start_test_class_outside_fixtures, +            item.parent.parent.cls, +        )          _current_class = item.parent.parent -        # this is needed for the class-level, to ensure that the -        # teardown runs after the class is completed with its own -        # class-level teardown...          def finalize():              global _current_class -            asyncio._maybe_async(class_teardown, item.parent.parent)              _current_class = None +            asyncio._maybe_async_provisioning( +                plugin_base.stop_test_class_outside_fixtures, +                item.parent.parent.cls, +            ) +          item.parent.parent.addfinalizer(finalize) -    asyncio._maybe_async(test_setup, item) +def pytest_runtest_call(item): +    # runs inside of pytest function fixture scope +    # before test function runs -def pytest_runtest_teardown(item):      from sqlalchemy.testing import asyncio -    # ...but this works better as the hook here rather than -    # using a finalizer, as the finalizer seems to get in the way -    # of the test reporting failures correctly (you get a bunch of -    # pytest assertion stuff instead) -    asyncio._maybe_async(test_teardown, item) +    asyncio._maybe_async( +        plugin_base.before_test, +        item, +        item.parent.module.__name__, +        item.parent.cls, +        item.name, +    ) -def test_setup(item): -    plugin_base.before_test( -        item, item.parent.module.__name__, item.parent.cls, item.name -    ) +def pytest_runtest_teardown(item, nextitem): +    # runs inside of pytest function fixture scope +    # after test function runs +    from sqlalchemy.testing import asyncio -def test_teardown(item): -    plugin_base.after_test(item) +    asyncio._maybe_async(plugin_base.after_test, item) -def class_setup(item): +@pytest.fixture(scope="class") +def setup_class_methods(request):      from sqlalchemy.testing import asyncio -    asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls) +    cls = request.cls + +    if hasattr(cls, "setup_test_class"): +        asyncio._maybe_async(cls.setup_test_class) + +    if py2k: +        reinvent_fixtures_py2k.run_class_fixture_setup(request) + +    yield + +    if py2k: +        reinvent_fixtures_py2k.run_class_fixture_teardown(request) +    if hasattr(cls, "teardown_test_class"): +        asyncio._maybe_async(cls.teardown_test_class) -def class_teardown(item): -    plugin_base.stop_test_class(item.cls) +    asyncio._maybe_async(plugin_base.stop_test_class, cls) + + +@pytest.fixture(scope="function") +def setup_test_methods(request): +    from sqlalchemy.testing import asyncio + +    # called for each test + +    self = request.instance + +    # 1. run outer xdist-style setup +    if hasattr(self, "setup_test"): +        asyncio._maybe_async(self.setup_test) + +    # alembic test suite is using setUp and tearDown +    # xdist methods; support these in the test suite +    # for the near term +    if hasattr(self, "setUp"): +        asyncio._maybe_async(self.setUp) + +    # 2. run homegrown function level "autouse" fixtures under py2k +    if py2k: +        reinvent_fixtures_py2k.run_fn_fixture_setup(request) + +    # inside the yield: + +    # 3. function level "autouse" fixtures under py3k (examples: TablesTest +    #    define tables / data, MappedTest define tables / mappers / data) + +    # 4. function level fixtures defined on test functions themselves, +    #    e.g. "connection", "metadata" run next + +    # 5. pytest hook pytest_runtest_call then runs + +    # 6. test itself runs + +    yield + +    # yield finishes: + +    # 7. pytest hook pytest_runtest_teardown hook runs, this is associated +    #    with fixtures close all sessions, provisioning.stop_test_class(), +    #    engines.testing_reaper -> ensure all connection pool connections +    #    are returned, engines created by testing_engine that aren't the +    #    config engine are disposed + +    # 8. function level fixtures defined on test functions +    #    themselves, e.g. "connection" rolls back the transaction, "metadata" +    #    emits drop all + +    # 9. function level "autouse" fixtures under py3k (examples: TablesTest / +    #    MappedTest delete table data, possibly drop tables and clear mappers +    #    depending on the flags defined by the test class) + +    # 10. run homegrown function-level "autouse" fixtures under py2k +    if py2k: +        reinvent_fixtures_py2k.run_fn_fixture_teardown(request) + +    asyncio._maybe_async(plugin_base.after_test_fixtures, self) + +    # 11. run outer xdist-style teardown +    if hasattr(self, "tearDown"): +        asyncio._maybe_async(self.tearDown) + +    if hasattr(self, "teardown_test"): +        asyncio._maybe_async(self.teardown_test)  def getargspec(fn): @@ -461,6 +558,8 @@ def %(name)s(%(args)s):              # for the wrapped function              decorated.__module__ = fn.__module__              decorated.__name__ = fn.__name__ +            if hasattr(fn, "pytestmark"): +                decorated.pytestmark = fn.pytestmark              return decorated      return decorate @@ -470,6 +569,11 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):      def skip_test_exception(self, *arg, **kw):          return pytest.skip.Exception(*arg, **kw) +    def mark_base_test_class(self): +        return pytest.mark.usefixtures( +            "setup_class_methods", "setup_test_methods" +        ) +      _combination_id_fns = {          "i": lambda obj: obj,          "r": repr, @@ -647,8 +751,18 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):                  fn = asyncio._maybe_async_wrapper(fn)              # other wrappers may be added here -            # now apply FixtureFunctionMarker -            fn = fixture(fn) +            if py2k and "autouse" in kw: +                # py2k workaround for too-slow collection of autouse fixtures +                # in pytest 4.6.11.  See notes in reinvent_fixtures_py2k for +                # rationale. + +                # comment this condition out in order to disable the +                # py2k workaround entirely. +                reinvent_fixtures_py2k.add_fixture(fn, fixture) +            else: +                # now apply FixtureFunctionMarker +                fn = fixture(fn) +              return fn          if fn: diff --git a/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py new file mode 100644 index 000000000..36b68417b --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py @@ -0,0 +1,112 @@ +""" +invent a quick version of pytest autouse fixtures as pytest's unacceptably slow +collection/high memory use in pytest 4.6.11, which is the highest version that +works in py2k. + +by "too-slow" we mean the test suite can't even manage to be collected for a +single process in less than 70 seconds or so and memory use seems to be very +high as well.   for two or four workers the job just times out after ten +minutes. + +so instead we have invented a very limited form of these fixtures, as our +current use of "autouse" fixtures are limited to those in fixtures.py. + +assumptions for these fixtures: + +1. we are only using "function" or "class" scope + +2. the functions must be associated with a test class + +3. the fixture functions cannot themselves use pytest fixtures + +4. the fixture functions must use yield, not return + +When py2k support is removed and we can stay on a modern pytest version, this +can all be removed. + + +""" +import collections + + +_py2k_fixture_fn_names = collections.defaultdict(set) +_py2k_class_fixtures = collections.defaultdict( +    lambda: collections.defaultdict(set) +) +_py2k_function_fixtures = collections.defaultdict( +    lambda: collections.defaultdict(set) +) + +_py2k_cls_fixture_stack = [] +_py2k_fn_fixture_stack = [] + + +def add_fixture(fn, fixture): +    assert fixture.scope in ("class", "function") +    _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope)) + + +def scan_for_fixtures_to_use_for_class(item): +    test_class = item.parent.parent.obj + +    for name in _py2k_fixture_fn_names: +        for fixture_fn, scope in _py2k_fixture_fn_names[name]: +            meth = getattr(test_class, name, None) +            if meth and meth.im_func is fixture_fn: +                for sup in test_class.__mro__: +                    if name in sup.__dict__: +                        if scope == "class": +                            _py2k_class_fixtures[test_class][sup].add(meth) +                        elif scope == "function": +                            _py2k_function_fixtures[test_class][sup].add(meth) +                        break +                break + + +def run_class_fixture_setup(request): + +    cls = request.cls +    self = cls.__new__(cls) + +    fixtures_for_this_class = _py2k_class_fixtures.get(cls) + +    if fixtures_for_this_class: +        for sup_ in cls.__mro__: +            for fn in fixtures_for_this_class.get(sup_, ()): +                iter_ = fn(self) +                next(iter_) + +                _py2k_cls_fixture_stack.append(iter_) + + +def run_class_fixture_teardown(request): +    while _py2k_cls_fixture_stack: +        iter_ = _py2k_cls_fixture_stack.pop(-1) +        try: +            next(iter_) +        except StopIteration: +            pass + + +def run_fn_fixture_setup(request): +    cls = request.cls +    self = request.instance + +    fixtures_for_this_class = _py2k_function_fixtures.get(cls) + +    if fixtures_for_this_class: +        for sup_ in reversed(cls.__mro__): +            for fn in fixtures_for_this_class.get(sup_, ()): +                iter_ = fn(self) +                next(iter_) + +                _py2k_fn_fixture_stack.append(iter_) + + +def run_fn_fixture_teardown(request): +    while _py2k_fn_fixture_stack: +        iter_ = _py2k_fn_fixture_stack.pop(-1) +        try: +            next(iter_) +        except StopIteration: +            pass diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 4ee0567f2..2fade1c32 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -67,6 +67,7 @@ def setup_config(db_url, options, file_config, follower_ident):          db_url = follower_url_from_main(db_url, follower_ident)      db_opts = {}      update_db_opts(db_url, db_opts) +    db_opts["scope"] = "global"      eng = engines.testing_engine(db_url, db_opts)      post_configure_engine(db_url, eng, follower_ident)      eng.connect().close() @@ -264,6 +265,7 @@ def drop_all_schema_objects(cfg, eng):      if config.requirements.schemas.enabled_for_config(cfg):          util.drop_all_tables(eng, inspector, schema=cfg.test_schema) +        util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)      drop_all_schema_objects_post_tables(cfg, eng) @@ -299,7 +301,7 @@ def update_db_opts(db_url, db_opts):  def post_configure_engine(url, engine, follower_ident):      """Perform extra steps after configuring an engine for testing. -    (For the internal dialects, currently only used by sqlite.) +    (For the internal dialects, currently only used by sqlite, oracle)      """      pass @@ -375,7 +377,12 @@ def temp_table_keyword_args(cfg, eng):  @register.init -def stop_test_class(config, db, testcls): +def prepare_for_drop_tables(config, connection): +    pass + + +@register.init +def stop_test_class_outside_fixtures(config, db, testcls):      pass diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 6c3c1005a..de157d028 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -293,7 +293,7 @@ class ComponentReflectionTest(fixtures.TablesTest):              from sqlalchemy import pool              return engines.testing_engine( -                options=dict(poolclass=pool.StaticPool) +                options=dict(poolclass=pool.StaticPool, scope="class"),              )          else:              return config.db diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index e0fdbe47a..e8dd6cf2c 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -261,10 +261,6 @@ class ServerSideCursorsTest(              )          return self.engine -    def tearDown(self): -        engines.testing_reaper.close_all() -        self.engine.dispose() -      @testing.combinations(          ("global_string", True, "select 1", True),          ("global_text", True, text("select 1"), True), @@ -309,24 +305,22 @@ class ServerSideCursorsTest(      def test_conn_option(self):          engine = self._fixture(False) -        # should be enabled for this one -        result = ( -            engine.connect() -            .execution_options(stream_results=True) -            .exec_driver_sql("select 1") -        ) -        assert self._is_server_side(result.cursor) +        with engine.connect() as conn: +            # should be enabled for this one +            result = conn.execution_options( +                stream_results=True +            ).exec_driver_sql("select 1") +            assert self._is_server_side(result.cursor)      def test_stmt_enabled_conn_option_disabled(self):          engine = self._fixture(False)          s = select(1).execution_options(stream_results=True) -        # not this one -        result = ( -            engine.connect().execution_options(stream_results=False).execute(s) -        ) -        assert not self._is_server_side(result.cursor) +        with engine.connect() as conn: +            # not this one +            result = conn.execution_options(stream_results=False).execute(s) +            assert not self._is_server_side(result.cursor)      def test_aliases_and_ss(self):          engine = self._fixture(False) @@ -344,8 +338,7 @@ class ServerSideCursorsTest(              assert not self._is_server_side(result.cursor)              result.close() -    @testing.provide_metadata -    def test_roundtrip_fetchall(self): +    def test_roundtrip_fetchall(self, metadata):          md = self.metadata          engine = self._fixture(True) @@ -385,8 +378,7 @@ class ServerSideCursorsTest(                  0,              ) -    @testing.provide_metadata -    def test_roundtrip_fetchmany(self): +    def test_roundtrip_fetchmany(self, metadata):          md = self.metadata          engine = self._fixture(True) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 3a5e02c32..ebcceaae7 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -511,24 +511,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):      __backend__ = True      @testing.fixture -    def do_numeric_test(self, metadata): +    def do_numeric_test(self, metadata, connection):          @testing.emits_warning(              r".*does \*not\* support Decimal objects natively"          )          def run(type_, input_, output, filter_=None, check_scale=False):              t = Table("t", metadata, Column("x", type_)) -            t.create(testing.db) -            with config.db.begin() as conn: -                conn.execute(t.insert(), [{"x": x} for x in input_]) - -                result = {row[0] for row in conn.execute(t.select())} -                output = set(output) -                if filter_: -                    result = set(filter_(x) for x in result) -                    output = set(filter_(x) for x in output) -                eq_(result, output) -                if check_scale: -                    eq_([str(x) for x in result], [str(x) for x in output]) +            t.create(connection) +            connection.execute(t.insert(), [{"x": x} for x in input_]) + +            result = {row[0] for row in connection.execute(t.select())} +            output = set(output) +            if filter_: +                result = set(filter_(x) for x in result) +                output = set(filter_(x) for x in output) +            eq_(result, output) +            if check_scale: +                eq_([str(x) for x in result], [str(x) for x in output])          return run @@ -1165,40 +1164,39 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):                  },              ) -    def test_eval_none_flag_orm(self): +    def test_eval_none_flag_orm(self, connection):          Base = declarative_base()          class Data(Base):              __table__ = self.tables.data_table -        s = Session(testing.db) +        with Session(connection) as s: +            d1 = Data(name="d1", data=None, nulldata=None) +            s.add(d1) +            s.commit() -        d1 = Data(name="d1", data=None, nulldata=None) -        s.add(d1) -        s.commit() - -        s.bulk_insert_mappings( -            Data, [{"name": "d2", "data": None, "nulldata": None}] -        ) -        eq_( -            s.query( -                cast(self.tables.data_table.c.data, String()), -                cast(self.tables.data_table.c.nulldata, String), +            s.bulk_insert_mappings( +                Data, [{"name": "d2", "data": None, "nulldata": None}]              ) -            .filter(self.tables.data_table.c.name == "d1") -            .first(), -            ("null", None), -        ) -        eq_( -            s.query( -                cast(self.tables.data_table.c.data, String()), -                cast(self.tables.data_table.c.nulldata, String), +            eq_( +                s.query( +                    cast(self.tables.data_table.c.data, String()), +                    cast(self.tables.data_table.c.nulldata, String), +                ) +                .filter(self.tables.data_table.c.name == "d1") +                .first(), +                ("null", None), +            ) +            eq_( +                s.query( +                    cast(self.tables.data_table.c.data, String()), +                    cast(self.tables.data_table.c.nulldata, String), +                ) +                .filter(self.tables.data_table.c.name == "d2") +                .first(), +                ("null", None),              ) -            .filter(self.tables.data_table.c.name == "d2") -            .first(), -            ("null", None), -        )  class JSONLegacyStringCastIndexTest( diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index eb9fcd1cd..01185c284 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -14,6 +14,7 @@ import types  from . import config  from . import mock  from .. import inspect +from ..engine import Connection  from ..schema import Column  from ..schema import DropConstraint  from ..schema import DropTable @@ -207,11 +208,13 @@ def fail(msg):  @decorator  def provide_metadata(fn, *args, **kw): -    """Provide bound MetaData for a single test, dropping afterwards.""" +    """Provide bound MetaData for a single test, dropping afterwards. -    # import cycle that only occurs with py2k's import resolver -    # in py3k this can be moved top level. -    from . import engines +    Legacy; use the "metadata" pytest fixture. + +    """ + +    from . import fixtures      metadata = schema.MetaData()      self = args[0] @@ -220,7 +223,31 @@ def provide_metadata(fn, *args, **kw):      try:          return fn(*args, **kw)      finally: -        engines.drop_all_tables(metadata, config.db) +        # close out some things that get in the way of dropping tables. +        # when using the "metadata" fixture, there is a set ordering +        # of things that makes sure things are cleaned up in order, however +        # the simple "decorator" nature of this legacy function means +        # we have to hardcode some of that cleanup ahead of time. + +        # close ORM sessions +        fixtures._close_all_sessions() + +        # integrate with the "connection" fixture as there are many +        # tests where it is used along with provide_metadata +        if fixtures._connection_fixture_connection: +            # TODO: this warning can be used to find all the places +            # this is used with connection fixture +            # warn("mixing legacy provide metadata with connection fixture") +            drop_all_tables_from_metadata( +                metadata, fixtures._connection_fixture_connection +            ) +            # as the provide_metadata fixture is often used with "testing.db", +            # when we do the drop we have to commit the transaction so that +            # the DB is actually updated as the CREATE would have been +            # committed +            fixtures._connection_fixture_connection.get_transaction().commit() +        else: +            drop_all_tables_from_metadata(metadata, config.db)          self.metadata = prev_meta @@ -359,6 +386,29 @@ class adict(dict):      get_all = __call__ +def drop_all_tables_from_metadata(metadata, engine_or_connection): +    from . import engines + +    def go(connection): +        engines.testing_reaper.prepare_for_drop_tables(connection) + +        if not connection.dialect.supports_alter: +            from . import assertions + +            with assertions.expect_warnings( +                "Can't sort tables", assert_=False +            ): +                metadata.drop_all(connection) +        else: +            metadata.drop_all(connection) + +    if not isinstance(engine_or_connection, Connection): +        with engine_or_connection.begin() as connection: +            go(connection) +    else: +        go(engine_or_connection) + +  def drop_all_tables(engine, inspector, schema=None, include_names=None):      if include_names is not None: diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index 99ecb4fb3..ca5a3abde 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -230,13 +230,16 @@ class AsyncAdaptedQueue:              return self.put_nowait(item)          try: -            if timeout: +            if timeout is not None:                  return self.await_(                      asyncio.wait_for(self._queue.put(item), timeout)                  )              else:                  return self.await_(self._queue.put(item)) -        except asyncio.queues.QueueFull as err: +        except ( +            asyncio.queues.QueueFull, +            asyncio.exceptions.TimeoutError, +        ) as err:              compat.raise_(                  Full(),                  replace_context=err, @@ -254,14 +257,18 @@ class AsyncAdaptedQueue:      def get(self, block=True, timeout=None):          if not block:              return self.get_nowait() +          try: -            if timeout: +            if timeout is not None:                  return self.await_(                      asyncio.wait_for(self._queue.get(), timeout)                  )              else:                  return self.await_(self._queue.get()) -        except asyncio.queues.QueueEmpty as err: +        except ( +            asyncio.queues.QueueEmpty, +            asyncio.exceptions.TimeoutError, +        ) as err:              compat.raise_(                  Empty(),                  replace_context=err,  | 
