summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/default.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-17 16:18:55 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-19 23:15:15 -0400
commit6c3d738757d7be32dc9f99d8e1c6b5c81c596d5f (patch)
treeae142d45de71d1ebd43df1a38e54e1d3cf1063ec /lib/sqlalchemy/engine/default.py
parentc2fe4a264003933ff895c51f5d07a8456ac86382 (diff)
downloadsqlalchemy-6c3d738757d7be32dc9f99d8e1c6b5c81c596d5f.tar.gz
pep 484 for types
strict types type_api.py, including TypeDecorator, NativeForEmulated, etc. Change-Id: Ib2eba26de0981324a83733954cb7044a29bbd7db
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r--lib/sqlalchemy/engine/default.py81
1 files changed, 52 insertions, 29 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index ba34a0d42..4a833d2e5 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -57,6 +57,8 @@ from ..sql.compiler import SQLCompiler
from ..sql.elements import quoted_name
if typing.TYPE_CHECKING:
+ from types import ModuleType
+
from .base import Connection
from .base import Engine
from .characteristics import ConnectionCharacteristic
@@ -67,8 +69,10 @@ if typing.TYPE_CHECKING:
from .interfaces import _DBAPIMultiExecuteParams
from .interfaces import _DBAPISingleExecuteParams
from .interfaces import _ExecuteOptions
+ from .interfaces import _IsolationLevel
from .interfaces import _MutableCoreSingleExecuteParams
- from .result import _ProcessorType
+ from .interfaces import _ParamStyle
+ from .interfaces import DBAPIConnection
from .row import Row
from .url import URL
from ..event import _ListenerFnType
@@ -76,12 +80,16 @@ if typing.TYPE_CHECKING:
from ..pool import PoolProxiedConnection
from ..sql import Executable
from ..sql.compiler import Compiled
+ from ..sql.compiler import Linting
from ..sql.compiler import ResultColumnsEntry
from ..sql.compiler import TypeCompiler
from ..sql.dml import DMLState
+ from ..sql.dml import UpdateBase
from ..sql.elements import BindParameter
from ..sql.schema import Column
from ..sql.schema import ColumnDefault
+ from ..sql.type_api import _BindProcessorType
+ from ..sql.type_api import _ResultProcessorType
from ..sql.type_api import TypeEngine
# When we're handed literal SQL, ensure it's a SELECT query
@@ -102,10 +110,7 @@ class DefaultDialect(Dialect):
statement_compiler = compiler.SQLCompiler
ddl_compiler = compiler.DDLCompiler
- if typing.TYPE_CHECKING:
- type_compiler: TypeCompiler
- else:
- type_compiler = compiler.GenericTypeCompiler
+ type_compiler_cls = compiler.GenericTypeCompiler
preparer = compiler.IdentifierPreparer
supports_alter = True
@@ -253,20 +258,19 @@ class DefaultDialect(Dialect):
)
def __init__(
self,
- paramstyle=None,
- isolation_level=None,
- dbapi=None,
- implicit_returning=None,
- supports_native_boolean=None,
- max_identifier_length=None,
- label_length=None,
- # int() is because the @deprecated_params decorator cannot accommodate
- # the direct reference to the "NO_LINTING" object
- compiler_linting=int(compiler.NO_LINTING),
- server_side_cursors=False,
- **kwargs,
+ paramstyle: Optional[_ParamStyle] = None,
+ isolation_level: Optional[_IsolationLevel] = None,
+ dbapi: Optional[ModuleType] = None,
+ implicit_returning: Optional[bool] = None,
+ supports_native_boolean: Optional[bool] = None,
+ max_identifier_length: Optional[int] = None,
+ label_length: Optional[int] = None,
+ # util.deprecated_params decorator cannot render the
+ # Linting.NO_LINTING constant
+ compiler_linting: Linting = int(compiler.NO_LINTING), # type: ignore
+ server_side_cursors: bool = False,
+ **kwargs: Any,
):
-
if server_side_cursors:
if not self.supports_server_side_cursors:
raise exc.ArgumentError(
@@ -286,7 +290,9 @@ class DefaultDialect(Dialect):
self.positional = False
self._ischema = None
+
self.dbapi = dbapi
+
if paramstyle is not None:
self.paramstyle = paramstyle
elif self.dbapi is not None:
@@ -299,11 +305,17 @@ class DefaultDialect(Dialect):
self.identifier_preparer = self.preparer(self)
self._on_connect_isolation_level = isolation_level
- tt_callable = cast(
- Type[compiler.GenericTypeCompiler],
- self.type_compiler,
- )
- self.type_compiler = tt_callable(self)
+ legacy_tt_callable = getattr(self, "type_compiler", None)
+ if legacy_tt_callable is not None:
+ tt_callable = cast(
+ Type[compiler.GenericTypeCompiler],
+ self.type_compiler,
+ )
+ else:
+ tt_callable = self.type_compiler_cls
+
+ self.type_compiler_instance = self.type_compiler = tt_callable(self)
+
if supports_native_boolean is not None:
self.supports_native_boolean = supports_native_boolean
@@ -316,6 +328,15 @@ class DefaultDialect(Dialect):
self.compiler_linting = compiler_linting
@util.memoized_property
+ def loaded_dbapi(self) -> ModuleType:
+ if self.dbapi is None:
+ raise exc.InvalidRequestError(
+ f"Dialect {self} does not have a Python DBAPI established "
+ "and cannot be used for actual database interaction"
+ )
+ return self.dbapi
+
+ @util.memoized_property
def _bind_typing_render_casts(self):
return self.bind_typing is interfaces.BindTyping.RENDER_CASTS
@@ -495,7 +516,7 @@ class DefaultDialect(Dialect):
def connect(self, *cargs, **cparams):
# inherits the docstring from interfaces.Dialect.connect
- return self.dbapi.connect(*cargs, **cparams)
+ return self.loaded_dbapi.connect(*cargs, **cparams)
def create_connect_args(self, url):
# inherits the docstring from interfaces.Dialect.create_connect_args
@@ -584,7 +605,7 @@ class DefaultDialect(Dialect):
def _dialect_specific_select_one(self):
return str(expression.select(1).compile(dialect=self))
- def do_ping(self, dbapi_connection):
+ def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
cursor = None
try:
cursor = dbapi_connection.cursor()
@@ -592,7 +613,7 @@ class DefaultDialect(Dialect):
cursor.execute(self._dialect_specific_select_one)
finally:
cursor.close()
- except self.dbapi.Error as err:
+ except self.loaded_dbapi.Error as err:
if self.is_disconnect(err, dbapi_connection, cursor):
return False
else:
@@ -747,7 +768,7 @@ class StrCompileDialect(DefaultDialect):
statement_compiler = compiler.StrSQLCompiler
ddl_compiler = compiler.DDLCompiler
- type_compiler = compiler.StrSQLTypeCompiler # type: ignore
+ type_compiler_cls = compiler.StrSQLTypeCompiler
preparer = compiler.IdentifierPreparer
supports_statement_cache = True
@@ -906,6 +927,8 @@ class DefaultExecutionContext(ExecutionContext):
self.is_text = compiled.isplaintext
if self.isinsert or self.isupdate or self.isdelete:
+ if TYPE_CHECKING:
+ assert isinstance(compiled.statement, UpdateBase)
self.is_crud = True
self._is_explicit_returning = bool(compiled.statement._returning)
self._is_implicit_returning = bool(
@@ -943,7 +966,7 @@ class DefaultExecutionContext(ExecutionContext):
processors = compiled._bind_processors
flattened_processors: Mapping[
- str, _ProcessorType
+ str, _BindProcessorType[Any]
] = processors # type: ignore[assignment]
if compiled.literal_execute_params or compiled.post_compile_params:
@@ -1354,7 +1377,7 @@ class DefaultExecutionContext(ExecutionContext):
type_ = bindparam.type
impl_type = type_.dialect_impl(self.dialect)
- dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi)
+ dbapi_type = impl_type.get_dbapi_type(self.dialect.loaded_dbapi)
result_processor = impl_type.result_processor(
self.dialect, dbapi_type
)