diff options
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r-- | lib/sqlalchemy/testing/engines.py | 33 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_dialect.py | 6 |
2 files changed, 35 insertions, 4 deletions
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index e17c09be7..52c2d3cbf 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -7,6 +7,10 @@ import collections import re +import typing +from typing import Any +from typing import Dict +from typing import Optional import warnings import weakref @@ -15,6 +19,13 @@ from .util import decorator from .util import gc_collect from .. import event from .. import pool +from ..util.typing import Literal + + +if typing.TYPE_CHECKING: + from ..engine import Engine + from ..engine.url import URL + from ..ext.asyncio import AsyncEngine class ConnectionKiller: @@ -264,14 +275,32 @@ def reconnecting_engine(url=None, options=None): return engine +@typing.overload +def testing_engine( + url: Optional["URL"] = None, + options: Optional[Dict[str, Any]] = None, + asyncio: Literal[False] = False, + transfer_staticpool: bool = False, +) -> "Engine": + ... + + +@typing.overload +def testing_engine( + url: Optional["URL"] = None, + options: Optional[Dict[str, Any]] = None, + asyncio: Literal[True] = True, + transfer_staticpool: bool = False, +) -> "AsyncEngine": + ... + + def testing_engine( url=None, options=None, asyncio=False, transfer_staticpool=False, ): - """Produce an engine configured by --options with optional overrides.""" - if asyncio: from sqlalchemy.ext.asyncio import create_async_engine as create_engine else: diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index 28fd99876..daaea085d 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -115,7 +115,9 @@ class IsolationLevelTest(fixtures.TestBase): eq_(conn.get_isolation_level(), non_default) - conn.dialect.reset_isolation_level(conn.connection) + conn.dialect.reset_isolation_level( + conn.connection.dbapi_connection + ) eq_(conn.get_isolation_level(), existing) @@ -223,7 +225,7 @@ class AutocommitIsolationTest(fixtures.TablesTest): c2 = conn.execution_options(isolation_level="AUTOCOMMIT") self._test_conn_autocommits(c2, True) - c2.dialect.reset_isolation_level(c2.connection) + c2.dialect.reset_isolation_level(c2.connection.dbapi_connection) self._test_conn_autocommits(conn, False) |