summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/default.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-08 17:14:41 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-13 15:29:20 -0400
commit769fa67d842035dd852ab8b6a26ea3f110a51131 (patch)
tree5c121caca336071091c6f5ea4c54743c92d6458a /lib/sqlalchemy/engine/default.py
parent77fc8216a74e6b2d0efc6591c6c735687bd10002 (diff)
downloadsqlalchemy-769fa67d842035dd852ab8b6a26ea3f110a51131.tar.gz
pep-484: sqlalchemy.sql pass one
sqlalchemy.sql will require many passes to get all modules even gradually typed. Will have to pick and choose what modules can be strictly typed vs. which can be gradual. in this patch, emphasis is on visitors.py, cache_key.py, annotations.py for strict typing, compiler.py is on gradual typing but has much more structure, in particular where it connects with the outside world. The work within compiler.py also reached back out to engine/cursor.py , default.py quite a bit. References: #6810 Change-Id: I6e8a29f6013fd216e43d45091bc193f8be0368fd
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r--lib/sqlalchemy/engine/default.py28
1 files changed, 18 insertions, 10 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 2579f573c..c9fb1ebf2 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -46,6 +46,7 @@ from .interfaces import ExecutionContext
from .. import event
from .. import exc
from .. import pool
+from .. import TupleType
from .. import types as sqltypes
from .. import util
from ..sql import compiler
@@ -76,6 +77,8 @@ if typing.TYPE_CHECKING:
from ..sql.compiler import Compiled
from ..sql.compiler import ResultColumnsEntry
from ..sql.compiler import TypeCompiler
+ from ..sql.dml import DMLState
+ from ..sql.elements import BindParameter
from ..sql.schema import Column
from ..sql.type_api import TypeEngine
@@ -820,7 +823,7 @@ class DefaultExecutionContext(ExecutionContext):
cursor: DBAPICursor
compiled_parameters: List[_MutableCoreSingleExecuteParams]
parameters: _DBAPIMultiExecuteParams
- extracted_parameters: _CoreSingleExecuteParams
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]]
_empty_dict_params = cast("Mapping[str, Any]", util.EMPTY_DICT)
@@ -878,7 +881,7 @@ class DefaultExecutionContext(ExecutionContext):
compiled: SQLCompiler,
parameters: _CoreMultiExecuteParams,
invoked_statement: Executable,
- extracted_parameters: _CoreSingleExecuteParams,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]],
cache_hit: CacheStats = CacheStats.CACHING_DISABLED,
) -> ExecutionContext:
"""Initialize execution context for a Compiled construct."""
@@ -1513,9 +1516,10 @@ class DefaultExecutionContext(ExecutionContext):
inputsizes, self.cursor, self.statement, self.parameters, self
)
- has_escaped_names = bool(compiled.escaped_bind_names)
- if has_escaped_names:
+ if compiled.escaped_bind_names:
escaped_bind_names = compiled.escaped_bind_names
+ else:
+ escaped_bind_names = None
if dialect.positional:
items = [
@@ -1535,17 +1539,18 @@ class DefaultExecutionContext(ExecutionContext):
if key in self._expanded_parameters:
if bindparam.type._is_tuple_type:
- num = len(bindparam.type.types)
+ tup_type = cast(TupleType, bindparam.type)
+ num = len(tup_type.types)
dbtypes = inputsizes[bindparam]
generic_inputsizes.extend(
(
(
escaped_bind_names.get(paramname, paramname)
- if has_escaped_names
+ if escaped_bind_names is not None
else paramname
),
dbtypes[idx % num],
- bindparam.type.types[idx % num],
+ tup_type.types[idx % num],
)
for idx, paramname in enumerate(
self._expanded_parameters[key]
@@ -1557,7 +1562,7 @@ class DefaultExecutionContext(ExecutionContext):
(
(
escaped_bind_names.get(paramname, paramname)
- if has_escaped_names
+ if escaped_bind_names is not None
else paramname
),
dbtype,
@@ -1570,7 +1575,7 @@ class DefaultExecutionContext(ExecutionContext):
escaped_name = (
escaped_bind_names.get(key, key)
- if has_escaped_names
+ if escaped_bind_names is not None
else key
)
@@ -1702,7 +1707,9 @@ class DefaultExecutionContext(ExecutionContext):
else:
assert column is not None
assert parameters is not None
- compile_state = cast(SQLCompiler, self.compiled).compile_state
+ compile_state = cast(
+ "DMLState", cast(SQLCompiler, self.compiled).compile_state
+ )
assert compile_state is not None
if (
isolate_multiinsert_groups
@@ -1715,6 +1722,7 @@ class DefaultExecutionContext(ExecutionContext):
else:
d = {column.key: parameters[column.key]}
index = 0
+ assert compile_state._dict_parameters is not None
keys = compile_state._dict_parameters.keys()
d.update(
(key, parameters["%s_m%d" % (key, index)]) for key in keys