summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-12-05 16:29:52 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-12-05 16:29:52 +0000
commit679b3f7e100e40b36d6d0abb947bd62589ef3e05 (patch)
treed51f54c74ef3e542b461f6a9c8c35e0edbe462e2 /lib/sqlalchemy/dialects
parent422d8d3bcbf2b60f053ab76c3fc29f33242ccf4b (diff)
parent06c234d037bdab48e716d6c5f5dc200095269474 (diff)
downloadsqlalchemy-679b3f7e100e40b36d6d0abb947bd62589ef3e05.tar.gz
Merge "Rewrite positional handling, test for "numeric"" into main
Diffstat (limited to 'lib/sqlalchemy/dialects')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py17
-rw-r--r--lib/sqlalchemy/dialects/sqlite/provision.py8
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlite.py107
3 files changed, 117 insertions, 15 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index 751dc3dcf..b8f614eba 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -438,9 +438,6 @@ class AsyncAdapt_asyncpg_cursor:
def _handle_exception(self, error):
self._adapt_connection._handle_exception(error)
- def _parameter_placeholders(self, params):
- return tuple(f"${idx:d}" for idx, _ in enumerate(params, 1))
-
async def _prepare_and_execute(self, operation, parameters):
adapt_connection = self._adapt_connection
@@ -449,11 +446,7 @@ class AsyncAdapt_asyncpg_cursor:
if not adapt_connection._started:
await adapt_connection._start_transaction()
- if parameters is not None:
- operation = operation % self._parameter_placeholders(
- parameters
- )
- else:
+ if parameters is None:
parameters = ()
try:
@@ -506,10 +499,6 @@ class AsyncAdapt_asyncpg_cursor:
if not adapt_connection._started:
await adapt_connection._start_transaction()
- operation = operation % self._parameter_placeholders(
- seq_of_parameters[0]
- )
-
try:
return await self._connection.executemany(
operation, seq_of_parameters
@@ -808,7 +797,7 @@ class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
class AsyncAdapt_asyncpg_dbapi:
def __init__(self, asyncpg):
self.asyncpg = asyncpg
- self.paramstyle = "format"
+ self.paramstyle = "numeric_dollar"
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
@@ -900,7 +889,7 @@ class PGDialect_asyncpg(PGDialect):
render_bind_cast = True
has_terminate = True
- default_paramstyle = "format"
+ default_paramstyle = "numeric_dollar"
supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_asyncpg
statement_compiler = PGCompiler_asyncpg
diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py
index 05ee6c625..851f0951f 100644
--- a/lib/sqlalchemy/dialects/sqlite/provision.py
+++ b/lib/sqlalchemy/dialects/sqlite/provision.py
@@ -18,7 +18,13 @@ from ...testing.provision import upsert
# TODO: I can't get this to build dynamically with pytest-xdist procs
-_drivernames = {"pysqlite", "aiosqlite", "pysqlcipher"}
+_drivernames = {
+ "pysqlite",
+ "aiosqlite",
+ "pysqlcipher",
+ "pysqlite_numeric",
+ "pysqlite_dollar",
+}
@generate_driver_url.for_db("sqlite")
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
index 4475ccae7..c04a3601d 100644
--- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
@@ -637,3 +637,110 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
dialect = SQLiteDialect_pysqlite
+
+
+class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite):
+ """numeric dialect for testing only
+
+ internal use only. This dialect is **NOT** supported by SQLAlchemy
+ and may change at any time.
+
+ """
+
+ supports_statement_cache = True
+ default_paramstyle = "numeric"
+ driver = "pysqlite_numeric"
+
+ _first_bind = ":1"
+ _not_in_statement_regexp = None
+
+ def __init__(self, *arg, **kw):
+ kw.setdefault("paramstyle", "numeric")
+ super().__init__(*arg, **kw)
+
+ def create_connect_args(self, url):
+ arg, opts = super().create_connect_args(url)
+ opts["factory"] = self._fix_sqlite_issue_99953()
+ return arg, opts
+
+ def _fix_sqlite_issue_99953(self):
+ import sqlite3
+
+ first_bind = self._first_bind
+ if self._not_in_statement_regexp:
+ nis = self._not_in_statement_regexp
+
+ def _test_sql(sql):
+ m = nis.search(sql)
+ assert not m, f"Found {nis.pattern!r} in {sql!r}"
+
+ else:
+
+ def _test_sql(sql):
+ pass
+
+ def _numeric_param_as_dict(parameters):
+ if parameters:
+ assert isinstance(parameters, tuple)
+ return {
+ str(idx): value for idx, value in enumerate(parameters, 1)
+ }
+ else:
+ return ()
+
+ class SQLiteFix99953Cursor(sqlite3.Cursor):
+ def execute(self, sql, parameters=()):
+ _test_sql(sql)
+ if first_bind in sql:
+ parameters = _numeric_param_as_dict(parameters)
+ return super().execute(sql, parameters)
+
+ def executemany(self, sql, parameters):
+ _test_sql(sql)
+ if first_bind in sql:
+ parameters = [
+ _numeric_param_as_dict(p) for p in parameters
+ ]
+ return super().executemany(sql, parameters)
+
+ class SQLiteFix99953Connection(sqlite3.Connection):
+ def cursor(self, factory=None):
+ if factory is None:
+ factory = SQLiteFix99953Cursor
+ return super().cursor(factory=factory)
+
+ def execute(self, sql, parameters=()):
+ _test_sql(sql)
+ if first_bind in sql:
+ parameters = _numeric_param_as_dict(parameters)
+ return super().execute(sql, parameters)
+
+ def executemany(self, sql, parameters):
+ _test_sql(sql)
+ if first_bind in sql:
+ parameters = [
+ _numeric_param_as_dict(p) for p in parameters
+ ]
+ return super().executemany(sql, parameters)
+
+ return SQLiteFix99953Connection
+
+
+class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric):
+ """numeric dialect that uses $ for testing only
+
+ internal use only. This dialect is **NOT** supported by SQLAlchemy
+ and may change at any time.
+
+ """
+
+ supports_statement_cache = True
+ default_paramstyle = "numeric_dollar"
+ driver = "pysqlite_dollar"
+
+ _first_bind = "$1"
+ _not_in_statement_regexp = re.compile(r"[^\d]:\d+")
+
+ def __init__(self, *arg, **kw):
+ kw.setdefault("paramstyle", "numeric_dollar")
+ super().__init__(*arg, **kw)