summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/engines.py33
-rw-r--r--lib/sqlalchemy/testing/suite/test_dialect.py6
2 files changed, 35 insertions, 4 deletions
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index e17c09be7..52c2d3cbf 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -7,6 +7,10 @@
import collections
import re
+import typing
+from typing import Any
+from typing import Dict
+from typing import Optional
import warnings
import weakref
@@ -15,6 +19,13 @@ from .util import decorator
from .util import gc_collect
from .. import event
from .. import pool
+from ..util.typing import Literal
+
+
+if typing.TYPE_CHECKING:
+ from ..engine import Engine
+ from ..engine.url import URL
+ from ..ext.asyncio import AsyncEngine
class ConnectionKiller:
@@ -264,14 +275,32 @@ def reconnecting_engine(url=None, options=None):
return engine
+@typing.overload
+def testing_engine(
+ url: Optional["URL"] = None,
+ options: Optional[Dict[str, Any]] = None,
+ asyncio: Literal[False] = False,
+ transfer_staticpool: bool = False,
+) -> "Engine":
+ ...
+
+
+@typing.overload
+def testing_engine(
+ url: Optional["URL"] = None,
+ options: Optional[Dict[str, Any]] = None,
+ asyncio: Literal[True] = True,
+ transfer_staticpool: bool = False,
+) -> "AsyncEngine":
+ ...
+
+
def testing_engine(
url=None,
options=None,
asyncio=False,
transfer_staticpool=False,
):
- """Produce an engine configured by --options with optional overrides."""
-
if asyncio:
from sqlalchemy.ext.asyncio import create_async_engine as create_engine
else:
diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py
index 28fd99876..daaea085d 100644
--- a/lib/sqlalchemy/testing/suite/test_dialect.py
+++ b/lib/sqlalchemy/testing/suite/test_dialect.py
@@ -115,7 +115,9 @@ class IsolationLevelTest(fixtures.TestBase):
eq_(conn.get_isolation_level(), non_default)
- conn.dialect.reset_isolation_level(conn.connection)
+ conn.dialect.reset_isolation_level(
+ conn.connection.dbapi_connection
+ )
eq_(conn.get_isolation_level(), existing)
@@ -223,7 +225,7 @@ class AutocommitIsolationTest(fixtures.TablesTest):
c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
self._test_conn_autocommits(c2, True)
- c2.dialect.reset_isolation_level(c2.connection)
+ c2.dialect.reset_isolation_level(c2.connection.dbapi_connection)
self._test_conn_autocommits(conn, False)