summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/mock.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine/mock.py')
-rw-r--r--lib/sqlalchemy/engine/mock.py57
1 files changed, 43 insertions, 14 deletions
diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py
index 76e77a3f3..a0ba96603 100644
--- a/lib/sqlalchemy/engine/mock.py
+++ b/lib/sqlalchemy/engine/mock.py
@@ -8,40 +8,69 @@
from __future__ import annotations
from operator import attrgetter
+import typing
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Optional
+from typing import Type
+from typing import Union
from . import url as _url
from .. import util
+if typing.TYPE_CHECKING:
+ from .base import Connection
+ from .base import Engine
+ from .interfaces import _CoreAnyExecuteParams
+ from .interfaces import _ExecuteOptionsParameter
+ from .interfaces import Dialect
+ from .url import URL
+ from ..sql.base import Executable
+ from ..sql.ddl import DDLElement
+ from ..sql.ddl import SchemaDropper
+ from ..sql.ddl import SchemaGenerator
+ from ..sql.schema import HasSchemaAttr
+
+
class MockConnection:
- def __init__(self, dialect, execute):
+ def __init__(self, dialect: Dialect, execute: Callable[..., Any]):
self._dialect = dialect
- self.execute = execute
+ self._execute_impl = execute
- engine = property(lambda s: s)
- dialect = property(attrgetter("_dialect"))
- name = property(lambda s: s._dialect.name)
+ engine: Engine = cast(Any, property(lambda s: s))
+ dialect: Dialect = cast(Any, property(attrgetter("_dialect")))
+ name: str = cast(Any, property(lambda s: s._dialect.name))
- def connect(self, **kwargs):
+ def connect(self, **kwargs: Any) -> MockConnection:
return self
- def schema_for_object(self, obj):
+ def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]:
return obj.schema
- def execution_options(self, **kw):
+ def execution_options(self, **kw: Any) -> MockConnection:
return self
def _run_ddl_visitor(
- self, visitorcallable, element, connection=None, **kwargs
- ):
+ self,
+ visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
+ element: DDLElement,
+ **kwargs: Any,
+ ) -> None:
kwargs["checkfirst"] = False
visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
- def execute(self, object_, *multiparams, **params):
- raise NotImplementedError()
+ def execute(
+ self,
+ obj: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Any:
+ return self._execute_impl(obj, parameters)
-def create_mock_engine(url, executor, **kw):
+def create_mock_engine(url: URL, executor: Any, **kw: Any) -> MockConnection:
"""Create a "mock" engine used for echoing DDL.
This is a utility function used for debugging or storing the output of DDL
@@ -96,6 +125,6 @@ def create_mock_engine(url, executor, **kw):
dialect_args[k] = kw.pop(k)
# create dialect
- dialect = dialect_cls(**dialect_args)
+ dialect = dialect_cls(**dialect_args) # type: ignore
return MockConnection(dialect, executor)