diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2022-12-05 16:29:52 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-12-05 16:29:52 +0000 |
| commit | 679b3f7e100e40b36d6d0abb947bd62589ef3e05 (patch) | |
| tree | d51f54c74ef3e542b461f6a9c8c35e0edbe462e2 /lib/sqlalchemy/testing/assertsql.py | |
| parent | 422d8d3bcbf2b60f053ab76c3fc29f33242ccf4b (diff) | |
| parent | 06c234d037bdab48e716d6c5f5dc200095269474 (diff) | |
| download | sqlalchemy-679b3f7e100e40b36d6d0abb947bd62589ef3e05.tar.gz | |
Merge "Rewrite positional handling, test for "numeric"" into main
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
| -rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 26 |
1 files changed, 17 insertions, 9 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index d183372c3..45a2496dd 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -11,6 +11,7 @@ from __future__ import annotations import collections import contextlib +import itertools import re from .. import event @@ -285,7 +286,8 @@ class DialectSQL(CompiledSQL): return received_stmt, execute_observed.context.compiled_parameters - def _dialect_adjusted_statement(self, paramstyle): + def _dialect_adjusted_statement(self, dialect): + paramstyle = dialect.paramstyle stmt = re.sub(r"[\n\t]", "", self.statement) # temporarily escape out PG double colons @@ -300,8 +302,14 @@ class DialectSQL(CompiledSQL): repl = "?" elif paramstyle == "format": repl = r"%s" - elif paramstyle == "numeric": - repl = None + elif paramstyle.startswith("numeric"): + counter = itertools.count(1) + + num_identifier = "$" if paramstyle == "numeric_dollar" else ":" + + def repl(m): + return f"{num_identifier}{next(counter)}" + stmt = re.sub(r":([\w_]+)", repl, stmt) # put them back @@ -310,20 +318,20 @@ class DialectSQL(CompiledSQL): return stmt def _compare_sql(self, execute_observed, received_statement): - paramstyle = execute_observed.context.dialect.paramstyle - stmt = self._dialect_adjusted_statement(paramstyle) + stmt = self._dialect_adjusted_statement( + execute_observed.context.dialect + ) return received_statement == stmt def _failure_message(self, execute_observed, expected_params): - paramstyle = execute_observed.context.dialect.paramstyle return ( "Testing for compiled statement\n%r partial params %s, " "received\n%%(received_statement)r with params " "%%(received_parameters)r" % ( - self._dialect_adjusted_statement(paramstyle).replace( - "%", "%%" - ), + self._dialect_adjusted_statement( + execute_observed.context.dialect + ).replace("%", "%%"), repr(expected_params).replace("%", "%%"), ) ) |
