diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-10-15 15:20:21 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-10-16 08:47:47 -0400 |
commit | 2b966de4196c8271934769337780f7d504d431cf (patch) | |
tree | 608cf4c6400faf6dccefbaefbcdd2e0db1e9bdae /lib/sqlalchemy/sql/compiler.py | |
parent | e8da50ce0f0474bc89cee15603931760cb6c55ce (diff) | |
download | sqlalchemy-2b966de4196c8271934769337780f7d504d431cf.tar.gz |
accommodate arbitrary embedded params in insertmanyvalues
Fixed bug in new "insertmanyvalues" feature where INSERT that included a
subquery with :func:`_sql.bindparam` inside of it would fail to render
correctly in "insertmanyvalues" format. This affected psycopg2 most
directly as "insertmanyvalues" is used unconditionally with this driver.
Fixes: #8639
Change-Id: I67903fa86afe208899d4f23f940e0727d1be2ce3
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 61 |
1 files changed, 37 insertions, 24 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index dd40bfe34..efe0ea2b4 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -94,7 +94,6 @@ if typing.TYPE_CHECKING: from .elements import BindParameter from .elements import ColumnClause from .elements import ColumnElement - from .elements import KeyedColumnElement from .elements import Label from .functions import Function from .selectable import AliasedReturnsRows @@ -236,6 +235,7 @@ BIND_TEMPLATES = { "named": ":%(name)s", } + _BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]") _BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__")) @@ -416,7 +416,7 @@ class _InsertManyValues(NamedTuple): is_default_expr: bool single_values_expr: str - insert_crud_params: List[Tuple[KeyedColumnElement[Any], str, str]] + insert_crud_params: List[crud._CrudParamElementStr] num_positional_params_counted: int @@ -2960,6 +2960,7 @@ class SQLCompiler(Compiled): skip_bind_expression=False, literal_execute=False, render_postcompile=False, + accumulate_bind_names=None, **kwargs, ): if not skip_bind_expression: @@ -2973,6 +2974,7 @@ class SQLCompiler(Compiled): literal_binds=literal_binds, literal_execute=literal_execute, render_postcompile=render_postcompile, + accumulate_bind_names=accumulate_bind_names, **kwargs, ) if bindparam.expanding: @@ -3063,6 +3065,9 @@ class SQLCompiler(Compiled): self.binds[bindparam.key] = self.binds[name] = bindparam + if accumulate_bind_names is not None: + accumulate_bind_names.add(name) + # if we are given a cache key that we're going to match against, # relate the bindparam here to one that is most likely present # in the "extracted params" portion of the cache key. this is used @@ -4646,13 +4651,25 @@ class SQLCompiler(Compiled): all_keys = set(parameters[0]) - escaped_insert_crud_params: Sequence[Any] = [ - (escaped_bind_names.get(col.key, col.key), formatted) - for col, _, formatted in insert_crud_params - ] + def apply_placeholders(keys, formatted): + for key in keys: + key = escaped_bind_names.get(key, key) + formatted = formatted.replace( + self.bindtemplate % {"name": key}, + self.bindtemplate + % {"name": f"{key}__EXECMANY_INDEX__"}, + ) + return formatted + + formatted_values_clause = f"""({', '.join( + apply_placeholders(bind_keys, formatted) + for _, _, formatted, bind_keys in insert_crud_params + )})""" keys_to_replace = all_keys.intersection( - key for key, _ in escaped_insert_crud_params + escaped_bind_names.get(key, key) + for _, _, _, bind_keys in insert_crud_params + for key in bind_keys ) base_parameters = { key: parameters[0][key] @@ -4660,7 +4677,7 @@ class SQLCompiler(Compiled): } executemany_values_w_comma = "" else: - escaped_insert_crud_params = () + formatted_values_clause = "" keys_to_replace = set() base_parameters = {} executemany_values_w_comma = f"({imv.single_values_expr}), " @@ -4723,14 +4740,10 @@ class SQLCompiler(Compiled): replaced_parameters = base_parameters.copy() for i, param in enumerate(batch): - new_tokens = [ - formatted.replace(key, f"{key}__{i}") - if key in param - else formatted - for key, formatted in escaped_insert_crud_params - ] replaced_values_clauses.append( - f"({', '.join(new_tokens)})" + formatted_values_clause.replace( + "EXECMANY_INDEX__", str(i) + ) ) replaced_parameters.update( @@ -4841,7 +4854,7 @@ class SQLCompiler(Compiled): if crud_params_single or not supports_default_values: text += " (%s)" % ", ".join( - [expr for _, expr, _ in crud_params_single] + [expr for _, expr, _, _ in crud_params_single] ) if self.implicit_returning or insert_stmt._returning: @@ -4902,8 +4915,7 @@ class SQLCompiler(Compiled): True, self.dialect.default_metavalue_token, cast( - "List[Tuple[KeyedColumnElement[Any], str, str]]", - crud_params_single, + "List[crud._CrudParamElementStr]", crud_params_single ), (positiontup_after - positiontup_before), ) @@ -4911,7 +4923,7 @@ class SQLCompiler(Compiled): text += " VALUES %s" % ( ", ".join( "(%s)" - % (", ".join(value for _, _, value in crud_param_set)) + % (", ".join(value for _, _, value, _ in crud_param_set)) for crud_param_set in crud_params_struct.all_multi_params ) ) @@ -4921,8 +4933,9 @@ class SQLCompiler(Compiled): insert_single_values_expr = ", ".join( [ value - for _, _, value in cast( - "List[Tuple[Any, Any, str]]", crud_params_single + for _, _, value, _ in cast( + "List[crud._CrudParamElementStr]", + crud_params_single, ) ] ) @@ -4935,7 +4948,7 @@ class SQLCompiler(Compiled): False, insert_single_values_expr, cast( - "List[Tuple[KeyedColumnElement[Any], str, str]]", + "List[crud._CrudParamElementStr]", crud_params_single, ), positiontup_after - positiontup_before, @@ -5058,8 +5071,8 @@ class SQLCompiler(Compiled): text += " SET " text += ", ".join( expr + "=" + value - for _, expr, value in cast( - "List[Tuple[Any, str, str]]", crud_params + for _, expr, value, _ in cast( + "List[Tuple[Any, str, str, Any]]", crud_params ) ) |