diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-17 16:18:55 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-19 23:15:15 -0400 |
| commit | 6c3d738757d7be32dc9f99d8e1c6b5c81c596d5f (patch) | |
| tree | ae142d45de71d1ebd43df1a38e54e1d3cf1063ec /lib/sqlalchemy/engine/default.py | |
| parent | c2fe4a264003933ff895c51f5d07a8456ac86382 (diff) | |
| download | sqlalchemy-6c3d738757d7be32dc9f99d8e1c6b5c81c596d5f.tar.gz | |
pep 484 for types
strict types type_api.py, including TypeDecorator,
NativeForEmulated, etc.
Change-Id: Ib2eba26de0981324a83733954cb7044a29bbd7db
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 81 |
1 files changed, 52 insertions, 29 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ba34a0d42..4a833d2e5 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -57,6 +57,8 @@ from ..sql.compiler import SQLCompiler from ..sql.elements import quoted_name if typing.TYPE_CHECKING: + from types import ModuleType + from .base import Connection from .base import Engine from .characteristics import ConnectionCharacteristic @@ -67,8 +69,10 @@ if typing.TYPE_CHECKING: from .interfaces import _DBAPIMultiExecuteParams from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions + from .interfaces import _IsolationLevel from .interfaces import _MutableCoreSingleExecuteParams - from .result import _ProcessorType + from .interfaces import _ParamStyle + from .interfaces import DBAPIConnection from .row import Row from .url import URL from ..event import _ListenerFnType @@ -76,12 +80,16 @@ if typing.TYPE_CHECKING: from ..pool import PoolProxiedConnection from ..sql import Executable from ..sql.compiler import Compiled + from ..sql.compiler import Linting from ..sql.compiler import ResultColumnsEntry from ..sql.compiler import TypeCompiler from ..sql.dml import DMLState + from ..sql.dml import UpdateBase from ..sql.elements import BindParameter from ..sql.schema import Column from ..sql.schema import ColumnDefault + from ..sql.type_api import _BindProcessorType + from ..sql.type_api import _ResultProcessorType from ..sql.type_api import TypeEngine # When we're handed literal SQL, ensure it's a SELECT query @@ -102,10 +110,7 @@ class DefaultDialect(Dialect): statement_compiler = compiler.SQLCompiler ddl_compiler = compiler.DDLCompiler - if typing.TYPE_CHECKING: - type_compiler: TypeCompiler - else: - type_compiler = compiler.GenericTypeCompiler + type_compiler_cls = compiler.GenericTypeCompiler preparer = compiler.IdentifierPreparer supports_alter = True @@ -253,20 +258,19 @@ class DefaultDialect(Dialect): ) def __init__( self, - paramstyle=None, - isolation_level=None, - dbapi=None, - implicit_returning=None, - supports_native_boolean=None, - max_identifier_length=None, - label_length=None, - # int() is because the @deprecated_params decorator cannot accommodate - # the direct reference to the "NO_LINTING" object - compiler_linting=int(compiler.NO_LINTING), - server_side_cursors=False, - **kwargs, + paramstyle: Optional[_ParamStyle] = None, + isolation_level: Optional[_IsolationLevel] = None, + dbapi: Optional[ModuleType] = None, + implicit_returning: Optional[bool] = None, + supports_native_boolean: Optional[bool] = None, + max_identifier_length: Optional[int] = None, + label_length: Optional[int] = None, + # util.deprecated_params decorator cannot render the + # Linting.NO_LINTING constant + compiler_linting: Linting = int(compiler.NO_LINTING), # type: ignore + server_side_cursors: bool = False, + **kwargs: Any, ): - if server_side_cursors: if not self.supports_server_side_cursors: raise exc.ArgumentError( @@ -286,7 +290,9 @@ class DefaultDialect(Dialect): self.positional = False self._ischema = None + self.dbapi = dbapi + if paramstyle is not None: self.paramstyle = paramstyle elif self.dbapi is not None: @@ -299,11 +305,17 @@ class DefaultDialect(Dialect): self.identifier_preparer = self.preparer(self) self._on_connect_isolation_level = isolation_level - tt_callable = cast( - Type[compiler.GenericTypeCompiler], - self.type_compiler, - ) - self.type_compiler = tt_callable(self) + legacy_tt_callable = getattr(self, "type_compiler", None) + if legacy_tt_callable is not None: + tt_callable = cast( + Type[compiler.GenericTypeCompiler], + self.type_compiler, + ) + else: + tt_callable = self.type_compiler_cls + + self.type_compiler_instance = self.type_compiler = tt_callable(self) + if supports_native_boolean is not None: self.supports_native_boolean = supports_native_boolean @@ -316,6 +328,15 @@ class DefaultDialect(Dialect): self.compiler_linting = compiler_linting @util.memoized_property + def loaded_dbapi(self) -> ModuleType: + if self.dbapi is None: + raise exc.InvalidRequestError( + f"Dialect {self} does not have a Python DBAPI established " + "and cannot be used for actual database interaction" + ) + return self.dbapi + + @util.memoized_property def _bind_typing_render_casts(self): return self.bind_typing is interfaces.BindTyping.RENDER_CASTS @@ -495,7 +516,7 @@ class DefaultDialect(Dialect): def connect(self, *cargs, **cparams): # inherits the docstring from interfaces.Dialect.connect - return self.dbapi.connect(*cargs, **cparams) + return self.loaded_dbapi.connect(*cargs, **cparams) def create_connect_args(self, url): # inherits the docstring from interfaces.Dialect.create_connect_args @@ -584,7 +605,7 @@ class DefaultDialect(Dialect): def _dialect_specific_select_one(self): return str(expression.select(1).compile(dialect=self)) - def do_ping(self, dbapi_connection): + def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: cursor = None try: cursor = dbapi_connection.cursor() @@ -592,7 +613,7 @@ class DefaultDialect(Dialect): cursor.execute(self._dialect_specific_select_one) finally: cursor.close() - except self.dbapi.Error as err: + except self.loaded_dbapi.Error as err: if self.is_disconnect(err, dbapi_connection, cursor): return False else: @@ -747,7 +768,7 @@ class StrCompileDialect(DefaultDialect): statement_compiler = compiler.StrSQLCompiler ddl_compiler = compiler.DDLCompiler - type_compiler = compiler.StrSQLTypeCompiler # type: ignore + type_compiler_cls = compiler.StrSQLTypeCompiler preparer = compiler.IdentifierPreparer supports_statement_cache = True @@ -906,6 +927,8 @@ class DefaultExecutionContext(ExecutionContext): self.is_text = compiled.isplaintext if self.isinsert or self.isupdate or self.isdelete: + if TYPE_CHECKING: + assert isinstance(compiled.statement, UpdateBase) self.is_crud = True self._is_explicit_returning = bool(compiled.statement._returning) self._is_implicit_returning = bool( @@ -943,7 +966,7 @@ class DefaultExecutionContext(ExecutionContext): processors = compiled._bind_processors flattened_processors: Mapping[ - str, _ProcessorType + str, _BindProcessorType[Any] ] = processors # type: ignore[assignment] if compiled.literal_execute_params or compiled.post_compile_params: @@ -1354,7 +1377,7 @@ class DefaultExecutionContext(ExecutionContext): type_ = bindparam.type impl_type = type_.dialect_impl(self.dialect) - dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi) + dbapi_type = impl_type.get_dbapi_type(self.dialect.loaded_dbapi) result_processor = impl_type.result_processor( self.dialect, dbapi_type ) |
