summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py23
-rw-r--r--lib/sqlalchemy/dialects/mssql/provision.py34
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py4
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py1
-rw-r--r--lib/sqlalchemy/dialects/oracle/provision.py42
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py25
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py14
-rw-r--r--lib/sqlalchemy/dialects/postgresql/provision.py23
-rw-r--r--lib/sqlalchemy/dialects/sqlite/provision.py6
-rw-r--r--lib/sqlalchemy/engine/base.py21
-rw-r--r--lib/sqlalchemy/engine/create.py9
-rw-r--r--lib/sqlalchemy/exc.py6
-rw-r--r--lib/sqlalchemy/ext/compiler.py16
-rw-r--r--lib/sqlalchemy/future/engine.py15
-rw-r--r--lib/sqlalchemy/orm/__init__.py1
-rw-r--r--lib/sqlalchemy/orm/session.py12
-rw-r--r--lib/sqlalchemy/pool/base.py84
-rw-r--r--lib/sqlalchemy/sql/base.py2
-rw-r--r--lib/sqlalchemy/sql/elements.py1
-rw-r--r--lib/sqlalchemy/testing/__init__.py1
-rw-r--r--lib/sqlalchemy/testing/assertions.py4
-rw-r--r--lib/sqlalchemy/testing/config.py4
-rw-r--r--lib/sqlalchemy/testing/engines.py199
-rw-r--r--lib/sqlalchemy/testing/fixtures.py251
-rw-r--r--lib/sqlalchemy/testing/plugin/bootstrap.py5
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py40
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py188
-rw-r--r--lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py112
-rw-r--r--lib/sqlalchemy/testing/provision.py11
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py24
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py12
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py32
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py8
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py8
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py74
-rw-r--r--lib/sqlalchemy/testing/suite/test_update_delete.py4
-rw-r--r--lib/sqlalchemy/testing/util.py60
-rw-r--r--lib/sqlalchemy/testing/warnings.py3
-rw-r--r--lib/sqlalchemy/util/queue.py15
39 files changed, 991 insertions, 403 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 538679fcf..9d0e5d322 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -9,6 +9,16 @@
:name: Microsoft SQL Server
+.. _mssql_external_dialects:
+
+External Dialects
+-----------------
+
+In addition to the above DBAPI layers with native SQLAlchemy support, there
+are third-party dialects for other DBAPI layers that are compatible
+with SQL Server. See the "External Dialects" list on the
+:ref:`dialect_toplevel` page.
+
.. _mssql_identity:
Auto Increment Behavior / IDENTITY Columns
@@ -2785,15 +2795,14 @@ class MSDialect(default.DefaultDialect):
def has_table(self, connection, tablename, dbname, owner, schema):
if tablename.startswith("#"): # temporary table
tables = ischema.mssql_temp_table_columns
- result = connection.execute(
- sql.select(tables.c.table_name)
- .where(
- tables.c.table_name.like(
- self._temp_table_name_like_pattern(tablename)
- )
+
+ s = sql.select(tables.c.table_name).where(
+ tables.c.table_name.like(
+ self._temp_table_name_like_pattern(tablename)
)
- .limit(1)
)
+
+ result = connection.execute(s.limit(1))
return result.scalar() is not None
else:
tables = ischema.tables
diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py
index 269eb164f..56f3305a7 100644
--- a/lib/sqlalchemy/dialects/mssql/provision.py
+++ b/lib/sqlalchemy/dialects/mssql/provision.py
@@ -1,6 +1,14 @@
+from sqlalchemy import inspect
+from sqlalchemy import Integer
from ... import create_engine
from ... import exc
+from ...schema import Column
+from ...schema import DropConstraint
+from ...schema import ForeignKeyConstraint
+from ...schema import MetaData
+from ...schema import Table
from ...testing.provision import create_db
+from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import get_temp_table_name
from ...testing.provision import log
@@ -38,7 +46,6 @@ def _mssql_drop_ignore(conn, ident):
# "where database_id=db_id('%s')" % ident):
# log.info("killing SQL server session %s", row['session_id'])
# conn.exec_driver_sql("kill %s" % row['session_id'])
-
conn.exec_driver_sql("drop database %s" % ident)
log.info("Reaped db: %s", ident)
return True
@@ -83,4 +90,27 @@ def _mssql_temp_table_keyword_args(cfg, eng):
@get_temp_table_name.for_db("mssql")
def _mssql_get_temp_table_name(cfg, eng, base_name):
- return "#" + base_name
+ return "##" + base_name
+
+
+@drop_all_schema_objects_pre_tables.for_db("mssql")
+def drop_all_schema_objects_pre_tables(cfg, eng):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ inspector = inspect(conn)
+ for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2):
+ for tname in inspector.get_table_names(schema=schema):
+ tb = Table(
+ tname,
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ schema=schema,
+ )
+ for fk in inspect(conn).get_foreign_keys(tname, schema=schema):
+ conn.execute(
+ DropConstraint(
+ ForeignKeyConstraint(
+ [tb.c.x], [tb.c.y], name=fk["name"]
+ )
+ )
+ )
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 1cc8b7aef..a496f0354 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -1731,7 +1731,7 @@ class OracleDialect(default.DefaultDialect):
sql.text(
"SELECT username FROM user_db_links " "WHERE db_link=:link"
),
- link=dblink,
+ dict(link=dblink),
)
dblink = "@" + dblink
elif not owner:
@@ -1805,7 +1805,7 @@ class OracleDialect(default.DefaultDialect):
"SELECT sequence_name FROM all_sequences "
"WHERE sequence_owner = :schema_name"
),
- schema_name=self.denormalize_name(schema),
+ dict(schema_name=self.denormalize_name(schema)),
)
return [self.normalize_name(row[0]) for row in cursor]
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index 042443692..b8b4df760 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -93,6 +93,7 @@ The parameters accepted by the cx_oracle dialect are as follows:
* ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail.
+
.. _cx_oracle_unicode:
Unicode
diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py
index d51131c0b..e0dadd58e 100644
--- a/lib/sqlalchemy/dialects/oracle/provision.py
+++ b/lib/sqlalchemy/dialects/oracle/provision.py
@@ -6,11 +6,11 @@ from ...testing.provision import create_db
from ...testing.provision import drop_db
from ...testing.provision import follower_url_from_main
from ...testing.provision import log
+from ...testing.provision import post_configure_engine
from ...testing.provision import run_reap_dbs
from ...testing.provision import set_default_schema_on_connection
-from ...testing.provision import stop_test_class
+from ...testing.provision import stop_test_class_outside_fixtures
from ...testing.provision import temp_table_keyword_args
-from ...testing.provision import update_db_opts
@create_db.for_db("oracle")
@@ -57,21 +57,39 @@ def _oracle_drop_db(cfg, eng, ident):
_ora_drop_ignore(conn, "%s_ts2" % ident)
-@update_db_opts.for_db("oracle")
-def _oracle_update_db_opts(db_url, db_opts):
- pass
+@stop_test_class_outside_fixtures.for_db("oracle")
+def stop_test_class_outside_fixtures(config, db, cls):
+ with db.begin() as conn:
+ # run magic command to get rid of identity sequences
+ # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa E501
+ conn.exec_driver_sql("purge recyclebin")
-@stop_test_class.for_db("oracle")
-def stop_test_class(config, db, cls):
- """run magic command to get rid of identity sequences
+ # clear statement cache on all connections that were used
+ # https://github.com/oracle/python-cx_Oracle/issues/519
- # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/
+ for cx_oracle_conn in _all_conns:
+ try:
+ sc = cx_oracle_conn.stmtcachesize
+ except db.dialect.dbapi.InterfaceError:
+ # connection closed
+ pass
+ else:
+ cx_oracle_conn.stmtcachesize = 0
+ cx_oracle_conn.stmtcachesize = sc
+ _all_conns.clear()
- """
- with db.begin() as conn:
- conn.exec_driver_sql("purge recyclebin")
+_all_conns = set()
+
+
+@post_configure_engine.for_db("oracle")
+def _oracle_post_configure_engine(url, engine, follower_ident):
+ from sqlalchemy import event
+
+ @event.listens_for(engine, "checkout")
+ def checkout(dbapi_con, con_record, con_proxy):
+ _all_conns.add(dbapi_con)
@run_reap_dbs.for_db("oracle")
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index 7c6e8fb02..424ed0d50 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -615,6 +615,10 @@ class AsyncAdapt_asyncpg_connection:
return prepared_stmt, attributes
def _handle_exception(self, error):
+ if self._connection.is_closed():
+ self._transaction = None
+ self._started = False
+
if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
exception_mapping = self.dbapi._asyncpg_error_translate
@@ -669,16 +673,23 @@ class AsyncAdapt_asyncpg_connection:
def rollback(self):
if self._started:
- self.await_(self._transaction.rollback())
-
- self._transaction = None
- self._started = False
+ try:
+ self.await_(self._transaction.rollback())
+ except Exception as error:
+ self._handle_exception(error)
+ finally:
+ self._transaction = None
+ self._started = False
def commit(self):
if self._started:
- self.await_(self._transaction.commit())
- self._transaction = None
- self._started = False
+ try:
+ self.await_(self._transaction.commit())
+ except Exception as error:
+ self._handle_exception(error)
+ finally:
+ self._transaction = None
+ self._started = False
def close(self):
self.rollback()
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 735990a20..7a898cb8a 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -3335,7 +3335,11 @@ class PGDialect(default.DefaultDialect):
"WHERE n.nspname = :schema AND c.relkind IN (%s)"
% (", ".join("'%s'" % elem for elem in kinds))
).columns(relname=sqltypes.Unicode),
- schema=schema if schema is not None else self.default_schema_name,
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name
+ ),
)
return [name for name, in result]
@@ -3367,8 +3371,12 @@ class PGDialect(default.DefaultDialect):
"WHERE n.nspname = :schema AND c.relname = :view_name "
"AND c.relkind IN ('v', 'm')"
).columns(view_def=sqltypes.Unicode),
- schema=schema if schema is not None else self.default_schema_name,
- view_name=view_name,
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name,
+ view_name=view_name,
+ ),
)
return view_def
diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py
index d345cdfdf..9196337ba 100644
--- a/lib/sqlalchemy/dialects/postgresql/provision.py
+++ b/lib/sqlalchemy/dialects/postgresql/provision.py
@@ -8,6 +8,7 @@ from ...testing.provision import drop_all_schema_objects_post_tables
from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import log
+from ...testing.provision import prepare_for_drop_tables
from ...testing.provision import set_default_schema_on_connection
from ...testing.provision import temp_table_keyword_args
@@ -61,7 +62,7 @@ def _pg_drop_db(cfg, eng, ident):
"where usename=current_user and pid != pg_backend_pid() "
"and datname=:dname"
),
- dname=ident,
+ dict(dname=ident),
)
conn.exec_driver_sql("DROP DATABASE %s" % ident)
@@ -102,3 +103,23 @@ def drop_all_schema_objects_post_tables(cfg, eng):
postgresql.ENUM(name=enum["name"], schema=enum["schema"])
)
)
+
+
+@prepare_for_drop_tables.for_db("postgresql")
+def prepare_for_drop_tables(config, connection):
+ """Ensure there are no locks on the current username/database."""
+
+ result = connection.exec_driver_sql(
+ "select pid, state, wait_event_type, query "
+ # "select pg_terminate_backend(pid), state, wait_event_type "
+ "from pg_stat_activity where "
+ "usename=current_user "
+ "and datname=current_database() and state='idle in transaction' "
+ "and pid != pg_backend_pid()"
+ )
+ rows = result.all() # noqa
+ assert not rows, (
+ "PostgreSQL may not be able to DROP tables due to "
+ "idle in transaction: %s"
+ % ("; ".join(row._mapping["query"] for row in rows))
+ )
diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py
index f26c21e22..a481be27e 100644
--- a/lib/sqlalchemy/dialects/sqlite/provision.py
+++ b/lib/sqlalchemy/dialects/sqlite/provision.py
@@ -7,7 +7,7 @@ from ...testing.provision import follower_url_from_main
from ...testing.provision import log
from ...testing.provision import post_configure_engine
from ...testing.provision import run_reap_dbs
-from ...testing.provision import stop_test_class
+from ...testing.provision import stop_test_class_outside_fixtures
from ...testing.provision import temp_table_keyword_args
@@ -57,8 +57,8 @@ def _sqlite_drop_db(cfg, eng, ident):
os.remove(path)
-@stop_test_class.for_db("sqlite")
-def stop_test_class(config, db, cls):
+@stop_test_class_outside_fixtures.for_db("sqlite")
+def stop_test_class_outside_fixtures(config, db, cls):
with db.connect() as conn:
files = [
row.file
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 50f00c025..31bf885db 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -813,10 +813,11 @@ class Connection(Connectable):
if self._echo:
self.engine.logger.info("BEGIN (implicit)")
+ self.__in_begin = True
+
if self._has_events or self.engine._has_events:
self.dispatch.begin(self)
- self.__in_begin = True
try:
self.engine.dialect.do_begin(self.connection)
except BaseException as e:
@@ -2729,14 +2730,16 @@ class Engine(Connectable, log.Identified):
return self.conn
def __exit__(self, type_, value, traceback):
-
- if type_ is not None:
- self.transaction.rollback()
- else:
- if self.transaction.is_active:
- self.transaction.commit()
- if not self.close_with_result:
- self.conn.close()
+ try:
+ if type_ is not None:
+ if self.transaction.is_active:
+ self.transaction.rollback()
+ else:
+ if self.transaction.is_active:
+ self.transaction.commit()
+ finally:
+ if not self.close_with_result:
+ self.conn.close()
def begin(self, close_with_result=False):
"""Return a context manager delivering a :class:`_engine.Connection`
diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py
index f89be1809..72d232085 100644
--- a/lib/sqlalchemy/engine/create.py
+++ b/lib/sqlalchemy/engine/create.py
@@ -655,9 +655,12 @@ def create_engine(url, **kwargs):
c = base.Connection(
engine, connection=dbapi_connection, _has_events=False
)
- c._execution_options = util.immutabledict()
- dialect.initialize(c)
- dialect.do_rollback(c.connection)
+ c._execution_options = util.EMPTY_DICT
+
+ try:
+ dialect.initialize(c)
+ finally:
+ dialect.do_rollback(c.connection)
# previously, the "first_connect" event was used here, which was then
# scaled back if the "on_connect" handler were present. now,
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py
index b031c1610..08b1bb060 100644
--- a/lib/sqlalchemy/exc.py
+++ b/lib/sqlalchemy/exc.py
@@ -179,10 +179,10 @@ class UnsupportedCompilationError(CompileError):
code = "l7de"
- def __init__(self, compiler, element_type):
+ def __init__(self, compiler, element_type, message=None):
super(UnsupportedCompilationError, self).__init__(
- "Compiler %r can't render element of type %s"
- % (compiler, element_type)
+ "Compiler %r can't render element of type %s%s"
+ % (compiler, element_type, ": %s" % message if message else "")
)
diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py
index 5a31173ec..47fb6720b 100644
--- a/lib/sqlalchemy/ext/compiler.py
+++ b/lib/sqlalchemy/ext/compiler.py
@@ -425,9 +425,11 @@ def compiles(class_, *specs):
return existing_dispatch(element, compiler, **kw)
except exc.UnsupportedCompilationError as uce:
util.raise_(
- exc.CompileError(
- "%s construct has no default "
- "compilation handler." % type(element)
+ exc.UnsupportedCompilationError(
+ compiler,
+ type(element),
+ message="%s construct has no default "
+ "compilation handler." % type(element),
),
from_=uce,
)
@@ -476,9 +478,11 @@ class _dispatcher(object):
fn = self.specs["default"]
except KeyError as ke:
util.raise_(
- exc.CompileError(
- "%s construct has no default "
- "compilation handler." % type(element)
+ exc.UnsupportedCompilationError(
+ compiler,
+ type(element),
+ message="%s construct has no default "
+ "compilation handler." % type(element),
),
replace_context=ke,
)
diff --git a/lib/sqlalchemy/future/engine.py b/lib/sqlalchemy/future/engine.py
index d2f609326..bfdcdfc7f 100644
--- a/lib/sqlalchemy/future/engine.py
+++ b/lib/sqlalchemy/future/engine.py
@@ -368,12 +368,15 @@ class Engine(_LegacyEngine):
return self.conn
def __exit__(self, type_, value, traceback):
- if type_ is not None:
- self.transaction.rollback()
- else:
- if self.transaction.is_active:
- self.transaction.commit()
- self.conn.close()
+ try:
+ if type_ is not None:
+ if self.transaction.is_active:
+ self.transaction.rollback()
+ else:
+ if self.transaction.is_active:
+ self.transaction.commit()
+ finally:
+ self.conn.close()
def begin(self):
"""Return a :class:`_future.Connection` object with a transaction
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 2b7ad7bbd..f6b1a2e93 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -264,6 +264,7 @@ def clear_mappers():
upon a fixed set of classes.
"""
+
with mapperlib._CONFIGURE_MUTEX:
while _mapper_registry:
try:
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index e8312f393..8a7a64f78 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -531,6 +531,10 @@ class SessionTransaction(object):
self._take_snapshot(autobegin=autobegin)
+ # make sure transaction is assigned before we call the
+ # dispatch
+ self.session._transaction = self
+
self.session.dispatch.after_transaction_create(self.session, self)
@property
@@ -1242,7 +1246,8 @@ class Session(_SessionClassMethods):
def _autobegin(self):
if not self.autocommit and self._transaction is None:
- self._transaction = SessionTransaction(self, autobegin=True)
+ trans = SessionTransaction(self, autobegin=True)
+ assert self._transaction is trans
return True
return False
@@ -1299,7 +1304,7 @@ class Session(_SessionClassMethods):
if self._transaction is not None:
if subtransactions or _subtrans or nested:
trans = self._transaction._begin(nested=nested)
- self._transaction = trans
+ assert self._transaction is trans
if nested:
self._nested_transaction = trans
else:
@@ -1307,7 +1312,8 @@ class Session(_SessionClassMethods):
"A transaction is already begun on this Session."
)
else:
- self._transaction = SessionTransaction(self, nested=nested)
+ trans = SessionTransaction(self, nested=nested)
+ assert self._transaction is trans
return self._transaction # needed for __enter__/__exit__ hook
def begin_nested(self):
diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py
index 7c9509e45..47d9e2cba 100644
--- a/lib/sqlalchemy/pool/base.py
+++ b/lib/sqlalchemy/pool/base.py
@@ -18,7 +18,6 @@ from .. import event
from .. import exc
from .. import log
from .. import util
-from ..util import threading
reset_rollback = util.symbol("reset_rollback")
@@ -172,7 +171,6 @@ class Pool(log.Identified):
self._orig_logging_name = None
log.instance_logger(self, echoflag=echo)
- self._threadconns = threading.local()
self._creator = creator
self._recycle = recycle
self._invalidate_time = 0
@@ -423,26 +421,37 @@ class _ConnectionRecord(object):
dbapi_connection = rec.get_connection()
except Exception as err:
with util.safe_reraise():
- rec._checkin_failed(err)
+ rec._checkin_failed(err, _fairy_was_created=False)
echo = pool._should_log_debug()
fairy = _ConnectionFairy(dbapi_connection, rec, echo)
- rec.fairy_ref = weakref.ref(
+
+ rec.fairy_ref = ref = weakref.ref(
fairy,
lambda ref: _finalize_fairy
and _finalize_fairy(None, rec, pool, ref, echo),
)
+ _strong_ref_connection_records[ref] = rec
if echo:
pool.logger.debug(
"Connection %r checked out from pool", dbapi_connection
)
return fairy
- def _checkin_failed(self, err):
+ def _checkin_failed(self, err, _fairy_was_created=True):
self.invalidate(e=err)
- self.checkin(_no_fairy_ref=True)
+ self.checkin(
+ _fairy_was_created=_fairy_was_created,
+ )
- def checkin(self, _no_fairy_ref=False):
- if self.fairy_ref is None and not _no_fairy_ref:
+ def checkin(self, _fairy_was_created=True):
+ if self.fairy_ref is None and _fairy_was_created:
+ # _fairy_was_created is False for the initial get connection phase;
+ # meaning there was no _ConnectionFairy and we must unconditionally
+ # do a checkin.
+ #
+ # otherwise, if fairy_was_created==True, if fairy_ref is None here
+ # that means we were checked in already, so this looks like
+ # a double checkin.
util.warn("Double checkin attempted on %s" % self)
return
self.fairy_ref = None
@@ -453,6 +462,7 @@ class _ConnectionRecord(object):
finalizer(connection)
if pool.dispatch.checkin:
pool.dispatch.checkin(connection, self)
+
pool._return_conn(self)
@property
@@ -603,12 +613,26 @@ def _finalize_fairy(
"""
+ if ref:
+ _strong_ref_connection_records.pop(ref, None)
+ elif fairy:
+ _strong_ref_connection_records.pop(weakref.ref(fairy), None)
+
if ref is not None:
if connection_record.fairy_ref is not ref:
return
assert connection is None
connection = connection_record.connection
+ dont_restore_gced = pool._is_asyncio
+
+ if dont_restore_gced:
+ detach = not connection_record or ref
+ can_manipulate_connection = not ref
+ else:
+ detach = not connection_record
+ can_manipulate_connection = True
+
if connection is not None:
if connection_record and echo:
pool.logger.debug(
@@ -620,13 +644,26 @@ def _finalize_fairy(
connection, connection_record, echo
)
assert fairy.connection is connection
- fairy._reset(pool)
+ if can_manipulate_connection:
+ fairy._reset(pool)
+
+ if detach:
+ if connection_record:
+ fairy._pool = pool
+ fairy.detach()
+
+ if can_manipulate_connection:
+ if pool.dispatch.close_detached:
+ pool.dispatch.close_detached(connection)
+
+ pool._close_connection(connection)
+ else:
+ util.warn(
+ "asyncio connection is being garbage "
+ "collected without being properly closed: %r"
+ % connection
+ )
- # Immediately close detached instances
- if not connection_record:
- if pool.dispatch.close_detached:
- pool.dispatch.close_detached(connection)
- pool._close_connection(connection)
except BaseException as e:
pool.logger.error(
"Exception during reset or similar", exc_info=True
@@ -640,6 +677,13 @@ def _finalize_fairy(
connection_record.checkin()
+# a dictionary of the _ConnectionFairy weakrefs to _ConnectionRecord, so that
+# GC under pypy will call ConnectionFairy finalizers. linked directly to the
+# weakref that will empty itself when collected so that it should not create
+# any unmanaged memory references.
+_strong_ref_connection_records = {}
+
+
class _ConnectionFairy(object):
"""Proxies a DBAPI connection and provides return-on-dereference
@@ -774,7 +818,17 @@ class _ConnectionFairy(object):
)
except Exception as err:
with util.safe_reraise():
- fairy._connection_record._checkin_failed(err)
+ fairy._connection_record._checkin_failed(
+ err,
+ _fairy_was_created=True,
+ )
+
+ # prevent _ConnectionFairy from being carried
+ # in the stack trace. Do this after the
+ # connection record has been checked in, so that
+ # if the del triggers a finalize fairy, it won't
+ # try to checkin a second time.
+ del fairy
attempts -= 1
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index a1426b628..550111020 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -729,7 +729,7 @@ class ExecutableOption(HasCopyInternals, HasCacheKey):
return c
-class Executable(Generative):
+class Executable(roles.CoerceTextStatementRole, Generative):
"""Mark a :class:`_expression.ClauseElement` as supporting execution.
:class:`.Executable` is a superclass for all "statement" types
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index d3c767b5d..5ea3526ea 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -1502,7 +1502,6 @@ class TextClause(
roles.OrderByRole,
roles.FromClauseRole,
roles.SelectStatementRole,
- roles.CoerceTextStatementRole,
roles.BinaryElementRole,
roles.InElementRole,
Executable,
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index c1afeb907..9f2d0b857 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -32,6 +32,7 @@ from .assertions import is_instance_of # noqa
from .assertions import is_none # noqa
from .assertions import is_not # noqa
from .assertions import is_not_ # noqa
+from .assertions import is_not_none # noqa
from .assertions import is_true # noqa
from .assertions import le_ # noqa
from .assertions import ne_ # noqa
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index b2a4ac66e..f2ed91b79 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -228,6 +228,10 @@ def is_none(a, msg=None):
is_(a, None, msg=msg)
+def is_not_none(a, msg=None):
+ is_not(a, None, msg=msg)
+
+
def is_true(a, msg=None):
is_(bool(a), True, msg=msg)
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index f64153f33..750671f9f 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -97,6 +97,10 @@ def get_current_test_name():
return _fixture_functions.get_current_test_name()
+def mark_base_test_class():
+ return _fixture_functions.mark_base_test_class()
+
+
class Config(object):
def __init__(self, db, db_opts, options, file_config):
self._set_name(db)
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index a4c1f3973..a313c298a 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -7,12 +7,14 @@
from __future__ import absolute_import
+import collections
import re
import warnings
import weakref
from . import config
from .util import decorator
+from .util import gc_collect
from .. import event
from .. import pool
@@ -20,26 +22,29 @@ from .. import pool
class ConnectionKiller(object):
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
- self.testing_engines = weakref.WeakKeyDictionary()
- self.conns = set()
+ self.testing_engines = collections.defaultdict(set)
+ self.dbapi_connections = set()
def add_pool(self, pool):
- event.listen(pool, "connect", self.connect)
- event.listen(pool, "checkout", self.checkout)
- event.listen(pool, "invalidate", self.invalidate)
-
- def add_engine(self, engine):
- self.add_pool(engine.pool)
- self.testing_engines[engine] = True
+ event.listen(pool, "checkout", self._add_conn)
+ event.listen(pool, "checkin", self._remove_conn)
+ event.listen(pool, "close", self._remove_conn)
+ event.listen(pool, "close_detached", self._remove_conn)
+ # note we are keeping "invalidated" here, as those are still
+ # opened connections we would like to roll back
+
+ def _add_conn(self, dbapi_con, con_record, con_proxy):
+ self.dbapi_connections.add(dbapi_con)
+ self.proxy_refs[con_proxy] = True
- def connect(self, dbapi_conn, con_record):
- self.conns.add((dbapi_conn, con_record))
+ def _remove_conn(self, dbapi_conn, *arg):
+ self.dbapi_connections.discard(dbapi_conn)
- def checkout(self, dbapi_con, con_record, con_proxy):
- self.proxy_refs[con_proxy] = True
+ def add_engine(self, engine, scope):
+ self.add_pool(engine.pool)
- def invalidate(self, dbapi_con, con_record, exception):
- self.conns.discard((dbapi_con, con_record))
+ assert scope in ("class", "global", "function", "fixture")
+ self.testing_engines[scope].add(engine)
def _safe(self, fn):
try:
@@ -54,53 +59,89 @@ class ConnectionKiller(object):
if rec is not None and rec.is_valid:
self._safe(rec.rollback)
- def close_all(self):
+ def checkin_all(self):
+ # run pool.checkin() for all ConnectionFairy instances we have
+ # tracked.
+
for rec in list(self.proxy_refs):
if rec is not None and rec.is_valid:
- self._safe(rec._close)
-
- def _after_test_ctx(self):
- # this can cause a deadlock with pg8000 - pg8000 acquires
- # prepared statement lock inside of rollback() - if async gc
- # is collecting in finalize_fairy, deadlock.
- # not sure if this should be for non-cpython only.
- # note that firebird/fdb definitely needs this though
- for conn, rec in list(self.conns):
- if rec.connection is None:
- # this is a hint that the connection is closed, which
- # is causing segfaults on mysqlclient due to
- # https://github.com/PyMySQL/mysqlclient-python/issues/270;
- # try to work around here
- continue
- self._safe(conn.rollback)
-
- def _stop_test_ctx(self):
- if config.options.low_connections:
- self._stop_test_ctx_minimal()
- else:
- self._stop_test_ctx_aggressive()
-
- def _stop_test_ctx_minimal(self):
- self.close_all()
+ self.dbapi_connections.discard(rec.connection)
+ self._safe(rec._checkin)
- self.conns = set()
-
- for rec in list(self.testing_engines):
- if rec is not config.db:
- rec.dispose()
-
- def _stop_test_ctx_aggressive(self):
- self.close_all()
- for conn, rec in list(self.conns):
- self._safe(conn.close)
- rec.connection = None
+ # for fairy refs that were GCed and could not close the connection,
+ # such as asyncio, roll back those remaining connections
+ for con in self.dbapi_connections:
+ self._safe(con.rollback)
+ self.dbapi_connections.clear()
- self.conns = set()
- for rec in list(self.testing_engines):
- if hasattr(rec, "sync_engine"):
- rec.sync_engine.dispose()
- else:
- rec.dispose()
+ def close_all(self):
+ self.checkin_all()
+
+ def prepare_for_drop_tables(self, connection):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
+
+ from . import provision
+
+ provision.prepare_for_drop_tables(connection.engine.url, connection)
+
+ def _drop_testing_engines(self, scope):
+ eng = self.testing_engines[scope]
+ for rec in list(eng):
+ for proxy_ref in list(self.proxy_refs):
+ if proxy_ref is not None and proxy_ref.is_valid:
+ if (
+ proxy_ref._pool is not None
+ and proxy_ref._pool is rec.pool
+ ):
+ self._safe(proxy_ref._checkin)
+ rec.dispose()
+ eng.clear()
+
+ def after_test(self):
+ self._drop_testing_engines("function")
+
+ def after_test_outside_fixtures(self, test):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
+
+ if test.__class__.__leave_connections_for_teardown__:
+ return
+
+ self.checkin_all()
+
+ # on PostgreSQL, this will test for any "idle in transaction"
+ # connections. useful to identify tests with unusual patterns
+ # that can't be cleaned up correctly.
+ from . import provision
+
+ with config.db.connect() as conn:
+ provision.prepare_for_drop_tables(conn.engine.url, conn)
+
+ def stop_test_class_inside_fixtures(self):
+ self.checkin_all()
+ self._drop_testing_engines("function")
+ self._drop_testing_engines("class")
+
+ def stop_test_class_outside_fixtures(self):
+ # ensure no refs to checked out connections at all.
+
+ if pool.base._strong_ref_connection_records:
+ gc_collect()
+
+ if pool.base._strong_ref_connection_records:
+ ln = len(pool.base._strong_ref_connection_records)
+ pool.base._strong_ref_connection_records.clear()
+ assert (
+ False
+ ), "%d connection recs not cleared after test suite" % (ln)
+
+ def final_cleanup(self):
+ self.checkin_all()
+ for scope in self.testing_engines:
+ self._drop_testing_engines(scope)
def assert_all_closed(self):
for rec in self.proxy_refs:
@@ -111,20 +152,6 @@ class ConnectionKiller(object):
testing_reaper = ConnectionKiller()
-def drop_all_tables(metadata, bind):
- testing_reaper.close_all()
- if hasattr(bind, "close"):
- bind.close()
-
- if not config.db.dialect.supports_alter:
- from . import assertions
-
- with assertions.expect_warnings("Can't sort tables", assert_=False):
- metadata.drop_all(bind)
- else:
- metadata.drop_all(bind)
-
-
@decorator
def assert_conns_closed(fn, *args, **kw):
try:
@@ -147,7 +174,7 @@ def rollback_open_connections(fn, *args, **kw):
def close_first(fn, *args, **kw):
"""Decorator that closes all connections before fn execution."""
- testing_reaper.close_all()
+ testing_reaper.checkin_all()
fn(*args, **kw)
@@ -157,7 +184,7 @@ def close_open_connections(fn, *args, **kw):
try:
fn(*args, **kw)
finally:
- testing_reaper.close_all()
+ testing_reaper.checkin_all()
def all_dialects(exclude=None):
@@ -239,12 +266,14 @@ def reconnecting_engine(url=None, options=None):
return engine
-def testing_engine(url=None, options=None, future=False, asyncio=False):
+def testing_engine(url=None, options=None, future=None, asyncio=False):
"""Produce an engine configured by --options with optional overrides."""
if asyncio:
from sqlalchemy.ext.asyncio import create_async_engine as create_engine
- elif future or config.db and config.db._is_future:
+ elif future or (
+ config.db and config.db._is_future and future is not False
+ ):
from sqlalchemy.future import create_engine
else:
from sqlalchemy import create_engine
@@ -252,8 +281,10 @@ def testing_engine(url=None, options=None, future=False, asyncio=False):
if not options:
use_reaper = True
+ scope = "function"
else:
use_reaper = options.pop("use_reaper", True)
+ scope = options.pop("scope", "function")
url = url or config.db.url
@@ -268,16 +299,20 @@ def testing_engine(url=None, options=None, future=False, asyncio=False):
default_opt.update(options)
engine = create_engine(url, **options)
- if asyncio:
- engine.sync_engine._has_events = True
- else:
- engine._has_events = True # enable event blocks, helps with profiling
+
+ if scope == "global":
+ if asyncio:
+ engine.sync_engine._has_events = True
+ else:
+ engine._has_events = (
+ True # enable event blocks, helps with profiling
+ )
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
- engine.pool._max_overflow = 5
+ engine.pool._max_overflow = 0
if use_reaper:
- testing_reaper.add_engine(engine)
+ testing_reaper.add_engine(engine, scope)
return engine
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index ac4d3d8fa..45ca48444 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -5,6 +5,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import contextlib
import re
import sys
@@ -12,12 +13,11 @@ import sqlalchemy as sa
from . import assertions
from . import config
from . import schema
-from .engines import drop_all_tables
-from .engines import testing_engine
from .entities import BasicEntity
from .entities import ComparableEntity
from .entities import ComparableMixin # noqa
from .util import adict
+from .util import drop_all_tables_from_metadata
from .. import event
from .. import util
from ..orm import declarative_base
@@ -25,10 +25,8 @@ from ..orm import registry
from ..orm.decl_api import DeclarativeMeta
from ..schema import sort_tables_and_constraints
-# whether or not we use unittest changes things dramatically,
-# as far as how pytest collection works.
-
+@config.mark_base_test_class()
class TestBase(object):
# A sequence of database names to always run, regardless of the
# constraints below.
@@ -48,81 +46,122 @@ class TestBase(object):
# skipped.
__skip_if__ = None
+ # if True, the testing reaper will not attempt to touch connection
+ # state after a test is completed and before the outer teardown
+ # starts
+ __leave_connections_for_teardown__ = False
+
def assert_(self, val, msg=None):
assert val, msg
- # apparently a handful of tests are doing this....OK
- def setup(self):
- if hasattr(self, "setUp"):
- self.setUp()
-
- def teardown(self):
- if hasattr(self, "tearDown"):
- self.tearDown()
-
@config.fixture()
def connection(self):
- eng = getattr(self, "bind", config.db)
+ global _connection_fixture_connection
+
+ eng = getattr(self, "bind", None) or config.db
conn = eng.connect()
trans = conn.begin()
- try:
- yield conn
- finally:
- if trans.is_active:
- trans.rollback()
- conn.close()
+
+ _connection_fixture_connection = conn
+ yield conn
+
+ _connection_fixture_connection = None
+
+ if trans.is_active:
+ trans.rollback()
+ # trans would not be active here if the test is using
+ # the legacy @provide_metadata decorator still, as it will
+ # run a close all connections.
+ conn.close()
@config.fixture()
- def future_connection(self):
+ def future_connection(self, future_engine, connection):
+ # integrate the future_engine and connection fixtures so
+ # that users of the "connection" fixture will get at the
+ # "future" connection
+ yield connection
- eng = testing_engine(future=True)
- conn = eng.connect()
- trans = conn.begin()
- try:
- yield conn
- finally:
- if trans.is_active:
- trans.rollback()
- conn.close()
+ @config.fixture()
+ def future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
+
+ @config.fixture()
+ def testing_engine(self):
+ from . import engines
+
+ def gen_testing_engine(
+ url=None, options=None, future=None, asyncio=False
+ ):
+ if options is None:
+ options = {}
+ options["scope"] = "fixture"
+ return engines.testing_engine(
+ url=url, options=options, future=future, asyncio=asyncio
+ )
+
+ yield gen_testing_engine
+
+ engines.testing_reaper._drop_testing_engines("fixture")
@config.fixture()
- def metadata(self):
+ def async_testing_engine(self, testing_engine):
+ def go(**kw):
+ kw["asyncio"] = True
+ return testing_engine(**kw)
+
+ return go
+
+ @config.fixture()
+ def metadata(self, request):
"""Provide bound MetaData for a single test, dropping afterwards."""
- from . import engines
from ..sql import schema
metadata = schema.MetaData()
- try:
- yield metadata
- finally:
- engines.drop_all_tables(metadata, config.db)
+ request.instance.metadata = metadata
+ yield metadata
+ del request.instance.metadata
+ if (
+ _connection_fixture_connection
+ and _connection_fixture_connection.in_transaction()
+ ):
+ trans = _connection_fixture_connection.get_transaction()
+ trans.rollback()
+ with _connection_fixture_connection.begin():
+ drop_all_tables_from_metadata(
+ metadata, _connection_fixture_connection
+ )
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
-class FutureEngineMixin(object):
- @classmethod
- def setup_class(cls):
- from ..future.engine import Engine
- from sqlalchemy import testing
+_connection_fixture_connection = None
- facade = Engine._future_facade(config.db)
- config._current.push_engine(facade, testing)
- super_ = super(FutureEngineMixin, cls)
- if hasattr(super_, "setup_class"):
- super_.setup_class()
+@contextlib.contextmanager
+def _push_future_engine(engine):
- @classmethod
- def teardown_class(cls):
- super_ = super(FutureEngineMixin, cls)
- if hasattr(super_, "teardown_class"):
- super_.teardown_class()
+ from ..future.engine import Engine
+ from sqlalchemy import testing
+
+ facade = Engine._future_facade(engine)
+ config._current.push_engine(facade, testing)
+
+ yield facade
- from sqlalchemy import testing
+ config._current.pop(testing)
- config._current.pop(testing)
+
+class FutureEngineMixin(object):
+ @config.fixture(autouse=True, scope="class")
+ def _push_future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
class TablesTest(TestBase):
@@ -151,18 +190,32 @@ class TablesTest(TestBase):
other = None
sequences = None
- @property
- def tables_test_metadata(self):
- return self._tables_metadata
-
- @classmethod
- def setup_class(cls):
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
cls._init_class()
cls._setup_once_tables()
cls._setup_once_inserts()
+ yield
+
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_inserts()
+
+ yield
+
+ self._teardown_each_tables()
+
+ @property
+ def tables_test_metadata(self):
+ return self._tables_metadata
+
@classmethod
def _init_class(cls):
if cls.run_define_tables == "each":
@@ -213,10 +266,10 @@ class TablesTest(TestBase):
if self.run_define_tables == "each":
self.tables.clear()
if self.run_create_tables == "each":
- drop_all_tables(self._tables_metadata, self.bind)
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
self._tables_metadata.clear()
elif self.run_create_tables == "each":
- drop_all_tables(self._tables_metadata, self.bind)
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
# no need to run deletes if tables are recreated on setup
if (
@@ -242,17 +295,10 @@ class TablesTest(TestBase):
file=sys.stderr,
)
- def setup(self):
- self._setup_each_tables()
- self._setup_each_inserts()
-
- def teardown(self):
- self._teardown_each_tables()
-
@classmethod
def _teardown_once_metadata_bind(cls):
if cls.run_create_tables:
- drop_all_tables(cls._tables_metadata, cls.bind)
+ drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
if cls.run_dispose_bind == "once":
cls.dispose_bind(cls.bind)
@@ -263,10 +309,6 @@ class TablesTest(TestBase):
cls.bind = None
@classmethod
- def teardown_class(cls):
- cls._teardown_once_metadata_bind()
-
- @classmethod
def setup_bind(cls):
return config.db
@@ -332,38 +374,45 @@ class RemovesEvents(object):
self._event_fns.add((target, name, fn))
event.listen(target, name, fn, **kw)
- def teardown(self):
+ @config.fixture(autouse=True, scope="function")
+ def _remove_events(self):
+ yield
for key in self._event_fns:
event.remove(*key)
- super_ = super(RemovesEvents, self)
- if hasattr(super_, "teardown"):
- super_.teardown()
-class _ORMTest(object):
- @classmethod
- def teardown_class(cls):
- sa.orm.session.close_all_sessions()
- sa.orm.clear_mappers()
-
-
-def create_session(**kw):
- kw.setdefault("autoflush", False)
- kw.setdefault("expire_on_commit", False)
- return sa.orm.Session(config.db, **kw)
+_fixture_sessions = set()
def fixture_session(**kw):
kw.setdefault("autoflush", True)
kw.setdefault("expire_on_commit", True)
- return sa.orm.Session(config.db, **kw)
+ sess = sa.orm.Session(config.db, **kw)
+ _fixture_sessions.add(sess)
+ return sess
+
+
+def _close_all_sessions():
+ # will close all still-referenced sessions
+ sa.orm.session.close_all_sessions()
+ _fixture_sessions.clear()
-class ORMTest(_ORMTest, TestBase):
+def stop_test_class_inside_fixtures(cls):
+ _close_all_sessions()
+ sa.orm.clear_mappers()
+
+
+def after_test():
+ if _fixture_sessions:
+ _close_all_sessions()
+
+
+class ORMTest(TestBase):
pass
-class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
+class MappedTest(TablesTest, assertions.AssertsExecutionResults):
# 'once', 'each', None
run_setup_classes = "once"
@@ -372,8 +421,9 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
classes = None
- @classmethod
- def setup_class(cls):
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
cls._init_class()
if cls.classes is None:
@@ -384,18 +434,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
cls._setup_once_mappers()
cls._setup_once_inserts()
- @classmethod
- def teardown_class(cls):
+ yield
+
cls._teardown_once_class()
cls._teardown_once_metadata_bind()
- def setup(self):
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
self._setup_each_tables()
self._setup_each_classes()
self._setup_each_mappers()
self._setup_each_inserts()
- def teardown(self):
+ yield
+
sa.orm.session.close_all_sessions()
self._teardown_each_mappers()
self._teardown_each_classes()
@@ -404,7 +456,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
@classmethod
def _teardown_once_class(cls):
cls.classes.clear()
- _ORMTest.teardown_class()
@classmethod
def _setup_once_classes(cls):
@@ -440,6 +491,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
"""
cls_registry = cls.classes
+ assert cls_registry is not None
+
class FindFixture(type):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py
index a95c947e2..1f568dfc8 100644
--- a/lib/sqlalchemy/testing/plugin/bootstrap.py
+++ b/lib/sqlalchemy/testing/plugin/bootstrap.py
@@ -40,6 +40,11 @@ def load_file_as_module(name):
if to_bootstrap == "pytest":
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
+ sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
+ if sys.version_info < (3, 0):
+ sys.modules["sqla_reinvent_fixtures"] = load_file_as_module(
+ "reinvent_fixtures_py2k"
+ )
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
else:
raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 3594cd276..858814f91 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -21,6 +21,9 @@ import logging
import re
import sys
+# flag which indicates we are in the SQLAlchemy testing suite,
+# and not that of Alembic or a third party dialect.
+bootstrapped_as_sqlalchemy = False
log = logging.getLogger("sqlalchemy.testing.plugin_base")
@@ -381,7 +384,7 @@ def _init_symbols(options, file_config):
@post
def _set_disable_asyncio(opt, file_config):
- if opt.disable_asyncio:
+ if opt.disable_asyncio or not py3k:
from sqlalchemy.testing import asyncio
asyncio.ENABLE_ASYNCIO = False
@@ -458,6 +461,8 @@ def _setup_requirements(argument):
config.requirements = testing.requires = req_cls()
+ config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
+
@post
def _prep_testing_database(options, file_config):
@@ -566,17 +571,23 @@ def generate_sub_tests(cls, module):
yield cls
-def start_test_class(cls):
+def start_test_class_outside_fixtures(cls):
_do_skips(cls)
_setup_engine(cls)
def stop_test_class(cls):
- # from sqlalchemy import inspect
- # assert not inspect(testing.db).get_table_names()
+ # close sessions, immediate connections, etc.
+ fixtures.stop_test_class_inside_fixtures(cls)
+
+ # close outstanding connection pool connections, dispose of
+ # additional engines
+ engines.testing_reaper.stop_test_class_inside_fixtures()
- provision.stop_test_class(config, config.db, cls)
- engines.testing_reaper._stop_test_ctx()
+
+def stop_test_class_outside_fixtures(cls):
+ engines.testing_reaper.stop_test_class_outside_fixtures()
+ provision.stop_test_class_outside_fixtures(config, config.db, cls)
try:
if not options.low_connections:
assertions.global_cleanup_assertions()
@@ -590,14 +601,16 @@ def _restore_engine():
def final_process_cleanup():
- engines.testing_reaper._stop_test_ctx_aggressive()
+ engines.testing_reaper.final_cleanup()
assertions.global_cleanup_assertions()
_restore_engine()
def _setup_engine(cls):
if getattr(cls, "__engine_options__", None):
- eng = engines.testing_engine(options=cls.__engine_options__)
+ opts = dict(cls.__engine_options__)
+ opts["scope"] = "class"
+ eng = engines.testing_engine(options=opts)
config._current.push_engine(eng, testing)
@@ -614,7 +627,12 @@ def before_test(test, test_module_name, test_class, test_name):
def after_test(test):
- engines.testing_reaper._after_test_ctx()
+ fixtures.after_test()
+ engines.testing_reaper.after_test()
+
+
+def after_test_fixtures(test):
+ engines.testing_reaper.after_test_outside_fixtures(test)
def _possible_configs_for_cls(cls, reasons=None, sparse=False):
@@ -748,6 +766,10 @@ class FixtureFunctions(ABC):
def get_current_test_name(self):
raise NotImplementedError()
+ @abc.abstractmethod
+ def mark_base_test_class(self):
+ raise NotImplementedError()
+
_fixture_fn_class = None
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index 46468a07d..4eaaecebb 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -17,6 +17,7 @@ import sys
import pytest
+
try:
import typing
except ImportError:
@@ -33,6 +34,14 @@ except ImportError:
has_xdist = False
+py2k = sys.version_info < (3, 0)
+if py2k:
+ try:
+ import sqla_reinvent_fixtures as reinvent_fixtures_py2k
+ except ImportError:
+ from . import reinvent_fixtures_py2k
+
+
def pytest_addoption(parser):
group = parser.getgroup("sqlalchemy")
@@ -238,6 +247,10 @@ def pytest_collection_modifyitems(session, config, items):
else:
newitems.append(item)
+ if py2k:
+ for item in newitems:
+ reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item)
+
# seems like the functions attached to a test class aren't sorted already?
# is that true and why's that? (when using unittest, they're sorted)
items[:] = sorted(
@@ -251,7 +264,6 @@ def pytest_collection_modifyitems(session, config, items):
def pytest_pycollect_makeitem(collector, name, obj):
-
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
from sqlalchemy.testing import config
@@ -259,7 +271,6 @@ def pytest_pycollect_makeitem(collector, name, obj):
obj = _apply_maybe_async(obj)
ctor = getattr(pytest.Class, "from_parent", pytest.Class)
-
return [
ctor(name=parametrize_cls.__name__, parent=collector)
for parametrize_cls in _parametrize_cls(collector.module, obj)
@@ -287,12 +298,11 @@ def _is_wrapped_coroutine_function(fn):
def _apply_maybe_async(obj, recurse=True):
from sqlalchemy.testing import asyncio
- setup_names = {"setup", "setup_class", "teardown", "teardown_class"}
for name, value in vars(obj).items():
if (
(callable(value) or isinstance(value, classmethod))
and not getattr(value, "_maybe_async_applied", False)
- and (name.startswith("test_") or name in setup_names)
+ and (name.startswith("test_"))
and not _is_wrapped_coroutine_function(value)
):
is_classmethod = False
@@ -317,9 +327,6 @@ def _apply_maybe_async(obj, recurse=True):
return obj
-_current_class = None
-
-
def _parametrize_cls(module, cls):
"""implement a class-based version of pytest parametrize."""
@@ -355,63 +362,153 @@ def _parametrize_cls(module, cls):
return classes
+_current_class = None
+
+
def pytest_runtest_setup(item):
from sqlalchemy.testing import asyncio
- # here we seem to get called only based on what we collected
- # in pytest_collection_modifyitems. So to do class-based stuff
- # we have to tear that out.
- global _current_class
-
if not isinstance(item, pytest.Function):
return
- # ... so we're doing a little dance here to figure it out...
+ # pytest_runtest_setup runs *before* pytest fixtures with scope="class".
+ # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
+ # for the whole class and has to run things that are across all current
+ # databases, so we run this outside of the pytest fixture system altogether
+ # and ensure asyncio greenlet if any engines are async
+
+ global _current_class
+
if _current_class is None:
- asyncio._maybe_async(class_setup, item.parent.parent)
+ asyncio._maybe_async_provisioning(
+ plugin_base.start_test_class_outside_fixtures,
+ item.parent.parent.cls,
+ )
_current_class = item.parent.parent
- # this is needed for the class-level, to ensure that the
- # teardown runs after the class is completed with its own
- # class-level teardown...
def finalize():
global _current_class
- asyncio._maybe_async(class_teardown, item.parent.parent)
_current_class = None
+ asyncio._maybe_async_provisioning(
+ plugin_base.stop_test_class_outside_fixtures,
+ item.parent.parent.cls,
+ )
+
item.parent.parent.addfinalizer(finalize)
- asyncio._maybe_async(test_setup, item)
+def pytest_runtest_call(item):
+ # runs inside of pytest function fixture scope
+ # before test function runs
-def pytest_runtest_teardown(item):
from sqlalchemy.testing import asyncio
- # ...but this works better as the hook here rather than
- # using a finalizer, as the finalizer seems to get in the way
- # of the test reporting failures correctly (you get a bunch of
- # pytest assertion stuff instead)
- asyncio._maybe_async(test_teardown, item)
+ asyncio._maybe_async(
+ plugin_base.before_test,
+ item,
+ item.parent.module.__name__,
+ item.parent.cls,
+ item.name,
+ )
-def test_setup(item):
- plugin_base.before_test(
- item, item.parent.module.__name__, item.parent.cls, item.name
- )
+def pytest_runtest_teardown(item, nextitem):
+ # runs inside of pytest function fixture scope
+ # after test function runs
+ from sqlalchemy.testing import asyncio
-def test_teardown(item):
- plugin_base.after_test(item)
+ asyncio._maybe_async(plugin_base.after_test, item)
-def class_setup(item):
+@pytest.fixture(scope="class")
+def setup_class_methods(request):
from sqlalchemy.testing import asyncio
- asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls)
+ cls = request.cls
+
+ if hasattr(cls, "setup_test_class"):
+ asyncio._maybe_async(cls.setup_test_class)
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_setup(request)
+
+ yield
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_teardown(request)
+ if hasattr(cls, "teardown_test_class"):
+ asyncio._maybe_async(cls.teardown_test_class)
-def class_teardown(item):
- plugin_base.stop_test_class(item.cls)
+ asyncio._maybe_async(plugin_base.stop_test_class, cls)
+
+
+@pytest.fixture(scope="function")
+def setup_test_methods(request):
+ from sqlalchemy.testing import asyncio
+
+ # called for each test
+
+ self = request.instance
+
+ # 1. run outer xdist-style setup
+ if hasattr(self, "setup_test"):
+ asyncio._maybe_async(self.setup_test)
+
+ # alembic test suite is using setUp and tearDown
+ # xdist methods; support these in the test suite
+ # for the near term
+ if hasattr(self, "setUp"):
+ asyncio._maybe_async(self.setUp)
+
+ # 2. run homegrown function level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_setup(request)
+
+ # inside the yield:
+
+ # 3. function level "autouse" fixtures under py3k (examples: TablesTest
+ # define tables / data, MappedTest define tables / mappers / data)
+
+ # 4. function level fixtures defined on test functions themselves,
+ # e.g. "connection", "metadata" run next
+
+ # 5. pytest hook pytest_runtest_call then runs
+
+ # 6. test itself runs
+
+ yield
+
+ # yield finishes:
+
+ # 7. pytest hook pytest_runtest_teardown hook runs, this is associated
+ # with fixtures close all sessions, provisioning.stop_test_class(),
+ # engines.testing_reaper -> ensure all connection pool connections
+ # are returned, engines created by testing_engine that aren't the
+ # config engine are disposed
+
+ # 8. function level fixtures defined on test functions
+ # themselves, e.g. "connection" rolls back the transaction, "metadata"
+ # emits drop all
+
+ # 9. function level "autouse" fixtures under py3k (examples: TablesTest /
+ # MappedTest delete table data, possibly drop tables and clear mappers
+ # depending on the flags defined by the test class)
+
+ # 10. run homegrown function-level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_teardown(request)
+
+ asyncio._maybe_async(plugin_base.after_test_fixtures, self)
+
+ # 11. run outer xdist-style teardown
+ if hasattr(self, "tearDown"):
+ asyncio._maybe_async(self.tearDown)
+
+ if hasattr(self, "teardown_test"):
+ asyncio._maybe_async(self.teardown_test)
def getargspec(fn):
@@ -461,6 +558,8 @@ def %(name)s(%(args)s):
# for the wrapped function
decorated.__module__ = fn.__module__
decorated.__name__ = fn.__name__
+ if hasattr(fn, "pytestmark"):
+ decorated.pytestmark = fn.pytestmark
return decorated
return decorate
@@ -470,6 +569,11 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
def skip_test_exception(self, *arg, **kw):
return pytest.skip.Exception(*arg, **kw)
+ def mark_base_test_class(self):
+ return pytest.mark.usefixtures(
+ "setup_class_methods", "setup_test_methods"
+ )
+
_combination_id_fns = {
"i": lambda obj: obj,
"r": repr,
@@ -647,8 +751,18 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
fn = asyncio._maybe_async_wrapper(fn)
# other wrappers may be added here
- # now apply FixtureFunctionMarker
- fn = fixture(fn)
+ if py2k and "autouse" in kw:
+ # py2k workaround for too-slow collection of autouse fixtures
+ # in pytest 4.6.11. See notes in reinvent_fixtures_py2k for
+ # rationale.
+
+ # comment this condition out in order to disable the
+ # py2k workaround entirely.
+ reinvent_fixtures_py2k.add_fixture(fn, fixture)
+ else:
+ # now apply FixtureFunctionMarker
+ fn = fixture(fn)
+
return fn
if fn:
diff --git a/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
new file mode 100644
index 000000000..36b68417b
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
@@ -0,0 +1,112 @@
+"""
+invent a quick version of pytest autouse fixtures as pytest's unacceptably slow
+collection/high memory use in pytest 4.6.11, which is the highest version that
+works in py2k.
+
+by "too-slow" we mean the test suite can't even manage to be collected for a
+single process in less than 70 seconds or so and memory use seems to be very
+high as well. for two or four workers the job just times out after ten
+minutes.
+
+so instead we have invented a very limited form of these fixtures, as our
+current use of "autouse" fixtures are limited to those in fixtures.py.
+
+assumptions for these fixtures:
+
+1. we are only using "function" or "class" scope
+
+2. the functions must be associated with a test class
+
+3. the fixture functions cannot themselves use pytest fixtures
+
+4. the fixture functions must use yield, not return
+
+When py2k support is removed and we can stay on a modern pytest version, this
+can all be removed.
+
+
+"""
+import collections
+
+
+_py2k_fixture_fn_names = collections.defaultdict(set)
+_py2k_class_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+_py2k_function_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+
+_py2k_cls_fixture_stack = []
+_py2k_fn_fixture_stack = []
+
+
+def add_fixture(fn, fixture):
+ assert fixture.scope in ("class", "function")
+ _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope))
+
+
+def scan_for_fixtures_to_use_for_class(item):
+ test_class = item.parent.parent.obj
+
+ for name in _py2k_fixture_fn_names:
+ for fixture_fn, scope in _py2k_fixture_fn_names[name]:
+ meth = getattr(test_class, name, None)
+ if meth and meth.im_func is fixture_fn:
+ for sup in test_class.__mro__:
+ if name in sup.__dict__:
+ if scope == "class":
+ _py2k_class_fixtures[test_class][sup].add(meth)
+ elif scope == "function":
+ _py2k_function_fixtures[test_class][sup].add(meth)
+ break
+ break
+
+
+def run_class_fixture_setup(request):
+
+ cls = request.cls
+ self = cls.__new__(cls)
+
+ fixtures_for_this_class = _py2k_class_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in cls.__mro__:
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_cls_fixture_stack.append(iter_)
+
+
+def run_class_fixture_teardown(request):
+ while _py2k_cls_fixture_stack:
+ iter_ = _py2k_cls_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
+
+
+def run_fn_fixture_setup(request):
+ cls = request.cls
+ self = request.instance
+
+ fixtures_for_this_class = _py2k_function_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in reversed(cls.__mro__):
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_fn_fixture_stack.append(iter_)
+
+
+def run_fn_fixture_teardown(request):
+ while _py2k_fn_fixture_stack:
+ iter_ = _py2k_fn_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
index 4ee0567f2..2fade1c32 100644
--- a/lib/sqlalchemy/testing/provision.py
+++ b/lib/sqlalchemy/testing/provision.py
@@ -67,6 +67,7 @@ def setup_config(db_url, options, file_config, follower_ident):
db_url = follower_url_from_main(db_url, follower_ident)
db_opts = {}
update_db_opts(db_url, db_opts)
+ db_opts["scope"] = "global"
eng = engines.testing_engine(db_url, db_opts)
post_configure_engine(db_url, eng, follower_ident)
eng.connect().close()
@@ -264,6 +265,7 @@ def drop_all_schema_objects(cfg, eng):
if config.requirements.schemas.enabled_for_config(cfg):
util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
+ util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
drop_all_schema_objects_post_tables(cfg, eng)
@@ -299,7 +301,7 @@ def update_db_opts(db_url, db_opts):
def post_configure_engine(url, engine, follower_ident):
"""Perform extra steps after configuring an engine for testing.
- (For the internal dialects, currently only used by sqlite.)
+ (For the internal dialects, currently only used by sqlite, oracle)
"""
pass
@@ -375,7 +377,12 @@ def temp_table_keyword_args(cfg, eng):
@register.init
-def stop_test_class(config, db, testcls):
+def prepare_for_drop_tables(config, connection):
+ pass
+
+
+@register.init
+def stop_test_class_outside_fixtures(config, db, testcls):
pass
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
index da59d831f..35f3315c7 100644
--- a/lib/sqlalchemy/testing/suite/test_insert.py
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -51,13 +51,15 @@ class LastrowidTest(fixtures.TablesTest):
def test_autoincrement_on_insert(self, connection):
- connection.execute(self.tables.autoinc_pk.insert(), data="some data")
+ connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
self._assert_round_trip(self.tables.autoinc_pk, connection)
def test_last_inserted_id(self, connection):
r = connection.execute(
- self.tables.autoinc_pk.insert(), data="some data"
+ self.tables.autoinc_pk.insert(), dict(data="some data")
)
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
eq_(r.inserted_primary_key, (pk,))
@@ -65,7 +67,7 @@ class LastrowidTest(fixtures.TablesTest):
@requirements.dbapi_lastrowid
def test_native_lastrowid_autoinc(self, connection):
r = connection.execute(
- self.tables.autoinc_pk.insert(), data="some data"
+ self.tables.autoinc_pk.insert(), dict(data="some data")
)
lastrowid = r.lastrowid
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
@@ -116,7 +118,9 @@ class InsertBehaviorTest(fixtures.TablesTest):
engine = config.db
with engine.begin() as conn:
- r = conn.execute(self.tables.autoinc_pk.insert(), data="some data")
+ r = conn.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
assert r._soft_closed
assert not r.closed
assert r.is_insert
@@ -131,7 +135,7 @@ class InsertBehaviorTest(fixtures.TablesTest):
@requirements.returning
def test_autoclose_on_insert_implicit_returning(self, connection):
r = connection.execute(
- self.tables.autoinc_pk.insert(), data="some data"
+ self.tables.autoinc_pk.insert(), dict(data="some data")
)
assert r._soft_closed
assert not r.closed
@@ -315,7 +319,7 @@ class ReturningTest(fixtures.TablesTest):
def test_explicit_returning_pk_autocommit(self, connection):
table = self.tables.autoinc_pk
r = connection.execute(
- table.insert().returning(table.c.id), data="some data"
+ table.insert().returning(table.c.id), dict(data="some data")
)
pk = r.first()[0]
fetched_pk = connection.scalar(select(table.c.id))
@@ -324,7 +328,7 @@ class ReturningTest(fixtures.TablesTest):
def test_explicit_returning_pk_no_autocommit(self, connection):
table = self.tables.autoinc_pk
r = connection.execute(
- table.insert().returning(table.c.id), data="some data"
+ table.insert().returning(table.c.id), dict(data="some data")
)
pk = r.first()[0]
fetched_pk = connection.scalar(select(table.c.id))
@@ -332,13 +336,15 @@ class ReturningTest(fixtures.TablesTest):
def test_autoincrement_on_insert_implicit_returning(self, connection):
- connection.execute(self.tables.autoinc_pk.insert(), data="some data")
+ connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
self._assert_round_trip(self.tables.autoinc_pk, connection)
def test_last_inserted_id_implicit_returning(self, connection):
r = connection.execute(
- self.tables.autoinc_pk.insert(), data="some data"
+ self.tables.autoinc_pk.insert(), dict(data="some data")
)
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
eq_(r.inserted_primary_key, (pk,))
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 6c3c1005a..916d74db3 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -293,7 +293,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
from sqlalchemy import pool
return engines.testing_engine(
- options=dict(poolclass=pool.StaticPool)
+ options=dict(poolclass=pool.StaticPool, scope="class"),
)
else:
return config.db
@@ -523,6 +523,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
insp = inspect(self.bind)
eq_(insp.default_schema_name, self.bind.dialect.default_schema_name)
+ @testing.requires.foreign_key_constraint_reflection
@testing.combinations(
(None, True, False, False),
(None, True, False, True, testing.requires.schemas),
@@ -630,8 +631,12 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.combinations(
(False, False),
(False, True, testing.requires.schemas),
- (True, False),
- (False, True, testing.requires.schemas),
+ (True, False, testing.requires.view_reflection),
+ (
+ True,
+ True,
+ testing.requires.schemas + testing.requires.view_reflection,
+ ),
argnames="use_views,use_schema",
)
def test_get_columns(self, connection, use_views, use_schema):
@@ -999,6 +1004,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(names_that_duplicate_index, idx_names)
eq_(uq_names, set())
+ @testing.requires.view_reflection
@testing.combinations(
(False,), (True, testing.requires.schemas), argnames="use_schema"
)
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
index e0fdbe47a..e8dd6cf2c 100644
--- a/lib/sqlalchemy/testing/suite/test_results.py
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -261,10 +261,6 @@ class ServerSideCursorsTest(
)
return self.engine
- def tearDown(self):
- engines.testing_reaper.close_all()
- self.engine.dispose()
-
@testing.combinations(
("global_string", True, "select 1", True),
("global_text", True, text("select 1"), True),
@@ -309,24 +305,22 @@ class ServerSideCursorsTest(
def test_conn_option(self):
engine = self._fixture(False)
- # should be enabled for this one
- result = (
- engine.connect()
- .execution_options(stream_results=True)
- .exec_driver_sql("select 1")
- )
- assert self._is_server_side(result.cursor)
+ with engine.connect() as conn:
+ # should be enabled for this one
+ result = conn.execution_options(
+ stream_results=True
+ ).exec_driver_sql("select 1")
+ assert self._is_server_side(result.cursor)
def test_stmt_enabled_conn_option_disabled(self):
engine = self._fixture(False)
s = select(1).execution_options(stream_results=True)
- # not this one
- result = (
- engine.connect().execution_options(stream_results=False).execute(s)
- )
- assert not self._is_server_side(result.cursor)
+ with engine.connect() as conn:
+ # not this one
+ result = conn.execution_options(stream_results=False).execute(s)
+ assert not self._is_server_side(result.cursor)
def test_aliases_and_ss(self):
engine = self._fixture(False)
@@ -344,8 +338,7 @@ class ServerSideCursorsTest(
assert not self._is_server_side(result.cursor)
result.close()
- @testing.provide_metadata
- def test_roundtrip_fetchall(self):
+ def test_roundtrip_fetchall(self, metadata):
md = self.metadata
engine = self._fixture(True)
@@ -385,8 +378,7 @@ class ServerSideCursorsTest(
0,
)
- @testing.provide_metadata
- def test_roundtrip_fetchmany(self):
+ def test_roundtrip_fetchmany(self, metadata):
md = self.metadata
engine = self._fixture(True)
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
index f8d9b3d88..0d9f08848 100644
--- a/lib/sqlalchemy/testing/suite/test_select.py
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -742,7 +742,7 @@ class PostCompileParamsTest(
with self.sql_execution_asserter() as asserter:
with config.db.connect() as conn:
- conn.execute(stmt, q=10)
+ conn.execute(stmt, dict(q=10))
asserter.assert_(
CursorSQL(
@@ -761,7 +761,7 @@ class PostCompileParamsTest(
with self.sql_execution_asserter() as asserter:
with config.db.connect() as conn:
- conn.execute(stmt, q=[5, 6, 7])
+ conn.execute(stmt, dict(q=[5, 6, 7]))
asserter.assert_(
CursorSQL(
@@ -783,7 +783,7 @@ class PostCompileParamsTest(
with self.sql_execution_asserter() as asserter:
with config.db.connect() as conn:
- conn.execute(stmt, q=[(5, 10), (12, 18)])
+ conn.execute(stmt, dict(q=[(5, 10), (12, 18)]))
asserter.assert_(
CursorSQL(
@@ -807,7 +807,7 @@ class PostCompileParamsTest(
with self.sql_execution_asserter() as asserter:
with config.db.connect() as conn:
- conn.execute(stmt, q=[(5, "z1"), (12, "z3")])
+ conn.execute(stmt, dict(q=[(5, "z1"), (12, "z3")]))
asserter.assert_(
CursorSQL(
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
index d8c35ed0b..7445ade00 100644
--- a/lib/sqlalchemy/testing/suite/test_sequence.py
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -46,11 +46,13 @@ class SequenceTest(fixtures.TablesTest):
)
def test_insert_roundtrip(self, connection):
- connection.execute(self.tables.seq_pk.insert(), data="some data")
+ connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
self._assert_round_trip(self.tables.seq_pk, connection)
def test_insert_lastrowid(self, connection):
- r = connection.execute(self.tables.seq_pk.insert(), data="some data")
+ r = connection.execute(
+ self.tables.seq_pk.insert(), dict(data="some data")
+ )
eq_(
r.inserted_primary_key, (testing.db.dialect.default_sequence_base,)
)
@@ -62,7 +64,7 @@ class SequenceTest(fixtures.TablesTest):
@requirements.sequences_optional
def test_optional_seq(self, connection):
r = connection.execute(
- self.tables.seq_opt_pk.insert(), data="some data"
+ self.tables.seq_opt_pk.insert(), dict(data="some data")
)
eq_(r.inserted_primary_key, (1,))
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index 3a5e02c32..ebcceaae7 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -511,24 +511,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
__backend__ = True
@testing.fixture
- def do_numeric_test(self, metadata):
+ def do_numeric_test(self, metadata, connection):
@testing.emits_warning(
r".*does \*not\* support Decimal objects natively"
)
def run(type_, input_, output, filter_=None, check_scale=False):
t = Table("t", metadata, Column("x", type_))
- t.create(testing.db)
- with config.db.begin() as conn:
- conn.execute(t.insert(), [{"x": x} for x in input_])
-
- result = {row[0] for row in conn.execute(t.select())}
- output = set(output)
- if filter_:
- result = set(filter_(x) for x in result)
- output = set(filter_(x) for x in output)
- eq_(result, output)
- if check_scale:
- eq_([str(x) for x in result], [str(x) for x in output])
+ t.create(connection)
+ connection.execute(t.insert(), [{"x": x} for x in input_])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ output = set(output)
+ if filter_:
+ result = set(filter_(x) for x in result)
+ output = set(filter_(x) for x in output)
+ eq_(result, output)
+ if check_scale:
+ eq_([str(x) for x in result], [str(x) for x in output])
return run
@@ -1165,40 +1164,39 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
},
)
- def test_eval_none_flag_orm(self):
+ def test_eval_none_flag_orm(self, connection):
Base = declarative_base()
class Data(Base):
__table__ = self.tables.data_table
- s = Session(testing.db)
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
- d1 = Data(name="d1", data=None, nulldata=None)
- s.add(d1)
- s.commit()
-
- s.bulk_insert_mappings(
- Data, [{"name": "d2", "data": None, "nulldata": None}]
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String()),
- cast(self.tables.data_table.c.nulldata, String),
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
)
- .filter(self.tables.data_table.c.name == "d1")
- .first(),
- ("null", None),
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String()),
- cast(self.tables.data_table.c.nulldata, String),
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
)
- .filter(self.tables.data_table.c.name == "d2")
- .first(),
- ("null", None),
- )
class JSONLegacyStringCastIndexTest(
diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py
index f7e6da98e..3fb51ead3 100644
--- a/lib/sqlalchemy/testing/suite/test_update_delete.py
+++ b/lib/sqlalchemy/testing/suite/test_update_delete.py
@@ -32,7 +32,9 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest):
def test_update(self, connection):
t = self.tables.plain_pk
- r = connection.execute(t.update().where(t.c.id == 2), data="d2_new")
+ r = connection.execute(
+ t.update().where(t.c.id == 2), dict(data="d2_new")
+ )
assert not r.is_insert
assert not r.returns_rows
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
index eb9fcd1cd..01185c284 100644
--- a/lib/sqlalchemy/testing/util.py
+++ b/lib/sqlalchemy/testing/util.py
@@ -14,6 +14,7 @@ import types
from . import config
from . import mock
from .. import inspect
+from ..engine import Connection
from ..schema import Column
from ..schema import DropConstraint
from ..schema import DropTable
@@ -207,11 +208,13 @@ def fail(msg):
@decorator
def provide_metadata(fn, *args, **kw):
- """Provide bound MetaData for a single test, dropping afterwards."""
+ """Provide bound MetaData for a single test, dropping afterwards.
- # import cycle that only occurs with py2k's import resolver
- # in py3k this can be moved top level.
- from . import engines
+ Legacy; use the "metadata" pytest fixture.
+
+ """
+
+ from . import fixtures
metadata = schema.MetaData()
self = args[0]
@@ -220,7 +223,31 @@ def provide_metadata(fn, *args, **kw):
try:
return fn(*args, **kw)
finally:
- engines.drop_all_tables(metadata, config.db)
+ # close out some things that get in the way of dropping tables.
+ # when using the "metadata" fixture, there is a set ordering
+ # of things that makes sure things are cleaned up in order, however
+ # the simple "decorator" nature of this legacy function means
+ # we have to hardcode some of that cleanup ahead of time.
+
+ # close ORM sessions
+ fixtures._close_all_sessions()
+
+ # integrate with the "connection" fixture as there are many
+ # tests where it is used along with provide_metadata
+ if fixtures._connection_fixture_connection:
+ # TODO: this warning can be used to find all the places
+ # this is used with connection fixture
+ # warn("mixing legacy provide metadata with connection fixture")
+ drop_all_tables_from_metadata(
+ metadata, fixtures._connection_fixture_connection
+ )
+ # as the provide_metadata fixture is often used with "testing.db",
+ # when we do the drop we have to commit the transaction so that
+ # the DB is actually updated as the CREATE would have been
+ # committed
+ fixtures._connection_fixture_connection.get_transaction().commit()
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
self.metadata = prev_meta
@@ -359,6 +386,29 @@ class adict(dict):
get_all = __call__
+def drop_all_tables_from_metadata(metadata, engine_or_connection):
+ from . import engines
+
+ def go(connection):
+ engines.testing_reaper.prepare_for_drop_tables(connection)
+
+ if not connection.dialect.supports_alter:
+ from . import assertions
+
+ with assertions.expect_warnings(
+ "Can't sort tables", assert_=False
+ ):
+ metadata.drop_all(connection)
+ else:
+ metadata.drop_all(connection)
+
+ if not isinstance(engine_or_connection, Connection):
+ with engine_or_connection.begin() as connection:
+ go(connection)
+ else:
+ go(engine_or_connection)
+
+
def drop_all_tables(engine, inspector, schema=None, include_names=None):
if include_names is not None:
diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py
index 1b078a263..5699cd035 100644
--- a/lib/sqlalchemy/testing/warnings.py
+++ b/lib/sqlalchemy/testing/warnings.py
@@ -56,9 +56,6 @@ def setup_filters():
# Core execution
#
r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method",
- r"The connection.execute\(\) method in SQLAlchemy 2.0 will accept "
- "parameters as a single dictionary or a single sequence of "
- "dictionaries only.",
r"The Connection.connect\(\) method is considered legacy",
# r".*DefaultGenerator.execute\(\)",
#
diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py
index 99ecb4fb3..ca5a3abde 100644
--- a/lib/sqlalchemy/util/queue.py
+++ b/lib/sqlalchemy/util/queue.py
@@ -230,13 +230,16 @@ class AsyncAdaptedQueue:
return self.put_nowait(item)
try:
- if timeout:
+ if timeout is not None:
return self.await_(
asyncio.wait_for(self._queue.put(item), timeout)
)
else:
return self.await_(self._queue.put(item))
- except asyncio.queues.QueueFull as err:
+ except (
+ asyncio.queues.QueueFull,
+ asyncio.exceptions.TimeoutError,
+ ) as err:
compat.raise_(
Full(),
replace_context=err,
@@ -254,14 +257,18 @@ class AsyncAdaptedQueue:
def get(self, block=True, timeout=None):
if not block:
return self.get_nowait()
+
try:
- if timeout:
+ if timeout is not None:
return self.await_(
asyncio.wait_for(self._queue.get(), timeout)
)
else:
return self.await_(self._queue.get())
- except asyncio.queues.QueueEmpty as err:
+ except (
+ asyncio.queues.QueueEmpty,
+ asyncio.exceptions.TimeoutError,
+ ) as err:
compat.raise_(
Empty(),
replace_context=err,