summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-01-28 19:50:25 -0500
committerFederico Caselli <cfederico87@gmail.com>2023-01-30 22:28:53 +0100
commitd23dcbaea2a8e000c5fa2ba443e1b683b3b79fa6 (patch)
tree6b89a07b8bda5a469bf6c8dde165101315f571ed
parentb99b0c522ddb94468da27867ddfa1f7e2633c920 (diff)
downloadsqlalchemy-d23dcbaea2a8e000c5fa2ba443e1b683b3b79fa6.tar.gz
don't count / gather INSERT bind names inside of a CTE
Fixed regression related to the implementation for the new "insertmanyvalues" feature where an internal ``TypeError`` would occur in arrangements where a :func:`_sql.insert` would be referred towards inside of another :func:`_sql.insert` via a CTE; made additional repairs for this use case for positional dialects such as asyncpg when using "insertmanyvalues". at the core here is a change to positional insertmanyvalues where we now get exactly the positions for the "manyvalues" within the larger list, allowing non-"manyvalues" on the left and right sides at the same time, not assuming anything about how RETURNING renders etc., since CTEs are in the mix also. Fixes: #9173 Change-Id: I5ff071fbef0d92a2d6046b9c4e609bb008438afd
-rw-r--r--doc/build/changelog/unreleased_20/9173.rst12
-rw-r--r--lib/sqlalchemy/sql/compiler.py137
-rw-r--r--lib/sqlalchemy/sql/crud.py11
-rw-r--r--test/sql/test_cte.py66
-rw-r--r--test/sql/test_insert_exec.py120
5 files changed, 302 insertions, 44 deletions
diff --git a/doc/build/changelog/unreleased_20/9173.rst b/doc/build/changelog/unreleased_20/9173.rst
new file mode 100644
index 000000000..0e0f59520
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9173.rst
@@ -0,0 +1,12 @@
+.. change::
+ :tags: bug, sql, regression
+ :tickets: 9173
+
+ Fixed regression related to the implementation for the new
+ "insertmanyvalues" feature where an internal ``TypeError`` would occur in
+ arrangements where a :func:`_sql.insert` would be referred towards inside
+ of another :func:`_sql.insert` via a CTE; made additional repairs for this
+ use case for positional dialects such as asyncpg when using
+ "insertmanyvalues".
+
+
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 2c50081fb..d4ddc2e5d 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1545,12 +1545,12 @@ class SQLCompiler(Compiled):
self.positiontup = list(param_pos)
if self.escaped_bind_names:
- reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
- assert len(self.escaped_bind_names) == len(reverse_escape)
+ len_before = len(param_pos)
param_pos = {
self.escaped_bind_names.get(name, name): pos
for name, pos in param_pos.items()
}
+ assert len(param_pos) == len_before
# Can't use format here since % chars are not escaped.
self.string = self._pyformat_pattern.sub(
@@ -3374,7 +3374,6 @@ class SQLCompiler(Compiled):
skip_bind_expression=False,
literal_execute=False,
render_postcompile=False,
- accumulate_bind_names=None,
**kwargs,
):
if not skip_bind_expression:
@@ -3388,7 +3387,6 @@ class SQLCompiler(Compiled):
literal_binds=literal_binds and not bindparam.expanding,
literal_execute=literal_execute,
render_postcompile=render_postcompile,
- accumulate_bind_names=accumulate_bind_names,
**kwargs,
)
if bindparam.expanding:
@@ -3490,9 +3488,6 @@ class SQLCompiler(Compiled):
self.binds[bindparam.key] = self.binds[name] = bindparam
- if accumulate_bind_names is not None:
- accumulate_bind_names.add(name)
-
# if we are given a cache key that we're going to match against,
# relate the bindparam here to one that is most likely present
# in the "extracted params" portion of the cache key. this is used
@@ -3646,11 +3641,19 @@ class SQLCompiler(Compiled):
expanding: bool = False,
escaped_from: Optional[str] = None,
bindparam_type: Optional[TypeEngine[Any]] = None,
+ accumulate_bind_names: Optional[Set[str]] = None,
+ visited_bindparam: Optional[List[str]] = None,
**kw: Any,
) -> str:
- if self._visited_bindparam is not None:
- self._visited_bindparam.append(name)
+ # TODO: accumulate_bind_names is passed by crud.py to gather
+ # names on a per-value basis, visited_bindparam is passed by
+ # visit_insert() to collect all parameters in the statement.
+ # see if this gathering can be simplified somehow
+ if accumulate_bind_names is not None:
+ accumulate_bind_names.add(name)
+ if visited_bindparam is not None:
+ visited_bindparam.append(name)
if not escaped_from:
@@ -5086,6 +5089,8 @@ class SQLCompiler(Compiled):
assert insert_crud_params is not None
escaped_bind_names: Mapping[str, str]
+ expand_pos_lower_index = expand_pos_upper_index = 0
+
if not self.positional:
if self.escaped_bind_names:
escaped_bind_names = self.escaped_bind_names
@@ -5124,6 +5129,31 @@ class SQLCompiler(Compiled):
keys_to_replace = set()
base_parameters = {}
executemany_values_w_comma = f"({imv.single_values_expr}), "
+
+ all_names_we_will_expand: Set[str] = set()
+ for elem in imv.insert_crud_params:
+ all_names_we_will_expand.update(elem[3])
+
+ # get the start and end position in a particular list
+ # of parameters where we will be doing the "expanding".
+ # statements can have params on either side or both sides,
+ # given RETURNING and CTEs
+ if all_names_we_will_expand:
+ positiontup = self.positiontup
+ assert positiontup is not None
+
+ all_expand_positions = {
+ idx
+ for idx, name in enumerate(positiontup)
+ if name in all_names_we_will_expand
+ }
+ expand_pos_lower_index = min(all_expand_positions)
+ expand_pos_upper_index = max(all_expand_positions) + 1
+ assert (
+ len(all_expand_positions)
+ == expand_pos_upper_index - expand_pos_lower_index
+ )
+
if self._numeric_binds:
escaped = re.escape(self._numeric_binds_identifier_char)
executemany_values_w_comma = re.sub(
@@ -5149,52 +5179,61 @@ class SQLCompiler(Compiled):
replaced_parameters: Any
if self.positional:
- # the assumption here is that any parameters that are not
- # in the VALUES clause are expected to be parameterized
- # expressions in the RETURNING (or maybe ON CONFLICT) clause.
- # So based on
- # which sequence comes first in the compiler's INSERT
- # statement tells us where to expand the parameters.
-
- # otherwise we probably shouldn't be doing insertmanyvalues
- # on the statement.
-
num_ins_params = imv.num_positional_params_counted
batch_iterator: Iterable[Tuple[Any, ...]]
if num_ins_params == len(batch[0]):
- extra_params = ()
+ extra_params_left = extra_params_right = ()
batch_iterator = batch
- elif self.returning_precedes_values or self._numeric_binds:
- extra_params = batch[0][:-num_ins_params]
- batch_iterator = (b[-num_ins_params:] for b in batch)
else:
- extra_params = batch[0][num_ins_params:]
- batch_iterator = (b[:num_ins_params] for b in batch)
+ extra_params_left = batch[0][:expand_pos_lower_index]
+ extra_params_right = batch[0][expand_pos_upper_index:]
+ batch_iterator = (
+ b[expand_pos_lower_index:expand_pos_upper_index]
+ for b in batch
+ )
+
+ expanded_values_string = (
+ executemany_values_w_comma * len(batch)
+ )[:-2]
- values_string = (executemany_values_w_comma * len(batch))[:-2]
if self._numeric_binds and num_ins_params > 0:
+ # numeric will always number the parameters inside of
+ # VALUES (and thus order self.positiontup) to be higher
+ # than non-VALUES parameters, no matter where in the
+ # statement those non-VALUES parameters appear (this is
+ # ensured in _process_numeric by numbering first all
+ # params that are not in _values_bindparam)
+ # therefore all extra params are always
+ # on the left side and numbered lower than the VALUES
+ # parameters
+ assert not extra_params_right
+
+ start = expand_pos_lower_index + 1
+ end = num_ins_params * (len(batch)) + start
+
# need to format here, since statement may contain
# unescaped %, while values_string contains just (%s, %s)
- start = len(extra_params) + 1
- end = num_ins_params * len(batch) + start
positions = tuple(
f"{self._numeric_binds_identifier_char}{i}"
for i in range(start, end)
)
- values_string = values_string % positions
+ expanded_values_string = expanded_values_string % positions
replaced_statement = statement.replace(
- "__EXECMANY_TOKEN__", values_string
+ "__EXECMANY_TOKEN__", expanded_values_string
)
replaced_parameters = tuple(
itertools.chain.from_iterable(batch_iterator)
)
- if self.returning_precedes_values or self._numeric_binds:
- replaced_parameters = extra_params + replaced_parameters
- else:
- replaced_parameters = replaced_parameters + extra_params
+
+ replaced_parameters = (
+ extra_params_left
+ + replaced_parameters
+ + extra_params_right
+ )
+
else:
replaced_values_clauses = []
replaced_parameters = base_parameters.copy()
@@ -5224,7 +5263,7 @@ class SQLCompiler(Compiled):
)
batchnum += 1
- def visit_insert(self, insert_stmt, **kw):
+ def visit_insert(self, insert_stmt, visited_bindparam=None, **kw):
compile_state = insert_stmt._compile_state_factory(
insert_stmt, self, **kw
@@ -5250,6 +5289,9 @@ class SQLCompiler(Compiled):
counted_bindparam = 0
+ # reset any incoming "visited_bindparam" collection
+ visited_bindparam = None
+
# for positional, insertmanyvalues needs to know how many
# bound parameters are in the VALUES sequence; there's no simple
# rule because default expressions etc. can have zero or more
@@ -5257,21 +5299,30 @@ class SQLCompiler(Compiled):
# this very simplistic "count after" works and is
# likely the least amount of callcounts, though looks clumsy
if self.positional:
- self._visited_bindparam = []
+ # if we are inside a CTE, don't count parameters
+ # here since they wont be for insertmanyvalues. keep
+ # visited_bindparam at None so no counting happens.
+ # see #9173
+ has_visiting_cte = "visiting_cte" in kw
+ if not has_visiting_cte:
+ visited_bindparam = []
crud_params_struct = crud._get_crud_params(
- self, insert_stmt, compile_state, toplevel, **kw
+ self,
+ insert_stmt,
+ compile_state,
+ toplevel,
+ visited_bindparam=visited_bindparam,
+ **kw,
)
- if self.positional:
- assert self._visited_bindparam is not None
- counted_bindparam = len(self._visited_bindparam)
+ if self.positional and visited_bindparam is not None:
+ counted_bindparam = len(visited_bindparam)
if self._numeric_binds:
if self._values_bindparam is not None:
- self._values_bindparam += self._visited_bindparam
+ self._values_bindparam += visited_bindparam
else:
- self._values_bindparam = self._visited_bindparam
- self._visited_bindparam = None
+ self._values_bindparam = visited_bindparam
crud_params_single = crud_params_struct.single_params
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 5017afa78..04b62d1ff 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -150,6 +150,17 @@ def _get_crud_params(
compiler.update_prefetch = []
compiler.implicit_returning = []
+ visiting_cte = kw.get("visiting_cte", None)
+ if visiting_cte is not None:
+ # for insert -> CTE -> insert, don't populate an incoming
+ # _crud_accumulate_bind_names collection; the INSERT we process here
+ # will not be inline within the VALUES of the enclosing INSERT as the
+ # CTE is placed on the outside. See issue #9173
+ kw.pop("accumulate_bind_names", None)
+ assert (
+ "accumulate_bind_names" not in kw
+ ), "Don't know how to handle insert within insert without a CTE"
+
# getters - these are normally just column.key,
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py
index 502104dae..4ba4eddfe 100644
--- a/test/sql/test_cte.py
+++ b/test/sql/test_cte.py
@@ -1320,6 +1320,72 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
@testing.combinations(
("default_enhanced",),
("postgresql",),
+ ("postgresql+asyncpg",),
+ )
+ def test_insert_w_cte_in_scalar_subquery(self, dialect):
+ """test #9173"""
+
+ customer = table(
+ "customer",
+ column("id"),
+ column("name"),
+ )
+ order = table(
+ "order",
+ column("id"),
+ column("price"),
+ column("customer_id"),
+ )
+
+ inst = (
+ customer.insert()
+ .values(name="John")
+ .returning(customer.c.id)
+ .cte("inst")
+ )
+
+ stmt = (
+ order.insert()
+ .values(
+ price=1,
+ customer_id=select(inst.c.id).scalar_subquery(),
+ )
+ .add_cte(inst)
+ )
+
+ if dialect == "default_enhanced":
+ self.assert_compile(
+ stmt,
+ "WITH inst AS (INSERT INTO customer (name) VALUES (:param_1) "
+ 'RETURNING customer.id) INSERT INTO "order" '
+ "(price, customer_id) VALUES "
+ "(:price, (SELECT inst.id FROM inst))",
+ dialect=dialect,
+ )
+ elif dialect == "postgresql":
+ self.assert_compile(
+ stmt,
+ "WITH inst AS (INSERT INTO customer (name) "
+ "VALUES (%(param_1)s) "
+ 'RETURNING customer.id) INSERT INTO "order" '
+ "(price, customer_id) "
+ "VALUES (%(price)s, (SELECT inst.id FROM inst))",
+ dialect=dialect,
+ )
+ elif dialect == "postgresql+asyncpg":
+ self.assert_compile(
+ stmt,
+ "WITH inst AS (INSERT INTO customer (name) VALUES ($2) "
+ 'RETURNING customer.id) INSERT INTO "order" '
+ "(price, customer_id) VALUES ($1, (SELECT inst.id FROM inst))",
+ dialect=dialect,
+ )
+ else:
+ assert False
+
+ @testing.combinations(
+ ("default_enhanced",),
+ ("postgresql",),
)
def test_select_from_delete_cte(self, dialect):
t1 = table("table_1", column("id"), column("val"))
diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py
index d9dac75b3..3b5a1856c 100644
--- a/test/sql/test_insert_exec.py
+++ b/test/sql/test_insert_exec.py
@@ -23,6 +23,7 @@ from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
+from sqlalchemy.testing import provision
from sqlalchemy.testing.provision import normalize_sequence
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
@@ -825,6 +826,119 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
eq_(result.inserted_primary_key_rows, [(1,), (2,), (3,)])
+ @testing.requires.ctes_on_dml
+ @testing.variation("add_expr_returning", [True, False])
+ def test_insert_w_bindparam_in_nested_insert(
+ self, connection, add_expr_returning
+ ):
+ """test related to #9173"""
+
+ data, extra_table = self.tables("data", "extra_table")
+
+ inst = (
+ extra_table.insert()
+ .values(x_value="x", y_value="y")
+ .returning(extra_table.c.id)
+ .cte("inst")
+ )
+
+ stmt = (
+ data.insert()
+ .values(x="the x", z=select(inst.c.id).scalar_subquery())
+ .add_cte(inst)
+ )
+
+ if add_expr_returning:
+ stmt = stmt.returning(data.c.id, data.c.y + " returned y")
+ else:
+ stmt = stmt.returning(data.c.id)
+
+ result = connection.execute(
+ stmt,
+ [
+ {"y": "y1"},
+ {"y": "y2"},
+ {"y": "y3"},
+ ],
+ )
+
+ result_rows = result.all()
+
+ ids = [row[0] for row in result_rows]
+
+ extra_row = connection.execute(
+ select(extra_table).order_by(extra_table.c.id)
+ ).one()
+ extra_row_id = extra_row[0]
+ eq_(extra_row, (extra_row_id, "x", "y"))
+ eq_(
+ connection.execute(select(data).order_by(data.c.id)).all(),
+ [
+ (ids[0], "the x", "y1", extra_row_id),
+ (ids[1], "the x", "y2", extra_row_id),
+ (ids[2], "the x", "y3", extra_row_id),
+ ],
+ )
+
+ @testing.requires.provisioned_upsert
+ def test_upsert_w_returning(self, connection):
+ """test cases that will execise SQL similar to that of
+ test/orm/dml/test_bulk_statements.py
+
+ """
+
+ data = self.tables.data
+
+ initial_data = [
+ {"x": "x1", "y": "y1", "z": 4},
+ {"x": "x2", "y": "y2", "z": 8},
+ ]
+ ids = connection.scalars(
+ data.insert().returning(data.c.id), initial_data
+ ).all()
+
+ upsert_data = [
+ {
+ "id": ids[0],
+ "x": "x1",
+ "y": "y1",
+ },
+ {
+ "id": 32,
+ "x": "x19",
+ "y": "y7",
+ },
+ {
+ "id": ids[1],
+ "x": "x5",
+ "y": "y6",
+ },
+ {
+ "id": 28,
+ "x": "x9",
+ "y": "y15",
+ },
+ ]
+
+ stmt = provision.upsert(
+ config,
+ data,
+ (data,),
+ lambda inserted: {"x": inserted.x + " upserted"},
+ )
+
+ result = connection.execute(stmt, upsert_data)
+
+ eq_(
+ result.all(),
+ [
+ (ids[0], "x1 upserted", "y1", 4),
+ (32, "x19", "y7", 5),
+ (ids[1], "x5 upserted", "y2", 8),
+ (28, "x9", "y15", 5),
+ ],
+ )
+
@testing.combinations(True, False, argnames="use_returning")
@testing.combinations(1, 2, argnames="num_embedded_params")
@testing.combinations(True, False, argnames="use_whereclause")
@@ -835,7 +949,11 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
def test_insert_w_bindparam_in_subq(
self, connection, use_returning, num_embedded_params, use_whereclause
):
- """test #8639"""
+ """test #8639
+
+ see also test_insert_w_bindparam_in_nested_insert
+
+ """
t = self.tables.data
extra = self.tables.extra_table