diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2022-12-05 16:29:52 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-12-05 16:29:52 +0000 |
| commit | 679b3f7e100e40b36d6d0abb947bd62589ef3e05 (patch) | |
| tree | d51f54c74ef3e542b461f6a9c8c35e0edbe462e2 /lib/sqlalchemy/dialects | |
| parent | 422d8d3bcbf2b60f053ab76c3fc29f33242ccf4b (diff) | |
| parent | 06c234d037bdab48e716d6c5f5dc200095269474 (diff) | |
| download | sqlalchemy-679b3f7e100e40b36d6d0abb947bd62589ef3e05.tar.gz | |
Merge "Rewrite positional handling, test for "numeric"" into main
Diffstat (limited to 'lib/sqlalchemy/dialects')
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 17 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/provision.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/pysqlite.py | 107 |
3 files changed, 117 insertions, 15 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 751dc3dcf..b8f614eba 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -438,9 +438,6 @@ class AsyncAdapt_asyncpg_cursor: def _handle_exception(self, error): self._adapt_connection._handle_exception(error) - def _parameter_placeholders(self, params): - return tuple(f"${idx:d}" for idx, _ in enumerate(params, 1)) - async def _prepare_and_execute(self, operation, parameters): adapt_connection = self._adapt_connection @@ -449,11 +446,7 @@ class AsyncAdapt_asyncpg_cursor: if not adapt_connection._started: await adapt_connection._start_transaction() - if parameters is not None: - operation = operation % self._parameter_placeholders( - parameters - ) - else: + if parameters is None: parameters = () try: @@ -506,10 +499,6 @@ class AsyncAdapt_asyncpg_cursor: if not adapt_connection._started: await adapt_connection._start_transaction() - operation = operation % self._parameter_placeholders( - seq_of_parameters[0] - ) - try: return await self._connection.executemany( operation, seq_of_parameters @@ -808,7 +797,7 @@ class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection): class AsyncAdapt_asyncpg_dbapi: def __init__(self, asyncpg): self.asyncpg = asyncpg - self.paramstyle = "format" + self.paramstyle = "numeric_dollar" def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) @@ -900,7 +889,7 @@ class PGDialect_asyncpg(PGDialect): render_bind_cast = True has_terminate = True - default_paramstyle = "format" + default_paramstyle = "numeric_dollar" supports_sane_multi_rowcount = False execution_ctx_cls = PGExecutionContext_asyncpg statement_compiler = PGCompiler_asyncpg diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py index 05ee6c625..851f0951f 100644 --- a/lib/sqlalchemy/dialects/sqlite/provision.py +++ b/lib/sqlalchemy/dialects/sqlite/provision.py @@ -18,7 +18,13 @@ from ...testing.provision import upsert # TODO: I can't get this to build dynamically with pytest-xdist procs -_drivernames = {"pysqlite", "aiosqlite", "pysqlcipher"} +_drivernames = { + "pysqlite", + "aiosqlite", + "pysqlcipher", + "pysqlite_numeric", + "pysqlite_dollar", +} @generate_driver_url.for_db("sqlite") diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 4475ccae7..c04a3601d 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -637,3 +637,110 @@ class SQLiteDialect_pysqlite(SQLiteDialect): dialect = SQLiteDialect_pysqlite + + +class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite): + """numeric dialect for testing only + + internal use only. This dialect is **NOT** supported by SQLAlchemy + and may change at any time. + + """ + + supports_statement_cache = True + default_paramstyle = "numeric" + driver = "pysqlite_numeric" + + _first_bind = ":1" + _not_in_statement_regexp = None + + def __init__(self, *arg, **kw): + kw.setdefault("paramstyle", "numeric") + super().__init__(*arg, **kw) + + def create_connect_args(self, url): + arg, opts = super().create_connect_args(url) + opts["factory"] = self._fix_sqlite_issue_99953() + return arg, opts + + def _fix_sqlite_issue_99953(self): + import sqlite3 + + first_bind = self._first_bind + if self._not_in_statement_regexp: + nis = self._not_in_statement_regexp + + def _test_sql(sql): + m = nis.search(sql) + assert not m, f"Found {nis.pattern!r} in {sql!r}" + + else: + + def _test_sql(sql): + pass + + def _numeric_param_as_dict(parameters): + if parameters: + assert isinstance(parameters, tuple) + return { + str(idx): value for idx, value in enumerate(parameters, 1) + } + else: + return () + + class SQLiteFix99953Cursor(sqlite3.Cursor): + def execute(self, sql, parameters=()): + _test_sql(sql) + if first_bind in sql: + parameters = _numeric_param_as_dict(parameters) + return super().execute(sql, parameters) + + def executemany(self, sql, parameters): + _test_sql(sql) + if first_bind in sql: + parameters = [ + _numeric_param_as_dict(p) for p in parameters + ] + return super().executemany(sql, parameters) + + class SQLiteFix99953Connection(sqlite3.Connection): + def cursor(self, factory=None): + if factory is None: + factory = SQLiteFix99953Cursor + return super().cursor(factory=factory) + + def execute(self, sql, parameters=()): + _test_sql(sql) + if first_bind in sql: + parameters = _numeric_param_as_dict(parameters) + return super().execute(sql, parameters) + + def executemany(self, sql, parameters): + _test_sql(sql) + if first_bind in sql: + parameters = [ + _numeric_param_as_dict(p) for p in parameters + ] + return super().executemany(sql, parameters) + + return SQLiteFix99953Connection + + +class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric): + """numeric dialect that uses $ for testing only + + internal use only. This dialect is **NOT** supported by SQLAlchemy + and may change at any time. + + """ + + supports_statement_cache = True + default_paramstyle = "numeric_dollar" + driver = "pysqlite_dollar" + + _first_bind = "$1" + _not_in_statement_regexp = re.compile(r"[^\d]:\d+") + + def __init__(self, *arg, **kw): + kw.setdefault("paramstyle", "numeric_dollar") + super().__init__(*arg, **kw) |
