summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-10-15 15:20:21 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-10-16 08:47:47 -0400
commit2b966de4196c8271934769337780f7d504d431cf (patch)
tree608cf4c6400faf6dccefbaefbcdd2e0db1e9bdae /lib/sqlalchemy/sql/compiler.py
parente8da50ce0f0474bc89cee15603931760cb6c55ce (diff)
downloadsqlalchemy-2b966de4196c8271934769337780f7d504d431cf.tar.gz
accommodate arbitrary embedded params in insertmanyvalues
Fixed bug in new "insertmanyvalues" feature where INSERT that included a subquery with :func:`_sql.bindparam` inside of it would fail to render correctly in "insertmanyvalues" format. This affected psycopg2 most directly as "insertmanyvalues" is used unconditionally with this driver. Fixes: #8639 Change-Id: I67903fa86afe208899d4f23f940e0727d1be2ce3
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py61
1 files changed, 37 insertions, 24 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index dd40bfe34..efe0ea2b4 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -94,7 +94,6 @@ if typing.TYPE_CHECKING:
from .elements import BindParameter
from .elements import ColumnClause
from .elements import ColumnElement
- from .elements import KeyedColumnElement
from .elements import Label
from .functions import Function
from .selectable import AliasedReturnsRows
@@ -236,6 +235,7 @@ BIND_TEMPLATES = {
"named": ":%(name)s",
}
+
_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
@@ -416,7 +416,7 @@ class _InsertManyValues(NamedTuple):
is_default_expr: bool
single_values_expr: str
- insert_crud_params: List[Tuple[KeyedColumnElement[Any], str, str]]
+ insert_crud_params: List[crud._CrudParamElementStr]
num_positional_params_counted: int
@@ -2960,6 +2960,7 @@ class SQLCompiler(Compiled):
skip_bind_expression=False,
literal_execute=False,
render_postcompile=False,
+ accumulate_bind_names=None,
**kwargs,
):
if not skip_bind_expression:
@@ -2973,6 +2974,7 @@ class SQLCompiler(Compiled):
literal_binds=literal_binds,
literal_execute=literal_execute,
render_postcompile=render_postcompile,
+ accumulate_bind_names=accumulate_bind_names,
**kwargs,
)
if bindparam.expanding:
@@ -3063,6 +3065,9 @@ class SQLCompiler(Compiled):
self.binds[bindparam.key] = self.binds[name] = bindparam
+ if accumulate_bind_names is not None:
+ accumulate_bind_names.add(name)
+
# if we are given a cache key that we're going to match against,
# relate the bindparam here to one that is most likely present
# in the "extracted params" portion of the cache key. this is used
@@ -4646,13 +4651,25 @@ class SQLCompiler(Compiled):
all_keys = set(parameters[0])
- escaped_insert_crud_params: Sequence[Any] = [
- (escaped_bind_names.get(col.key, col.key), formatted)
- for col, _, formatted in insert_crud_params
- ]
+ def apply_placeholders(keys, formatted):
+ for key in keys:
+ key = escaped_bind_names.get(key, key)
+ formatted = formatted.replace(
+ self.bindtemplate % {"name": key},
+ self.bindtemplate
+ % {"name": f"{key}__EXECMANY_INDEX__"},
+ )
+ return formatted
+
+ formatted_values_clause = f"""({', '.join(
+ apply_placeholders(bind_keys, formatted)
+ for _, _, formatted, bind_keys in insert_crud_params
+ )})"""
keys_to_replace = all_keys.intersection(
- key for key, _ in escaped_insert_crud_params
+ escaped_bind_names.get(key, key)
+ for _, _, _, bind_keys in insert_crud_params
+ for key in bind_keys
)
base_parameters = {
key: parameters[0][key]
@@ -4660,7 +4677,7 @@ class SQLCompiler(Compiled):
}
executemany_values_w_comma = ""
else:
- escaped_insert_crud_params = ()
+ formatted_values_clause = ""
keys_to_replace = set()
base_parameters = {}
executemany_values_w_comma = f"({imv.single_values_expr}), "
@@ -4723,14 +4740,10 @@ class SQLCompiler(Compiled):
replaced_parameters = base_parameters.copy()
for i, param in enumerate(batch):
- new_tokens = [
- formatted.replace(key, f"{key}__{i}")
- if key in param
- else formatted
- for key, formatted in escaped_insert_crud_params
- ]
replaced_values_clauses.append(
- f"({', '.join(new_tokens)})"
+ formatted_values_clause.replace(
+ "EXECMANY_INDEX__", str(i)
+ )
)
replaced_parameters.update(
@@ -4841,7 +4854,7 @@ class SQLCompiler(Compiled):
if crud_params_single or not supports_default_values:
text += " (%s)" % ", ".join(
- [expr for _, expr, _ in crud_params_single]
+ [expr for _, expr, _, _ in crud_params_single]
)
if self.implicit_returning or insert_stmt._returning:
@@ -4902,8 +4915,7 @@ class SQLCompiler(Compiled):
True,
self.dialect.default_metavalue_token,
cast(
- "List[Tuple[KeyedColumnElement[Any], str, str]]",
- crud_params_single,
+ "List[crud._CrudParamElementStr]", crud_params_single
),
(positiontup_after - positiontup_before),
)
@@ -4911,7 +4923,7 @@ class SQLCompiler(Compiled):
text += " VALUES %s" % (
", ".join(
"(%s)"
- % (", ".join(value for _, _, value in crud_param_set))
+ % (", ".join(value for _, _, value, _ in crud_param_set))
for crud_param_set in crud_params_struct.all_multi_params
)
)
@@ -4921,8 +4933,9 @@ class SQLCompiler(Compiled):
insert_single_values_expr = ", ".join(
[
value
- for _, _, value in cast(
- "List[Tuple[Any, Any, str]]", crud_params_single
+ for _, _, value, _ in cast(
+ "List[crud._CrudParamElementStr]",
+ crud_params_single,
)
]
)
@@ -4935,7 +4948,7 @@ class SQLCompiler(Compiled):
False,
insert_single_values_expr,
cast(
- "List[Tuple[KeyedColumnElement[Any], str, str]]",
+ "List[crud._CrudParamElementStr]",
crud_params_single,
),
positiontup_after - positiontup_before,
@@ -5058,8 +5071,8 @@ class SQLCompiler(Compiled):
text += " SET "
text += ", ".join(
expr + "=" + value
- for _, expr, value in cast(
- "List[Tuple[Any, str, str]]", crud_params
+ for _, expr, value, _ in cast(
+ "List[Tuple[Any, str, str, Any]]", crud_params
)
)