summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/assertsql.py26
-rw-r--r--lib/sqlalchemy/testing/config.py2
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py16
3 files changed, 34 insertions, 10 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index d183372c3..45a2496dd 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -11,6 +11,7 @@ from __future__ import annotations
import collections
import contextlib
+import itertools
import re
from .. import event
@@ -285,7 +286,8 @@ class DialectSQL(CompiledSQL):
return received_stmt, execute_observed.context.compiled_parameters
- def _dialect_adjusted_statement(self, paramstyle):
+ def _dialect_adjusted_statement(self, dialect):
+ paramstyle = dialect.paramstyle
stmt = re.sub(r"[\n\t]", "", self.statement)
# temporarily escape out PG double colons
@@ -300,8 +302,14 @@ class DialectSQL(CompiledSQL):
repl = "?"
elif paramstyle == "format":
repl = r"%s"
- elif paramstyle == "numeric":
- repl = None
+ elif paramstyle.startswith("numeric"):
+ counter = itertools.count(1)
+
+ num_identifier = "$" if paramstyle == "numeric_dollar" else ":"
+
+ def repl(m):
+ return f"{num_identifier}{next(counter)}"
+
stmt = re.sub(r":([\w_]+)", repl, stmt)
# put them back
@@ -310,20 +318,20 @@ class DialectSQL(CompiledSQL):
return stmt
def _compare_sql(self, execute_observed, received_statement):
- paramstyle = execute_observed.context.dialect.paramstyle
- stmt = self._dialect_adjusted_statement(paramstyle)
+ stmt = self._dialect_adjusted_statement(
+ execute_observed.context.dialect
+ )
return received_statement == stmt
def _failure_message(self, execute_observed, expected_params):
- paramstyle = execute_observed.context.dialect.paramstyle
return (
"Testing for compiled statement\n%r partial params %s, "
"received\n%%(received_statement)r with params "
"%%(received_parameters)r"
% (
- self._dialect_adjusted_statement(paramstyle).replace(
- "%", "%%"
- ),
+ self._dialect_adjusted_statement(
+ execute_observed.context.dialect
+ ).replace("%", "%%"),
repr(expected_params).replace("%", "%%"),
)
)
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index 957876579..6adcf5b64 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -189,7 +189,7 @@ def variation(argname, cases):
elif querytyp.legacy_query:
stmt = Session.query(Thing)
else:
- assert False
+ querytyp.fail()
The variable provided is a slots object of boolean variables, as well
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 656a4e98a..ffe0f453a 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -371,6 +371,22 @@ def _setup_options(opt, file_config):
options = opt
+@pre
+def _register_sqlite_numeric_dialect(opt, file_config):
+ from sqlalchemy.dialects import registry
+
+ registry.register(
+ "sqlite.pysqlite_numeric",
+ "sqlalchemy.dialects.sqlite.pysqlite",
+ "_SQLiteDialect_pysqlite_numeric",
+ )
+ registry.register(
+ "sqlite.pysqlite_dollar",
+ "sqlalchemy.dialects.sqlite.pysqlite",
+ "_SQLiteDialect_pysqlite_dollar",
+ )
+
+
@post
def __ensure_cext(opt, file_config):
if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1":