From 0f2baae6bf72353f785bad394684f2d6fa53e0ef Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sat, 19 Nov 2022 20:39:10 +0100 Subject: Fix positional compiling bugs Fixed a series of issues regarding positionally rendered bound parameters, such as those used for SQLite, asyncpg, MySQL and others. Some compiled forms would not maintain the order of parameters correctly, such as the PostgreSQL ``regexp_replace()`` function as well as within the "nesting" feature of the :class:`.CTE` construct first introduced in :ticket:`4123`. Fixes: #8827 Change-Id: I9813ed7c358cc5c1e26725c48df546b209a442cb --- lib/sqlalchemy/testing/assertions.py | 64 ++++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 3 deletions(-) (limited to 'lib/sqlalchemy/testing/assertions.py') diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 321c05b44..790a72ec8 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -9,7 +9,9 @@ from __future__ import annotations +from collections import defaultdict import contextlib +from copy import copy from itertools import filterfalse import re import sys @@ -493,6 +495,7 @@ class AssertsCompiledSQL: render_schema_translate=False, default_schema_name=None, from_linting=False, + check_param_order=True, ): if use_default_dialect: dialect = default.DefaultDialect() @@ -506,8 +509,11 @@ class AssertsCompiledSQL: if dialect is None: dialect = config.db.dialect - elif dialect == "default": - dialect = default.DefaultDialect() + elif dialect == "default" or dialect == "default_qmark": + if dialect == "default": + dialect = default.DefaultDialect() + else: + dialect = default.DefaultDialect("qmark") dialect.supports_default_values = supports_default_values dialect.supports_default_metavalue = supports_default_metavalue elif dialect == "default_enhanced": @@ -632,7 +638,7 @@ class AssertsCompiledSQL: if checkparams is not None: eq_(c.construct_params(params), checkparams) if checkpositional is not None: - p = c.construct_params(params) + p = c.construct_params(params, escape_names=False) eq_(tuple([p[x] for x in c.positiontup]), checkpositional) if check_prefetch is not None: eq_(c.prefetch, check_prefetch) @@ -652,6 +658,58 @@ class AssertsCompiledSQL: }, check_post_param, ) + if check_param_order and getattr(c, "params", None): + + def get_dialect(paramstyle, positional): + cp = copy(dialect) + cp.paramstyle = paramstyle + cp.positional = positional + return cp + + pyformat_dialect = get_dialect("pyformat", False) + pyformat_c = clause.compile(dialect=pyformat_dialect, **kw) + stmt = re.sub(r"[\n\t]", "", str(pyformat_c)) + + qmark_dialect = get_dialect("qmark", True) + qmark_c = clause.compile(dialect=qmark_dialect, **kw) + values = list(qmark_c.positiontup) + escaped = qmark_c.escaped_bind_names + + for post_param in ( + qmark_c.post_compile_params | qmark_c.literal_execute_params + ): + name = qmark_c.bind_names[post_param] + if name in values: + values = [v for v in values if v != name] + positions = [] + pos_by_value = defaultdict(list) + for v in values: + try: + if v in pos_by_value: + start = pos_by_value[v][-1] + else: + start = 0 + esc = escaped.get(v, v) + pos = stmt.index("%%(%s)s" % (esc,), start) + 2 + positions.append(pos) + pos_by_value[v].append(pos) + except ValueError: + msg = "Expected to find bindparam %r in %r" % (v, stmt) + assert False, msg + + ordered = all( + positions[i - 1] < positions[i] + for i in range(1, len(positions)) + ) + + expected = [v for _, v in sorted(zip(positions, values))] + + msg = ( + "Order of parameters %s does not match the order " + "in the statement %s. Statement %r" % (values, expected, stmt) + ) + + is_true(ordered, msg) class ComparesTables: -- cgit v1.2.1