summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/8849.rst14
-rw-r--r--doc/build/changelog/unreleased_20/8926.rst8
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py17
-rw-r--r--lib/sqlalchemy/dialects/sqlite/provision.py8
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlite.py107
-rw-r--r--lib/sqlalchemy/engine/default.py7
-rw-r--r--lib/sqlalchemy/engine/interfaces.py4
-rw-r--r--lib/sqlalchemy/sql/compiler.py470
-rw-r--r--lib/sqlalchemy/sql/crud.py4
-rw-r--r--lib/sqlalchemy/testing/assertsql.py26
-rw-r--r--lib/sqlalchemy/testing/config.py2
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py16
-rw-r--r--setup.cfg2
-rw-r--r--test/dialect/postgresql/test_query.py6
-rw-r--r--test/dialect/test_sqlite.py2
-rw-r--r--test/engine/test_logging.py4
-rw-r--r--test/orm/dml/test_bulk_statements.py6
-rw-r--r--test/orm/test_dynamic.py66
-rw-r--r--test/orm/test_merge.py2
-rw-r--r--test/orm/test_unitofworkv2.py2
-rw-r--r--test/requirements.py11
-rw-r--r--test/sql/test_compiler.py366
-rw-r--r--test/sql/test_cte.py20
-rw-r--r--test/sql/test_insert.py95
-rw-r--r--test/sql/test_resultset.py2
-rw-r--r--test/sql/test_types.py2
-rw-r--r--test/sql/test_update.py66
-rw-r--r--tox.ini2
28 files changed, 989 insertions, 348 deletions
diff --git a/doc/build/changelog/unreleased_20/8849.rst b/doc/build/changelog/unreleased_20/8849.rst
new file mode 100644
index 000000000..29ecf2a2c
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/8849.rst
@@ -0,0 +1,14 @@
+.. change::
+ :tags: bug, sql
+ :tickets: 8849
+
+ Reworked how numeric paramstyle behavers, in particular, fixed insertmany
+ behaviour that prior to this was non functional; added support for repeated
+ parameter without duplicating them like in other positional dialects;
+ introduced new numeric paramstyle called ``numeric_dollar`` that can be
+ used to render statements that use the PostgreSQL placeholder style (
+ i.e. ``$1, $2, $3``).
+ This change requires that the dialect supports out of order placehoders,
+ that may be used used in the statements, in particular when using
+ insert-many values with statement that have parameters in the returning
+ clause.
diff --git a/doc/build/changelog/unreleased_20/8926.rst b/doc/build/changelog/unreleased_20/8926.rst
new file mode 100644
index 000000000..a0000fb87
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/8926.rst
@@ -0,0 +1,8 @@
+.. change::
+ :tags: asyncpg
+ :tickets: 8926
+
+ Changed the paramstyle used by asyncpg from ``format`` to
+ ``numeric_dollar``. This has two main benefits since it does not require
+ additional processing of the statement and allows for duplicate parameters
+ to be present in the statements.
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index 751dc3dcf..b8f614eba 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -438,9 +438,6 @@ class AsyncAdapt_asyncpg_cursor:
def _handle_exception(self, error):
self._adapt_connection._handle_exception(error)
- def _parameter_placeholders(self, params):
- return tuple(f"${idx:d}" for idx, _ in enumerate(params, 1))
-
async def _prepare_and_execute(self, operation, parameters):
adapt_connection = self._adapt_connection
@@ -449,11 +446,7 @@ class AsyncAdapt_asyncpg_cursor:
if not adapt_connection._started:
await adapt_connection._start_transaction()
- if parameters is not None:
- operation = operation % self._parameter_placeholders(
- parameters
- )
- else:
+ if parameters is None:
parameters = ()
try:
@@ -506,10 +499,6 @@ class AsyncAdapt_asyncpg_cursor:
if not adapt_connection._started:
await adapt_connection._start_transaction()
- operation = operation % self._parameter_placeholders(
- seq_of_parameters[0]
- )
-
try:
return await self._connection.executemany(
operation, seq_of_parameters
@@ -808,7 +797,7 @@ class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
class AsyncAdapt_asyncpg_dbapi:
def __init__(self, asyncpg):
self.asyncpg = asyncpg
- self.paramstyle = "format"
+ self.paramstyle = "numeric_dollar"
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
@@ -900,7 +889,7 @@ class PGDialect_asyncpg(PGDialect):
render_bind_cast = True
has_terminate = True
- default_paramstyle = "format"
+ default_paramstyle = "numeric_dollar"
supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_asyncpg
statement_compiler = PGCompiler_asyncpg
diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py
index 05ee6c625..851f0951f 100644
--- a/lib/sqlalchemy/dialects/sqlite/provision.py
+++ b/lib/sqlalchemy/dialects/sqlite/provision.py
@@ -18,7 +18,13 @@ from ...testing.provision import upsert
# TODO: I can't get this to build dynamically with pytest-xdist procs
-_drivernames = {"pysqlite", "aiosqlite", "pysqlcipher"}
+_drivernames = {
+ "pysqlite",
+ "aiosqlite",
+ "pysqlcipher",
+ "pysqlite_numeric",
+ "pysqlite_dollar",
+}
@generate_driver_url.for_db("sqlite")
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
index 4475ccae7..c04a3601d 100644
--- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
@@ -637,3 +637,110 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
dialect = SQLiteDialect_pysqlite
+
+
+class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite):
+ """numeric dialect for testing only
+
+ internal use only. This dialect is **NOT** supported by SQLAlchemy
+ and may change at any time.
+
+ """
+
+ supports_statement_cache = True
+ default_paramstyle = "numeric"
+ driver = "pysqlite_numeric"
+
+ _first_bind = ":1"
+ _not_in_statement_regexp = None
+
+ def __init__(self, *arg, **kw):
+ kw.setdefault("paramstyle", "numeric")
+ super().__init__(*arg, **kw)
+
+ def create_connect_args(self, url):
+ arg, opts = super().create_connect_args(url)
+ opts["factory"] = self._fix_sqlite_issue_99953()
+ return arg, opts
+
+ def _fix_sqlite_issue_99953(self):
+ import sqlite3
+
+ first_bind = self._first_bind
+ if self._not_in_statement_regexp:
+ nis = self._not_in_statement_regexp
+
+ def _test_sql(sql):
+ m = nis.search(sql)
+ assert not m, f"Found {nis.pattern!r} in {sql!r}"
+
+ else:
+
+ def _test_sql(sql):
+ pass
+
+ def _numeric_param_as_dict(parameters):
+ if parameters:
+ assert isinstance(parameters, tuple)
+ return {
+ str(idx): value for idx, value in enumerate(parameters, 1)
+ }
+ else:
+ return ()
+
+ class SQLiteFix99953Cursor(sqlite3.Cursor):
+ def execute(self, sql, parameters=()):
+ _test_sql(sql)
+ if first_bind in sql:
+ parameters = _numeric_param_as_dict(parameters)
+ return super().execute(sql, parameters)
+
+ def executemany(self, sql, parameters):
+ _test_sql(sql)
+ if first_bind in sql:
+ parameters = [
+ _numeric_param_as_dict(p) for p in parameters
+ ]
+ return super().executemany(sql, parameters)
+
+ class SQLiteFix99953Connection(sqlite3.Connection):
+ def cursor(self, factory=None):
+ if factory is None:
+ factory = SQLiteFix99953Cursor
+ return super().cursor(factory=factory)
+
+ def execute(self, sql, parameters=()):
+ _test_sql(sql)
+ if first_bind in sql:
+ parameters = _numeric_param_as_dict(parameters)
+ return super().execute(sql, parameters)
+
+ def executemany(self, sql, parameters):
+ _test_sql(sql)
+ if first_bind in sql:
+ parameters = [
+ _numeric_param_as_dict(p) for p in parameters
+ ]
+ return super().executemany(sql, parameters)
+
+ return SQLiteFix99953Connection
+
+
+class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric):
+ """numeric dialect that uses $ for testing only
+
+ internal use only. This dialect is **NOT** supported by SQLAlchemy
+ and may change at any time.
+
+ """
+
+ supports_statement_cache = True
+ default_paramstyle = "numeric_dollar"
+ driver = "pysqlite_dollar"
+
+ _first_bind = "$1"
+ _not_in_statement_regexp = re.compile(r"[^\d]:\d+")
+
+ def __init__(self, *arg, **kw):
+ kw.setdefault("paramstyle", "numeric_dollar")
+ super().__init__(*arg, **kw)
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 3cc9cab8b..4647c84d1 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -320,7 +320,12 @@ class DefaultDialect(Dialect):
self.paramstyle = self.dbapi.paramstyle
else:
self.paramstyle = self.default_paramstyle
- self.positional = self.paramstyle in ("qmark", "format", "numeric")
+ self.positional = self.paramstyle in (
+ "qmark",
+ "format",
+ "numeric",
+ "numeric_dollar",
+ )
self.identifier_preparer = self.preparer(self)
self._on_connect_isolation_level = isolation_level
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index 2f5efce25..ddf8a53fb 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -255,7 +255,9 @@ SchemaTranslateMapType = Mapping[Optional[str], Optional[str]]
_ImmutableExecuteOptions = immutabledict[str, Any]
-_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"]
+_ParamStyle = Literal[
+ "qmark", "numeric", "named", "format", "pyformat", "numeric_dollar"
+]
_GenericSetInputSizesType = List[Tuple[str, Any, "TypeEngine[Any]"]]
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 7ac279ee2..d7358ad3b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -227,11 +227,13 @@ FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
+_pyformat_template = "%%(%(name)s)s"
BIND_TEMPLATES = {
- "pyformat": "%%(%(name)s)s",
+ "pyformat": _pyformat_template,
"qmark": "?",
"format": "%%s",
"numeric": ":[_POSITION]",
+ "numeric_dollar": "$[_POSITION]",
"named": ":%(name)s",
}
@@ -420,6 +422,22 @@ class _InsertManyValues(NamedTuple):
num_positional_params_counted: int
+class CompilerState(IntEnum):
+ COMPILING = 0
+ """statement is present, compilation phase in progress"""
+
+ STRING_APPLIED = 1
+ """statement is present, string form of the statement has been applied.
+
+ Additional processors by subclasses may still be pending.
+
+ """
+
+ NO_STATEMENT = 2
+ """compiler does not have a statement to compile, is used
+ for method access"""
+
+
class Linting(IntEnum):
"""represent preferences for the 'SQL linting' feature.
@@ -527,6 +545,14 @@ class Compiled:
defaults.
"""
+ statement: Optional[ClauseElement] = None
+ "The statement to compile."
+ string: str = ""
+ "The string representation of the ``statement``"
+
+ state: CompilerState
+ """description of the compiler's state"""
+
is_sql = False
is_ddl = False
@@ -618,7 +644,6 @@ class Compiled:
"""
-
self.dialect = dialect
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
@@ -628,6 +653,7 @@ class Compiled:
)
if statement is not None:
+ self.state = CompilerState.COMPILING
self.statement = statement
self.can_execute = statement.supports_execution
self._annotations = statement._annotations
@@ -641,6 +667,11 @@ class Compiled:
self.string = self.preparer._render_schema_translates(
self.string, schema_translate_map
)
+
+ self.state = CompilerState.STRING_APPLIED
+ else:
+ self.state = CompilerState.NO_STATEMENT
+
self._gen_time = perf_counter()
def _execute_on_connection(
@@ -672,7 +703,10 @@ class Compiled:
def __str__(self) -> str:
"""Return the string text of the generated SQL or DDL."""
- return self.string or ""
+ if self.state is CompilerState.STRING_APPLIED:
+ return self.string
+ else:
+ return ""
def construct_params(
self,
@@ -859,6 +893,19 @@ class SQLCompiler(Compiled):
driver/DB enforces this
"""
+ bindtemplate: str
+ """template to render bound parameters based on paramstyle."""
+
+ compilation_bindtemplate: str
+ """template used by compiler to render parameters before positional
+ paramstyle application"""
+
+ _numeric_binds_identifier_char: str
+ """Character that's used to as the identifier of a numerical bind param.
+ For example if this char is set to ``$``, numerical binds will be rendered
+ in the form ``$1, $2, $3``.
+ """
+
_result_columns: List[ResultColumnsEntry]
"""relates label names in the final SQL to a tuple of local
column/label name, ColumnElement object (if any) and
@@ -967,13 +1014,17 @@ class SQLCompiler(Compiled):
and is combined with the :attr:`_sql.Compiled.params` dictionary to
render parameters.
+ This sequence always contains the unescaped name of the parameters.
+
.. seealso::
:ref:`faq_sql_expression_string` - includes a usage example for
debugging use cases.
"""
- positiontup_level: Optional[Dict[str, int]] = None
+ _values_bindparam: Optional[List[str]] = None
+
+ _visited_bindparam: Optional[List[str]] = None
inline: bool = False
@@ -988,9 +1039,12 @@ class SQLCompiler(Compiled):
level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
ctes_recursive: bool
- cte_positional: Dict[CTE, List[str]]
- cte_level: Dict[CTE, int]
- cte_order: Dict[Optional[CTE], List[CTE]]
+
+ _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]")
+ _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s")
+ _positional_pattern = re.compile(
+ f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
+ )
def __init__(
self,
@@ -1055,10 +1109,15 @@ class SQLCompiler(Compiled):
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
- self.positiontup_level = {}
- self.positiontup = []
- self._numeric_binds = dialect.paramstyle == "numeric"
- self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
+ self._numeric_binds = nb = dialect.paramstyle.startswith("numeric")
+ if nb:
+ self._numeric_binds_identifier_char = (
+ "$" if dialect.paramstyle == "numeric_dollar" else ":"
+ )
+
+ self.compilation_bindtemplate = _pyformat_template
+ else:
+ self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
self.ctes = None
@@ -1095,11 +1154,17 @@ class SQLCompiler(Compiled):
):
self.inline = True
- if self.positional and self._numeric_binds:
- self._apply_numbered_params()
+ self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
+
+ if self.state is CompilerState.STRING_APPLIED:
+ if self.positional:
+ if self._numeric_binds:
+ self._process_numeric()
+ else:
+ self._process_positional()
- if self._render_postcompile:
- self._process_parameters_for_postcompile(_populate_self=True)
+ if self._render_postcompile:
+ self._process_parameters_for_postcompile(_populate_self=True)
@property
def insert_single_values_expr(self) -> Optional[str]:
@@ -1135,7 +1200,7 @@ class SQLCompiler(Compiled):
"""
if self.implicit_returning:
return self.implicit_returning
- elif is_dml(self.statement):
+ elif self.statement is not None and is_dml(self.statement):
return [
c
for c in self.statement._all_selected_columns
@@ -1217,10 +1282,6 @@ class SQLCompiler(Compiled):
self.level_name_by_cte = {}
self.ctes_recursive = False
- if self.positional:
- self.cte_positional = {}
- self.cte_level = {}
- self.cte_order = collections.defaultdict(list)
return ctes
@@ -1248,12 +1309,145 @@ class SQLCompiler(Compiled):
ordered_columns,
)
- def _apply_numbered_params(self):
- poscount = itertools.count(1)
+ def _process_positional(self):
+ assert not self.positiontup
+ assert self.state is CompilerState.STRING_APPLIED
+ assert not self._numeric_binds
+
+ if self.dialect.paramstyle == "format":
+ placeholder = "%s"
+ else:
+ assert self.dialect.paramstyle == "qmark"
+ placeholder = "?"
+
+ positions = []
+
+ def find_position(m: re.Match[str]) -> str:
+ normal_bind = m.group(1)
+ if normal_bind:
+ positions.append(normal_bind)
+ return placeholder
+ else:
+ # this a post-compile bind
+ positions.append(m.group(2))
+ return m.group(0)
+
self.string = re.sub(
- r"\[_POSITION\]", lambda m: str(next(poscount)), self.string
+ self._positional_pattern, find_position, self.string
)
+ if self.escaped_bind_names:
+ reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
+ assert len(self.escaped_bind_names) == len(reverse_escape)
+ self.positiontup = [
+ reverse_escape.get(name, name) for name in positions
+ ]
+ else:
+ self.positiontup = positions
+
+ if self._insertmanyvalues:
+ positions = []
+ single_values_expr = re.sub(
+ self._positional_pattern,
+ find_position,
+ self._insertmanyvalues.single_values_expr,
+ )
+ insert_crud_params = [
+ (
+ v[0],
+ v[1],
+ re.sub(self._positional_pattern, find_position, v[2]),
+ v[3],
+ )
+ for v in self._insertmanyvalues.insert_crud_params
+ ]
+
+ self._insertmanyvalues = _InsertManyValues(
+ is_default_expr=self._insertmanyvalues.is_default_expr,
+ single_values_expr=single_values_expr,
+ insert_crud_params=insert_crud_params,
+ num_positional_params_counted=(
+ self._insertmanyvalues.num_positional_params_counted
+ ),
+ )
+
+ def _process_numeric(self):
+ assert self._numeric_binds
+ assert self.state is CompilerState.STRING_APPLIED
+
+ num = 1
+ param_pos: Dict[str, str] = {}
+ order: Iterable[str]
+ if self._insertmanyvalues and self._values_bindparam is not None:
+ # bindparams that are not in values are always placed first.
+ # this avoids the need of changing them when using executemany
+ # values () ()
+ order = itertools.chain(
+ (
+ name
+ for name in self.bind_names.values()
+ if name not in self._values_bindparam
+ ),
+ self.bind_names.values(),
+ )
+ else:
+ order = self.bind_names.values()
+
+ for bind_name in order:
+ if bind_name in param_pos:
+ continue
+ bind = self.binds[bind_name]
+ if (
+ bind in self.post_compile_params
+ or bind in self.literal_execute_params
+ ):
+ # set to None to just mark the in positiontup, it will not
+ # be replaced below.
+ param_pos[bind_name] = None # type: ignore
+ else:
+ ph = f"{self._numeric_binds_identifier_char}{num}"
+ num += 1
+ param_pos[bind_name] = ph
+
+ self.next_numeric_pos = num
+
+ self.positiontup = list(param_pos)
+ if self.escaped_bind_names:
+ reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
+ assert len(self.escaped_bind_names) == len(reverse_escape)
+ param_pos = {
+ self.escaped_bind_names.get(name, name): pos
+ for name, pos in param_pos.items()
+ }
+
+ # Can't use format here since % chars are not escaped.
+ self.string = self._pyformat_pattern.sub(
+ lambda m: param_pos[m.group(1)], self.string
+ )
+
+ if self._insertmanyvalues:
+ single_values_expr = (
+ # format is ok here since single_values_expr includes only
+ # place-holders
+ self._insertmanyvalues.single_values_expr
+ % param_pos
+ )
+ insert_crud_params = [
+ (v[0], v[1], "%s", v[3])
+ for v in self._insertmanyvalues.insert_crud_params
+ ]
+
+ self._insertmanyvalues = _InsertManyValues(
+ is_default_expr=self._insertmanyvalues.is_default_expr,
+ # This has the numbers (:1, :2)
+ single_values_expr=single_values_expr,
+ # The single binds are instead %s so they can be formatted
+ insert_crud_params=insert_crud_params,
+ num_positional_params_counted=(
+ self._insertmanyvalues.num_positional_params_counted
+ ),
+ )
+
@util.memoized_property
def _bind_processors(
self,
@@ -1492,39 +1686,30 @@ class SQLCompiler(Compiled):
new_processors: Dict[str, _BindProcessorType[Any]] = {}
- if self.positional and self._numeric_binds:
- # I'm not familiar with any DBAPI that uses 'numeric'.
- # strategy would likely be to make use of numbers greater than
- # the highest number present; then for expanding parameters,
- # append them to the end of the parameter list. that way
- # we avoid having to renumber all the existing parameters.
- raise NotImplementedError(
- "'post-compile' bind parameters are not supported with "
- "the 'numeric' paramstyle at this time."
- )
-
replacement_expressions: Dict[str, Any] = {}
to_update_sets: Dict[str, Any] = {}
# notes:
# *unescaped* parameter names in:
- # self.bind_names, self.binds, self._bind_processors
+ # self.bind_names, self.binds, self._bind_processors, self.positiontup
#
# *escaped* parameter names in:
# construct_params(), replacement_expressions
+ numeric_positiontup: Optional[List[str]] = None
+
if self.positional and self.positiontup is not None:
names: Iterable[str] = self.positiontup
+ if self._numeric_binds:
+ numeric_positiontup = []
else:
names = self.bind_names.values()
+ ebn = self.escaped_bind_names
for name in names:
- escaped_name = (
- self.escaped_bind_names.get(name, name)
- if self.escaped_bind_names
- else name
- )
+ escaped_name = ebn.get(name, name) if ebn else name
parameter = self.binds[name]
+
if parameter in self.literal_execute_params:
if escaped_name not in replacement_expressions:
value = parameters.pop(escaped_name)
@@ -1555,10 +1740,10 @@ class SQLCompiler(Compiled):
# in the escaped_bind_names dictionary.
values = parameters.pop(name)
- leep = self._literal_execute_expanding_parameter
- to_update, replacement_expr = leep(
+ leep_res = self._literal_execute_expanding_parameter(
escaped_name, parameter, values
)
+ (to_update, replacement_expr) = leep_res
to_update_sets[escaped_name] = to_update
replacement_expressions[escaped_name] = replacement_expr
@@ -1583,7 +1768,14 @@ class SQLCompiler(Compiled):
for key, _ in to_update
if name in single_processors
)
- if positiontup is not None:
+ if numeric_positiontup is not None:
+ numeric_positiontup.extend(
+ name for name, _ in to_update
+ )
+ elif positiontup is not None:
+ # to_update has escaped names, but that's ok since
+ # these are new names, that aren't in the
+ # escaped_bind_names dict.
positiontup.extend(name for name, _ in to_update)
expanded_parameters[name] = [
expand_key for expand_key, _ in to_update
@@ -1607,11 +1799,23 @@ class SQLCompiler(Compiled):
return expr
statement = re.sub(
- r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]",
- process_expanding,
- self.string,
+ self._post_compile_pattern, process_expanding, self.string
)
+ if numeric_positiontup is not None:
+ assert positiontup is not None
+ param_pos = {
+ key: f"{self._numeric_binds_identifier_char}{num}"
+ for num, key in enumerate(
+ numeric_positiontup, self.next_numeric_pos
+ )
+ }
+ # Can't use format here since % chars are not escaped.
+ statement = self._pyformat_pattern.sub(
+ lambda m: param_pos[m.group(1)], statement
+ )
+ positiontup.extend(numeric_positiontup)
+
expanded_state = ExpandedState(
statement,
parameters,
@@ -2109,13 +2313,7 @@ class SQLCompiler(Compiled):
text = self.process(taf.element, **kw)
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(
- nesting_level=nesting_level,
- visiting_cte=kw.get("visiting_cte"),
- )
- + text
- )
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
self.stack.pop(-1)
@@ -2411,7 +2609,6 @@ class SQLCompiler(Compiled):
self._render_cte_clause(
nesting_level=nesting_level,
include_following_stack=True,
- visiting_cte=kwargs.get("visiting_cte"),
)
+ text
)
@@ -2625,6 +2822,11 @@ class SQLCompiler(Compiled):
dialect = self.dialect
typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect)
+ if self._numeric_binds:
+ bind_template = self.compilation_bindtemplate
+ else:
+ bind_template = self.bindtemplate
+
if (
self.dialect._bind_typing_render_casts
and typ_dialect_impl.render_bind_cast
@@ -2634,13 +2836,13 @@ class SQLCompiler(Compiled):
return self.render_bind_cast(
parameter.type,
typ_dialect_impl,
- self.bindtemplate % {"name": name},
+ bind_template % {"name": name},
)
else:
def _render_bindtemplate(name):
- return self.bindtemplate % {"name": name}
+ return bind_template % {"name": name}
if not values:
to_update = []
@@ -3224,7 +3426,6 @@ class SQLCompiler(Compiled):
def bindparam_string(
self,
name: str,
- positional_names: Optional[List[str]] = None,
post_compile: bool = False,
expanding: bool = False,
escaped_from: Optional[str] = None,
@@ -3232,12 +3433,9 @@ class SQLCompiler(Compiled):
**kw: Any,
) -> str:
- if self.positional:
- if positional_names is not None:
- positional_names.append(name)
- else:
- self.positiontup.append(name) # type: ignore[union-attr]
- self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501
+ if self._visited_bindparam is not None:
+ self._visited_bindparam.append(name)
+
if not escaped_from:
if _BIND_TRANSLATE_RE.search(name):
@@ -3271,6 +3469,8 @@ class SQLCompiler(Compiled):
if type_impl.render_literal_cast:
ret = self.render_bind_cast(bindparam_type, type_impl, ret)
return ret
+ elif self.state is CompilerState.COMPILING:
+ ret = self.compilation_bindtemplate % {"name": name}
else:
ret = self.bindtemplate % {"name": name}
@@ -3349,8 +3549,6 @@ class SQLCompiler(Compiled):
self.level_name_by_cte[_reference_cte] = new_level_name + (
cte_opts,
)
- if self.positional:
- self.cte_level[cte] = cte_level
else:
cte_level = len(self.stack) if nesting else 1
@@ -3414,8 +3612,6 @@ class SQLCompiler(Compiled):
self.level_name_by_cte[_reference_cte] = cte_level_name + (
cte_opts,
)
- if self.positional:
- self.cte_level[cte] = cte_level
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
@@ -3455,9 +3651,6 @@ class SQLCompiler(Compiled):
)
)
- if self.positional:
- kwargs["positional_names"] = self.cte_positional[cte] = []
-
assert kwargs.get("subquery", False) is False
if not self.stack:
@@ -4152,13 +4345,7 @@ class SQLCompiler(Compiled):
# In compound query, CTEs are shared at the compound level
if self.ctes and (not is_embedded_select or toplevel):
nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(
- nesting_level=nesting_level,
- visiting_cte=kwargs.get("visiting_cte"),
- )
- + text
- )
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
if select_stmt._suffixes:
text += " " + self._generate_prefixes(
@@ -4332,7 +4519,6 @@ class SQLCompiler(Compiled):
self,
nesting_level=None,
include_following_stack=False,
- visiting_cte=None,
):
"""
include_following_stack
@@ -4367,46 +4553,6 @@ class SQLCompiler(Compiled):
return ""
ctes_recursive = any([cte.recursive for cte in ctes])
- if self.positional:
- self.cte_order[visiting_cte].extend(ctes)
-
- if visiting_cte is None and self.cte_order:
- assert self.positiontup is not None
-
- def get_nested_positional(cte):
- if cte in self.cte_order:
- children = self.cte_order.pop(cte)
- to_add = list(
- itertools.chain.from_iterable(
- get_nested_positional(child_cte)
- for child_cte in children
- )
- )
- if cte in self.cte_positional:
- return reorder_positional(
- self.cte_positional[cte],
- to_add,
- self.cte_level[children[0]],
- )
- else:
- return to_add
- else:
- return self.cte_positional.get(cte, [])
-
- def reorder_positional(pos, to_add, level):
- if not level:
- return to_add + pos
- index = 0
- for index, name in enumerate(reversed(pos)):
- if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501
- break
- return pos[:-index] + to_add + pos[-index:]
-
- to_add = get_nested_positional(None)
- self.positiontup = reorder_positional(
- self.positiontup, to_add, nesting_level
- )
-
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
cte_text += "\n "
@@ -4762,6 +4908,11 @@ class SQLCompiler(Compiled):
keys_to_replace = set()
base_parameters = {}
executemany_values_w_comma = f"({imv.single_values_expr}), "
+ if self._numeric_binds:
+ escaped = re.escape(self._numeric_binds_identifier_char)
+ executemany_values_w_comma = re.sub(
+ rf"{escaped}\d+", "%s", executemany_values_w_comma
+ )
while batches:
batch = batches[0:batch_size]
@@ -4794,25 +4945,37 @@ class SQLCompiler(Compiled):
num_ins_params = imv.num_positional_params_counted
+ batch_iterator: Iterable[Tuple[Any, ...]]
if num_ins_params == len(batch[0]):
extra_params = ()
- batch_iterator: Iterable[Tuple[Any, ...]] = batch
- elif self.returning_precedes_values:
+ batch_iterator = batch
+ elif self.returning_precedes_values or self._numeric_binds:
extra_params = batch[0][:-num_ins_params]
batch_iterator = (b[-num_ins_params:] for b in batch)
else:
extra_params = batch[0][num_ins_params:]
batch_iterator = (b[:num_ins_params] for b in batch)
+ values_string = (executemany_values_w_comma * len(batch))[:-2]
+ if self._numeric_binds and num_ins_params > 0:
+ # need to format here, since statement may contain
+ # unescaped %, while values_string contains just (%s, %s)
+ start = len(extra_params) + 1
+ end = num_ins_params * len(batch) + start
+ positions = tuple(
+ f"{self._numeric_binds_identifier_char}{i}"
+ for i in range(start, end)
+ )
+ values_string = values_string % positions
+
replaced_statement = statement.replace(
- "__EXECMANY_TOKEN__",
- (executemany_values_w_comma * len(batch))[:-2],
+ "__EXECMANY_TOKEN__", values_string
)
replaced_parameters = tuple(
itertools.chain.from_iterable(batch_iterator)
)
- if self.returning_precedes_values:
+ if self.returning_precedes_values or self._numeric_binds:
replaced_parameters = extra_params + replaced_parameters
else:
replaced_parameters = replaced_parameters + extra_params
@@ -4869,23 +5032,30 @@ class SQLCompiler(Compiled):
}
)
- positiontup_before = positiontup_after = 0
+ counted_bindparam = 0
# for positional, insertmanyvalues needs to know how many
# bound parameters are in the VALUES sequence; there's no simple
# rule because default expressions etc. can have zero or more
# params inside them. After multiple attempts to figure this out,
- # this very simplistic "count before, then count after" works and is
+ # this very simplistic "count after" works and is
# likely the least amount of callcounts, though looks clumsy
- if self.positiontup:
- positiontup_before = len(self.positiontup)
+ if self.positional:
+ self._visited_bindparam = []
crud_params_struct = crud._get_crud_params(
self, insert_stmt, compile_state, toplevel, **kw
)
- if self.positiontup:
- positiontup_after = len(self.positiontup)
+ if self.positional:
+ assert self._visited_bindparam is not None
+ counted_bindparam = len(self._visited_bindparam)
+ if self._numeric_binds:
+ if self._values_bindparam is not None:
+ self._values_bindparam += self._visited_bindparam
+ else:
+ self._values_bindparam = self._visited_bindparam
+ self._visited_bindparam = None
crud_params_single = crud_params_struct.single_params
@@ -4940,31 +5110,13 @@ class SQLCompiler(Compiled):
if self.implicit_returning or insert_stmt._returning:
- # if returning clause is rendered first, capture bound parameters
- # while visiting and place them prior to the VALUES oriented
- # bound parameters, when using positional parameter scheme
- rpv = self.returning_precedes_values
- flip_pt = rpv and self.positional
- if flip_pt:
- pt: Optional[List[str]] = self.positiontup
- temp_pt: Optional[List[str]]
- self.positiontup = temp_pt = []
- else:
- temp_pt = pt = None
-
returning_clause = self.returning_clause(
insert_stmt,
self.implicit_returning or insert_stmt._returning,
populate_result_map=toplevel,
)
- if flip_pt:
- if TYPE_CHECKING:
- assert temp_pt is not None
- assert pt is not None
- self.positiontup = temp_pt + pt
-
- if rpv:
+ if self.returning_precedes_values:
text += " " + returning_clause
else:
@@ -4982,7 +5134,6 @@ class SQLCompiler(Compiled):
self._render_cte_clause(
nesting_level=nesting_level,
include_following_stack=True,
- visiting_cte=kw.get("visiting_cte"),
),
select_text,
)
@@ -4999,7 +5150,7 @@ class SQLCompiler(Compiled):
cast(
"List[crud._CrudParamElementStr]", crud_params_single
),
- (positiontup_after - positiontup_before),
+ counted_bindparam,
)
elif compile_state._has_multi_parameters:
text += " VALUES %s" % (
@@ -5033,7 +5184,7 @@ class SQLCompiler(Compiled):
"List[crud._CrudParamElementStr]",
crud_params_single,
),
- positiontup_after - positiontup_before,
+ counted_bindparam,
)
if insert_stmt._post_values_clause is not None:
@@ -5052,7 +5203,6 @@ class SQLCompiler(Compiled):
self._render_cte_clause(
nesting_level=nesting_level,
include_following_stack=True,
- visiting_cte=kw.get("visiting_cte"),
)
+ text
)
@@ -5201,13 +5351,7 @@ class SQLCompiler(Compiled):
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(
- nesting_level=nesting_level,
- visiting_cte=kw.get("visiting_cte"),
- )
- + text
- )
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
self.stack.pop(-1)
@@ -5321,13 +5465,7 @@ class SQLCompiler(Compiled):
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(
- nesting_level=nesting_level,
- visiting_cte=kw.get("visiting_cte"),
- )
- + text
- )
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
self.stack.pop(-1)
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 017ff7baa..ae1b032ae 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -85,8 +85,8 @@ _CrudParamElement = Tuple[
]
_CrudParamElementStr = Tuple[
"KeyedColumnElement[Any]",
- str,
- str,
+ str, # column name
+ str, # placeholder
Iterable[str],
]
_CrudParamElementSQLExpr = Tuple[
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index d183372c3..45a2496dd 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -11,6 +11,7 @@ from __future__ import annotations
import collections
import contextlib
+import itertools
import re
from .. import event
@@ -285,7 +286,8 @@ class DialectSQL(CompiledSQL):
return received_stmt, execute_observed.context.compiled_parameters
- def _dialect_adjusted_statement(self, paramstyle):
+ def _dialect_adjusted_statement(self, dialect):
+ paramstyle = dialect.paramstyle
stmt = re.sub(r"[\n\t]", "", self.statement)
# temporarily escape out PG double colons
@@ -300,8 +302,14 @@ class DialectSQL(CompiledSQL):
repl = "?"
elif paramstyle == "format":
repl = r"%s"
- elif paramstyle == "numeric":
- repl = None
+ elif paramstyle.startswith("numeric"):
+ counter = itertools.count(1)
+
+ num_identifier = "$" if paramstyle == "numeric_dollar" else ":"
+
+ def repl(m):
+ return f"{num_identifier}{next(counter)}"
+
stmt = re.sub(r":([\w_]+)", repl, stmt)
# put them back
@@ -310,20 +318,20 @@ class DialectSQL(CompiledSQL):
return stmt
def _compare_sql(self, execute_observed, received_statement):
- paramstyle = execute_observed.context.dialect.paramstyle
- stmt = self._dialect_adjusted_statement(paramstyle)
+ stmt = self._dialect_adjusted_statement(
+ execute_observed.context.dialect
+ )
return received_statement == stmt
def _failure_message(self, execute_observed, expected_params):
- paramstyle = execute_observed.context.dialect.paramstyle
return (
"Testing for compiled statement\n%r partial params %s, "
"received\n%%(received_statement)r with params "
"%%(received_parameters)r"
% (
- self._dialect_adjusted_statement(paramstyle).replace(
- "%", "%%"
- ),
+ self._dialect_adjusted_statement(
+ execute_observed.context.dialect
+ ).replace("%", "%%"),
repr(expected_params).replace("%", "%%"),
)
)
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index 957876579..6adcf5b64 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -189,7 +189,7 @@ def variation(argname, cases):
elif querytyp.legacy_query:
stmt = Session.query(Thing)
else:
- assert False
+ querytyp.fail()
The variable provided is a slots object of boolean variables, as well
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 656a4e98a..ffe0f453a 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -371,6 +371,22 @@ def _setup_options(opt, file_config):
options = opt
+@pre
+def _register_sqlite_numeric_dialect(opt, file_config):
+ from sqlalchemy.dialects import registry
+
+ registry.register(
+ "sqlite.pysqlite_numeric",
+ "sqlalchemy.dialects.sqlite.pysqlite",
+ "_SQLiteDialect_pysqlite_numeric",
+ )
+ registry.register(
+ "sqlite.pysqlite_dollar",
+ "sqlalchemy.dialects.sqlite.pysqlite",
+ "_SQLiteDialect_pysqlite_dollar",
+ )
+
+
@post
def __ensure_cext(opt, file_config):
if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1":
diff --git a/setup.cfg b/setup.cfg
index b02ad2682..485f1d682 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -144,6 +144,8 @@ oracle_db_link2 = test_link2
[db]
default = sqlite:///:memory:
sqlite = sqlite:///:memory:
+sqlite_numeric = sqlite+pysqlite_numeric:///:memory:
+sqlite_dollar = sqlite+pysqlite_dollar:///:memory:
aiosqlite = sqlite+aiosqlite:///:memory:
sqlite_file = sqlite:///querytest.db
aiosqlite_file = sqlite+aiosqlite:///async_querytest.db
diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py
index 42ec20743..2b32d6db7 100644
--- a/test/dialect/postgresql/test_query.py
+++ b/test/dialect/postgresql/test_query.py
@@ -990,19 +990,19 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
"matchtable.title @@ plainto_tsquery(%(title_1)s)",
)
- @testing.requires.format_paramstyle
+ @testing.only_if("+asyncpg")
def test_expression_positional(self, connection):
matchtable = self.tables.matchtable
if self._strs_render_bind_casts(connection):
self.assert_compile(
matchtable.c.title.match("somstr"),
- "matchtable.title @@ plainto_tsquery(%s::VARCHAR(200))",
+ "matchtable.title @@ plainto_tsquery($1::VARCHAR(200))",
)
else:
self.assert_compile(
matchtable.c.title.match("somstr"),
- "matchtable.title @@ plainto_tsquery(%s)",
+ "matchtable.title @@ plainto_tsquery($1)",
)
def test_simple_match(self, connection):
diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py
index c5147e37f..07117b862 100644
--- a/test/dialect/test_sqlite.py
+++ b/test/dialect/test_sqlite.py
@@ -2916,6 +2916,8 @@ class OnConflictTest(AssertsCompiledSQL, fixtures.TablesTest):
)
@testing.combinations("control", "excluded", "dict")
+ @testing.skip_if("+pysqlite_numeric")
+ @testing.skip_if("+pysqlite_dollar")
def test_set_excluded(self, scenario):
"""test #8014, sending all of .excluded to set"""
diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py
index 277248617..19c26f43c 100644
--- a/test/engine/test_logging.py
+++ b/test/engine/test_logging.py
@@ -28,7 +28,7 @@ def exec_sql(engine, sql, *args, **kwargs):
class LogParamsTest(fixtures.TestBase):
- __only_on__ = "sqlite"
+ __only_on__ = "sqlite+pysqlite"
__requires__ = ("ad_hoc_engines",)
def setup_test(self):
@@ -704,7 +704,7 @@ class LoggingNameTest(fixtures.TestBase):
class TransactionContextLoggingTest(fixtures.TestBase):
- __only_on__ = "sqlite"
+ __only_on__ = "sqlite+pysqlite"
@testing.fixture()
def plain_assert_buf(self, plain_logging_engine):
diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py
index 557b5e9da..78607e03d 100644
--- a/test/orm/dml/test_bulk_statements.py
+++ b/test/orm/dml/test_bulk_statements.py
@@ -958,7 +958,7 @@ class BulkDMLReturningJoinedInhTest(
BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
):
- __requires__ = ("insert_returning",)
+ __requires__ = ("insert_returning", "insert_executemany_returning")
__backend__ = True
@classmethod
@@ -1044,7 +1044,7 @@ class BulkDMLReturningJoinedInhTest(
class BulkDMLReturningSingleInhTest(
BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
):
- __requires__ = ("insert_returning",)
+ __requires__ = ("insert_returning", "insert_executemany_returning")
__backend__ = True
@classmethod
@@ -1075,7 +1075,7 @@ class BulkDMLReturningSingleInhTest(
class BulkDMLReturningConcreteInhTest(
BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
):
- __requires__ = ("insert_returning",)
+ __requires__ = ("insert_returning", "insert_executemany_returning")
__backend__ = True
@classmethod
diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py
index df335f0f6..714878f4e 100644
--- a/test/orm/test_dynamic.py
+++ b/test/orm/test_dynamic.py
@@ -1791,14 +1791,33 @@ class WriteOnlyBulkTest(
"INSERT INTO users (name) VALUES (:name)",
[{"name": "x"}],
),
- CompiledSQL(
- "INSERT INTO addresses (user_id, email_address) "
- "VALUES (:user_id, :email_address) "
- "RETURNING addresses.id",
+ Conditional(
+ testing.requires.insert_executemany_returning.enabled,
+ [
+ CompiledSQL(
+ "INSERT INTO addresses "
+ "(user_id, email_address) "
+ "VALUES (:user_id, :email_address) "
+ "RETURNING addresses.id",
+ [
+ {"user_id": uid, "email_address": "e1"},
+ {"user_id": uid, "email_address": "e2"},
+ {"user_id": uid, "email_address": "e3"},
+ ],
+ )
+ ],
[
- {"user_id": uid, "email_address": "e1"},
- {"user_id": uid, "email_address": "e2"},
- {"user_id": uid, "email_address": "e3"},
+ CompiledSQL(
+ "INSERT INTO addresses "
+ "(user_id, email_address) "
+ "VALUES (:user_id, :email_address)",
+ param,
+ )
+ for param in [
+ {"user_id": uid, "email_address": "e1"},
+ {"user_id": uid, "email_address": "e2"},
+ {"user_id": uid, "email_address": "e3"},
+ ]
],
),
],
@@ -1863,14 +1882,33 @@ class WriteOnlyBulkTest(
"INSERT INTO users (name) VALUES (:name)",
[{"name": "x"}],
),
- CompiledSQL(
- "INSERT INTO addresses (user_id, email_address) "
- "VALUES (:user_id, :email_address) "
- "RETURNING addresses.id",
+ Conditional(
+ testing.requires.insert_executemany_returning.enabled,
+ [
+ CompiledSQL(
+ "INSERT INTO addresses "
+ "(user_id, email_address) "
+ "VALUES (:user_id, :email_address) "
+ "RETURNING addresses.id",
+ [
+ {"user_id": uid, "email_address": "e1"},
+ {"user_id": uid, "email_address": "e2"},
+ {"user_id": uid, "email_address": "e3"},
+ ],
+ )
+ ],
[
- {"user_id": uid, "email_address": "e1"},
- {"user_id": uid, "email_address": "e2"},
- {"user_id": uid, "email_address": "e3"},
+ CompiledSQL(
+ "INSERT INTO addresses "
+ "(user_id, email_address) "
+ "VALUES (:user_id, :email_address)",
+ param,
+ )
+ for param in [
+ {"user_id": uid, "email_address": "e1"},
+ {"user_id": uid, "email_address": "e2"},
+ {"user_id": uid, "email_address": "e3"},
+ ]
],
),
],
diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py
index 36c47e27b..eb5a795e2 100644
--- a/test/orm/test_merge.py
+++ b/test/orm/test_merge.py
@@ -1458,7 +1458,7 @@ class MergeTest(_fixtures.FixtureTest):
)
attrname = "user"
else:
- assert False
+ direction.fail()
assert attrname in obj_to_merge.__dict__
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index f204e954c..468d43063 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -3077,7 +3077,7 @@ class EagerDefaultsTest(fixtures.MappedTest):
asserter.assert_(
Conditional(
- testing.db.dialect.insert_executemany_returning,
+ testing.db.dialect.insert_returning,
[
CompiledSQL(
"INSERT INTO test (id) VALUES (:id) "
diff --git a/test/requirements.py b/test/requirements.py
index 5276593c9..83cd65cd8 100644
--- a/test/requirements.py
+++ b/test/requirements.py
@@ -232,7 +232,6 @@ class DefaultRequirements(SuiteRequirements):
"mariadb+pymysql",
"mariadb+cymysql",
"mariadb+mysqlconnector",
- "postgresql+asyncpg",
"postgresql+pg8000",
]
)
@@ -388,6 +387,14 @@ class DefaultRequirements(SuiteRequirements):
)
@property
+ def predictable_gc(self):
+ """target platform must remove all cycles unconditionally when
+ gc.collect() is called, as well as clean out unreferenced subclasses.
+
+ """
+ return self.cpython + skip_if("+aiosqlite")
+
+ @property
def memory_process_intensive(self):
"""Driver is able to handle the memory tests which run in a subprocess
and iterate through hundreds of connections
@@ -969,6 +976,8 @@ class DefaultRequirements(SuiteRequirements):
"mariadb",
"sqlite+aiosqlite",
"sqlite+pysqlite",
+ "sqlite+pysqlite_numeric",
+ "sqlite+pysqlite_dollar",
"sqlite+pysqlcipher",
"mssql",
)
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index 205ce5157..d342b9248 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -79,6 +79,7 @@ from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import CompilerColumnElement
+from sqlalchemy.sql.elements import Grouping
from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.sql.expression import ClauseList
from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
@@ -4915,88 +4916,259 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
dialect="default",
)
- @standalone_escape
- @testing.variation("use_assert_compile", [True, False])
@testing.variation("use_positional", [True, False])
- def test_standalone_bindparam_escape_expanding(
- self, paramname, expected, use_assert_compile, use_positional
+ def test_standalone_bindparam_escape_collision(self, use_positional):
+ """this case is currently not supported
+
+ it's kinda bad since positional takes the unescaped param
+ while non positional takes the escaped one.
+ """
+ stmt = select(table1.c.myid).where(
+ table1.c.name == bindparam("[brackets]", value="x"),
+ table1.c.description == bindparam("_brackets_", value="y"),
+ )
+
+ if use_positional:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable WHERE mytable.name = ? "
+ "AND mytable.description = ?",
+ params={"[brackets]": "a", "_brackets_": "b"},
+ checkpositional=("a", "a"),
+ dialect="sqlite",
+ )
+ else:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable WHERE mytable.name = "
+ ":_brackets_ AND mytable.description = :_brackets_",
+ params={"[brackets]": "a", "_brackets_": "b"},
+ checkparams={"_brackets_": "b"},
+ dialect="default",
+ )
+
+ paramstyle = testing.variation("paramstyle", ["named", "qmark", "numeric"])
+
+ @standalone_escape
+ @paramstyle
+ def test_standalone_bindparam_escape_expanding_compile(
+ self, paramname, expected, paramstyle
):
stmt = select(table1.c.myid).where(
table1.c.name.in_(bindparam(paramname, value=["a", "b"]))
)
- if use_assert_compile:
- if use_positional:
- self.assert_compile(
- stmt,
- "SELECT mytable.myid FROM mytable "
- "WHERE mytable.name IN (?, ?)",
- params={paramname: ["y", "z"]},
- # NOTE: this is what render_postcompile will do right now
- # if you run construct_params(). render_postcompile mode
- # is not actually used by the execution internals, it's for
- # user-facing compilation code. So this is likely a
- # current limitation of construct_params() which is not
- # doing the full blown postcompile; just assert that's
- # what it does for now. it likely should be corrected
- # to make more sense.
- checkpositional=(["y", "z"], ["y", "z"]),
- dialect="sqlite",
- render_postcompile=True,
- )
- else:
- self.assert_compile(
- stmt,
- "SELECT mytable.myid FROM mytable WHERE mytable.name IN "
- "(:%s_1, :%s_2)" % (expected, expected),
- params={paramname: ["y", "z"]},
- # NOTE: this is what render_postcompile will do right now
- # if you run construct_params(). render_postcompile mode
- # is not actually used by the execution internals, it's for
- # user-facing compilation code. So this is likely a
- # current limitation of construct_params() which is not
- # doing the full blown postcompile; just assert that's
- # what it does for now. it likely should be corrected
- # to make more sense.
- checkparams={
- "%s_1" % expected: ["y", "z"],
- "%s_2" % expected: ["y", "z"],
- },
- dialect="default",
- render_postcompile=True,
- )
+ # NOTE: below the rendered params are just what
+ # render_postcompile will do right now
+ # if you run construct_params(). render_postcompile mode
+ # is not actually used by the execution internals, it's for
+ # user-facing compilation code. So this is likely a
+ # current limitation of construct_params() which is not
+ # doing the full blown postcompile; just assert that's
+ # what it does for now. it likely should be corrected
+ # to make more sense.
+ if paramstyle.qmark:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable "
+ "WHERE mytable.name IN (?, ?)",
+ params={paramname: ["y", "z"]},
+ checkpositional=(["y", "z"], ["y", "z"]),
+ dialect="sqlite",
+ render_postcompile=True,
+ )
+ elif paramstyle.numeric:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable "
+ "WHERE mytable.name IN (:1, :2)",
+ params={paramname: ["y", "z"]},
+ checkpositional=(["y", "z"], ["y", "z"]),
+ dialect=sqlite.dialect(paramstyle="numeric"),
+ render_postcompile=True,
+ )
+ elif paramstyle.named:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable WHERE mytable.name IN "
+ "(:%s_1, :%s_2)" % (expected, expected),
+ params={paramname: ["y", "z"]},
+ checkparams={
+ "%s_1" % expected: ["y", "z"],
+ "%s_2" % expected: ["y", "z"],
+ },
+ dialect="default",
+ render_postcompile=True,
+ )
else:
- # this is what DefaultDialect actually does.
- # this should be matched to DefaultDialect._init_compiled()
- if use_positional:
- compiled = stmt.compile(
- dialect=default.DefaultDialect(paramstyle="qmark")
- )
- else:
- compiled = stmt.compile(dialect=default.DefaultDialect())
+ paramstyle.fail()
- checkparams = compiled.construct_params(
- {paramname: ["y", "z"]}, escape_names=False
- )
+ @standalone_escape
+ @paramstyle
+ def test_standalone_bindparam_escape_expanding(
+ self, paramname, expected, paramstyle
+ ):
+ stmt = select(table1.c.myid).where(
+ table1.c.name.in_(bindparam(paramname, value=["a", "b"]))
+ )
+ # this is what DefaultDialect actually does.
+ # this should be matched to DefaultDialect._init_compiled()
+ if paramstyle.qmark:
+ dialect = default.DefaultDialect(paramstyle="qmark")
+ elif paramstyle.numeric:
+ dialect = default.DefaultDialect(paramstyle="numeric")
+ else:
+ dialect = default.DefaultDialect()
- # nothing actually happened. if the compiler had
- # render_postcompile set, the
- # above weird param thing happens
- eq_(checkparams, {paramname: ["y", "z"]})
+ compiled = stmt.compile(dialect=dialect)
+ checkparams = compiled.construct_params(
+ {paramname: ["y", "z"]}, escape_names=False
+ )
- expanded_state = compiled._process_parameters_for_postcompile(
- checkparams
- )
+ # nothing actually happened. if the compiler had
+ # render_postcompile set, the
+ # above weird param thing happens
+ eq_(checkparams, {paramname: ["y", "z"]})
+
+ expanded_state = compiled._process_parameters_for_postcompile(
+ checkparams
+ )
+ eq_(
+ expanded_state.additional_parameters,
+ {f"{expected}_1": "y", f"{expected}_2": "z"},
+ )
+
+ if paramstyle.qmark or paramstyle.numeric:
eq_(
- expanded_state.additional_parameters,
- {f"{expected}_1": "y", f"{expected}_2": "z"},
+ expanded_state.positiontup,
+ [f"{expected}_1", f"{expected}_2"],
)
- if use_positional:
- eq_(
- expanded_state.positiontup,
- [f"{expected}_1", f"{expected}_2"],
+ @paramstyle
+ def test_expanding_in_repeated(self, paramstyle):
+ stmt = (
+ select(table1)
+ .where(
+ table1.c.name.in_(
+ bindparam("uname", value=["h", "e"], expanding=True)
+ )
+ | table1.c.name.in_(
+ bindparam("uname2", value=["y"], expanding=True)
+ )
+ )
+ .where(table1.c.myid == 8)
+ )
+ stmt = stmt.union(
+ select(table1)
+ .where(
+ table1.c.name.in_(
+ bindparam("uname", value=["h", "e"], expanding=True)
+ )
+ | table1.c.name.in_(
+ bindparam("uname2", value=["y"], expanding=True)
)
+ )
+ .where(table1.c.myid == 9)
+ ).order_by("myid")
+
+ # NOTE: below the rendered params are just what
+ # render_postcompile will do right now
+ # if you run construct_params(). render_postcompile mode
+ # is not actually used by the execution internals, it's for
+ # user-facing compilation code. So this is likely a
+ # current limitation of construct_params() which is not
+ # doing the full blown postcompile; just assert that's
+ # what it does for now. it likely should be corrected
+ # to make more sense.
+
+ if paramstyle.qmark:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable "
+ "WHERE (mytable.name IN (?, ?) OR "
+ "mytable.name IN (?)) "
+ "AND mytable.myid = ? "
+ "UNION SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable "
+ "WHERE (mytable.name IN (?, ?) OR "
+ "mytable.name IN (?)) "
+ "AND mytable.myid = ? ORDER BY myid",
+ params={"uname": ["y", "z"], "uname2": ["a"]},
+ checkpositional=(
+ ["y", "z"],
+ ["y", "z"],
+ ["a"],
+ 8,
+ ["y", "z"],
+ ["y", "z"],
+ ["a"],
+ 9,
+ ),
+ dialect="sqlite",
+ render_postcompile=True,
+ )
+ elif paramstyle.numeric:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable "
+ "WHERE (mytable.name IN (:3, :4) OR "
+ "mytable.name IN (:5)) "
+ "AND mytable.myid = :1 "
+ "UNION SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable "
+ "WHERE (mytable.name IN (:3, :4) OR "
+ "mytable.name IN (:5)) "
+ "AND mytable.myid = :2 ORDER BY myid",
+ params={"uname": ["y", "z"], "uname2": ["a"]},
+ checkpositional=(8, 9, ["y", "z"], ["y", "z"], ["a"]),
+ dialect=sqlite.dialect(paramstyle="numeric"),
+ render_postcompile=True,
+ )
+ elif paramstyle.named:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable "
+ "WHERE (mytable.name IN (:uname_1, :uname_2) OR "
+ "mytable.name IN (:uname2_1)) "
+ "AND mytable.myid = :myid_1 "
+ "UNION SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable "
+ "WHERE (mytable.name IN (:uname_1, :uname_2) OR "
+ "mytable.name IN (:uname2_1)) "
+ "AND mytable.myid = :myid_2 ORDER BY myid",
+ params={"uname": ["y", "z"], "uname2": ["a"]},
+ checkparams={
+ "uname": ["y", "z"],
+ "uname2": ["a"],
+ "uname_1": ["y", "z"],
+ "uname_2": ["y", "z"],
+ "uname2_1": ["a"],
+ "myid_1": 8,
+ "myid_2": 9,
+ },
+ dialect="default",
+ render_postcompile=True,
+ )
+ else:
+ paramstyle.fail()
+
+ def test_numeric_dollar_bindparam(self):
+ stmt = table1.select().where(
+ table1.c.name == "a", table1.c.myid.in_([1, 2])
+ )
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable "
+ "WHERE mytable.name = $1 "
+ "AND mytable.myid IN ($2, $3)",
+ checkpositional=("a", 1, 2),
+ dialect=default.DefaultDialect(paramstyle="numeric_dollar"),
+ render_postcompile=True,
+ )
class UnsupportedTest(fixtures.TestBase):
@@ -5096,6 +5268,28 @@ class StringifySpecialTest(fixtures.TestBase):
"INSERT INTO mytable (myid) VALUES (:myid_m0), (:myid_m1)",
)
+ def test_multirow_insert_positional(self):
+ stmt = table1.insert().values([{"myid": 1}, {"myid": 2}])
+ eq_ignore_whitespace(
+ stmt.compile(dialect=sqlite.dialect()).string,
+ "INSERT INTO mytable (myid) VALUES (?), (?)",
+ )
+
+ def test_multirow_insert_numeric(self):
+ stmt = table1.insert().values([{"myid": 1}, {"myid": 2}])
+ eq_ignore_whitespace(
+ stmt.compile(dialect=sqlite.dialect(paramstyle="numeric")).string,
+ "INSERT INTO mytable (myid) VALUES (:1), (:2)",
+ )
+
+ def test_insert_noparams_numeric(self):
+ ii = table1.insert().returning(table1.c.myid)
+ eq_ignore_whitespace(
+ ii.compile(dialect=sqlite.dialect(paramstyle="numeric")).string,
+ "INSERT INTO mytable (myid, name, description) VALUES "
+ "(:1, :2, :3) RETURNING myid",
+ )
+
def test_cte(self):
# stringify of these was supported anyway by defaultdialect.
stmt = select(table1.c.myid).cte()
@@ -5153,6 +5347,42 @@ class StringifySpecialTest(fixtures.TestBase):
"SELECT CAST(mytable.myid AS MyType()) AS myid FROM mytable",
)
+ def test_dialect_sub_compile(self):
+ class Widget(ClauseElement):
+ __visit_name__ = "widget"
+ stringify_dialect = "sqlite"
+
+ def visit_widget(self, element, **kw):
+ return "widget"
+
+ with mock.patch(
+ "sqlalchemy.dialects.sqlite.base.SQLiteCompiler.visit_widget",
+ visit_widget,
+ create=True,
+ ):
+ eq_(str(Grouping(Widget())), "(widget)")
+
+ def test_dialect_sub_compile_w_binds(self):
+ """test sub-compile into a new compiler where
+ state != CompilerState.COMPILING, but we have to render a bindparam
+ string. has to render the correct template.
+
+ """
+
+ class Widget(ClauseElement):
+ __visit_name__ = "widget"
+ stringify_dialect = "sqlite"
+
+ def visit_widget(self, element, **kw):
+ return f"widget {self.process(bindparam('q'), **kw)}"
+
+ with mock.patch(
+ "sqlalchemy.dialects.sqlite.base.SQLiteCompiler.visit_widget",
+ visit_widget,
+ create=True,
+ ):
+ eq_(str(Grouping(Widget())), "(widget ?)")
+
def test_within_group(self):
# stringify of these was supported anyway by defaultdialect.
from sqlalchemy import within_group
diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py
index b89d18de6..502104dae 100644
--- a/test/sql/test_cte.py
+++ b/test/sql/test_cte.py
@@ -993,20 +993,20 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(
s,
'WITH regional_sales AS (SELECT orders."order" '
- 'AS "order", :1 AS anon_2 FROM orders) SELECT '
- 'regional_sales."order", :2 AS anon_1 FROM regional_sales',
- checkpositional=("x", "y"),
+ 'AS "order", :2 AS anon_2 FROM orders) SELECT '
+ 'regional_sales."order", :1 AS anon_1 FROM regional_sales',
+ checkpositional=("y", "x"),
dialect=dialect,
)
self.assert_compile(
s.union(s),
'WITH regional_sales AS (SELECT orders."order" '
- 'AS "order", :1 AS anon_2 FROM orders) SELECT '
- 'regional_sales."order", :2 AS anon_1 FROM regional_sales '
- 'UNION SELECT regional_sales."order", :3 AS anon_1 '
+ 'AS "order", :2 AS anon_2 FROM orders) SELECT '
+ 'regional_sales."order", :1 AS anon_1 FROM regional_sales '
+ 'UNION SELECT regional_sales."order", :1 AS anon_1 '
"FROM regional_sales",
- checkpositional=("x", "y", "y"),
+ checkpositional=("y", "x"),
dialect=dialect,
)
@@ -1057,8 +1057,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(
s3,
'WITH regional_sales_1 AS (SELECT orders."order" AS "order" '
- 'FROM orders WHERE orders."order" = :1), regional_sales_2 AS '
- '(SELECT orders."order" = :2 AS anon_1, '
+ 'FROM orders WHERE orders."order" = :2), regional_sales_2 AS '
+ '(SELECT orders."order" = :1 AS anon_1, '
'anon_2."order" AS "order", '
'orders."order" AS order_1, '
'regional_sales_1."order" AS order_2 FROM orders, '
@@ -1067,7 +1067,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
'WHERE orders."order" = :3) SELECT regional_sales_2.anon_1, '
'regional_sales_2."order", regional_sales_2.order_1, '
"regional_sales_2.order_2 FROM regional_sales_2",
- checkpositional=("x", "y", "z"),
+ checkpositional=("y", "x", "z"),
dialect=dialect,
)
diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py
index ac9ac4022..1c24d4c79 100644
--- a/test/sql/test_insert.py
+++ b/test/sql/test_insert.py
@@ -488,7 +488,8 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
dialect=postgresql.dialect(),
)
- def test_heterogeneous_multi_values(self):
+ @testing.variation("paramstyle", ["pg", "qmark", "numeric", "dollar"])
+ def test_heterogeneous_multi_values(self, paramstyle):
"""for #6047, originally I thought we'd take any insert().values()
and be able to convert it to a "many" style execution that we can
cache.
@@ -519,33 +520,81 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
]
)
+ pos_par = (
+ 1,
+ 1,
+ 2,
+ 2,
+ 1,
+ 2,
+ 2,
+ 3,
+ 1,
+ 2,
+ 2,
+ 10,
+ )
+
# SQL expressions in the params at arbitrary locations means
# we have to scan them at compile time, and the shape of the bound
# parameters is not predictable. so for #6047 where I originally
# thought all of values() could be rewritten, this makes it not
# really worth it.
- self.assert_compile(
- stmt,
- "INSERT INTO t (x, y, z) VALUES "
- "(%(x_m0)s, sum(%(sum_1)s, %(sum_2)s), %(z_m0)s), "
- "(sum(%(sum_3)s, %(sum_4)s), %(y_m1)s, %(z_m1)s), "
- "(sum(%(sum_5)s, %(sum_6)s), %(y_m2)s, foo(%(foo_1)s))",
- checkparams={
- "x_m0": 1,
- "sum_1": 1,
- "sum_2": 2,
- "z_m0": 2,
- "sum_3": 1,
- "sum_4": 2,
- "y_m1": 2,
- "z_m1": 3,
- "sum_5": 1,
- "sum_6": 2,
- "y_m2": 2,
- "foo_1": 10,
- },
- dialect=postgresql.dialect(),
- )
+ if paramstyle.pg:
+ self.assert_compile(
+ stmt,
+ "INSERT INTO t (x, y, z) VALUES "
+ "(%(x_m0)s, sum(%(sum_1)s, %(sum_2)s), %(z_m0)s), "
+ "(sum(%(sum_3)s, %(sum_4)s), %(y_m1)s, %(z_m1)s), "
+ "(sum(%(sum_5)s, %(sum_6)s), %(y_m2)s, foo(%(foo_1)s))",
+ checkparams={
+ "x_m0": 1,
+ "sum_1": 1,
+ "sum_2": 2,
+ "z_m0": 2,
+ "sum_3": 1,
+ "sum_4": 2,
+ "y_m1": 2,
+ "z_m1": 3,
+ "sum_5": 1,
+ "sum_6": 2,
+ "y_m2": 2,
+ "foo_1": 10,
+ },
+ dialect=postgresql.dialect(),
+ )
+ elif paramstyle.qmark:
+ self.assert_compile(
+ stmt,
+ "INSERT INTO t (x, y, z) VALUES "
+ "(?, sum(?, ?), ?), "
+ "(sum(?, ?), ?, ?), "
+ "(sum(?, ?), ?, foo(?))",
+ checkpositional=pos_par,
+ dialect=sqlite.dialect(),
+ )
+ elif paramstyle.numeric:
+ self.assert_compile(
+ stmt,
+ "INSERT INTO t (x, y, z) VALUES "
+ "(:1, sum(:2, :3), :4), "
+ "(sum(:5, :6), :7, :8), "
+ "(sum(:9, :10), :11, foo(:12))",
+ checkpositional=pos_par,
+ dialect=sqlite.dialect(paramstyle="numeric"),
+ )
+ elif paramstyle.dollar:
+ self.assert_compile(
+ stmt,
+ "INSERT INTO t (x, y, z) VALUES "
+ "($1, sum($2, $3), $4), "
+ "(sum($5, $6), $7, $8), "
+ "(sum($9, $10), $11, foo($12))",
+ checkpositional=pos_par,
+ dialect=sqlite.dialect(paramstyle="numeric_dollar"),
+ )
+ else:
+ paramstyle.fail()
def test_insert_seq_pk_multi_values_seq_not_supported(self):
m = MetaData()
diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py
index b856acfd3..7f1124c84 100644
--- a/test/sql/test_resultset.py
+++ b/test/sql/test_resultset.py
@@ -107,7 +107,7 @@ class CursorResultTest(fixtures.TablesTest):
Column("y", String(50)),
)
- @testing.requires.insert_returning
+ @testing.requires.insert_executemany_returning
def test_splice_horizontally(self, connection):
users = self.tables.users
addresses = self.tables.addresses
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 91413ff35..59519a5ec 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -3322,7 +3322,7 @@ class ExpressionTest(
elif expression_type.right_side:
expr = (column("x", Integer) == Widget(52)).right
else:
- assert False
+ expression_type.fail()
if secondary_adapt:
is_(expr.type._type_affinity, String)
diff --git a/test/sql/test_update.py b/test/sql/test_update.py
index cd7f992e2..66971f64e 100644
--- a/test/sql/test_update.py
+++ b/test/sql/test_update.py
@@ -907,7 +907,8 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
dialect=dialect,
)
- def test_update_bound_ordering(self):
+ @testing.variation("paramstyle", ["qmark", "format", "numeric"])
+ def test_update_bound_ordering(self, paramstyle):
"""test that bound parameters between the UPDATE and FROM clauses
order correctly in different SQL compilation scenarios.
@@ -921,30 +922,47 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
.values(name="foo")
)
- dialect = default.StrCompileDialect()
- dialect.positional = True
- self.assert_compile(
- upd,
- "UPDATE mytable SET name=:name FROM (SELECT "
- "myothertable.otherid AS otherid, "
- "myothertable.othername AS othername "
- "FROM myothertable "
- "WHERE myothertable.otherid = :otherid_1) AS anon_1 "
- "WHERE mytable.name = anon_1.othername",
- checkpositional=("foo", 5),
- dialect=dialect,
- )
+ if paramstyle.qmark:
- self.assert_compile(
- upd,
- "UPDATE mytable, (SELECT myothertable.otherid AS otherid, "
- "myothertable.othername AS othername "
- "FROM myothertable "
- "WHERE myothertable.otherid = %s) AS anon_1 SET mytable.name=%s "
- "WHERE mytable.name = anon_1.othername",
- checkpositional=(5, "foo"),
- dialect=mysql.dialect(),
- )
+ dialect = default.StrCompileDialect(paramstyle="qmark")
+ self.assert_compile(
+ upd,
+ "UPDATE mytable SET name=? FROM (SELECT "
+ "myothertable.otherid AS otherid, "
+ "myothertable.othername AS othername "
+ "FROM myothertable "
+ "WHERE myothertable.otherid = ?) AS anon_1 "
+ "WHERE mytable.name = anon_1.othername",
+ checkpositional=("foo", 5),
+ dialect=dialect,
+ )
+ elif paramstyle.format:
+ self.assert_compile(
+ upd,
+ "UPDATE mytable, (SELECT myothertable.otherid AS otherid, "
+ "myothertable.othername AS othername "
+ "FROM myothertable "
+ "WHERE myothertable.otherid = %s) AS anon_1 "
+ "SET mytable.name=%s "
+ "WHERE mytable.name = anon_1.othername",
+ checkpositional=(5, "foo"),
+ dialect=mysql.dialect(),
+ )
+ elif paramstyle.numeric:
+ dialect = default.StrCompileDialect(paramstyle="numeric")
+ self.assert_compile(
+ upd,
+ "UPDATE mytable SET name=:1 FROM (SELECT "
+ "myothertable.otherid AS otherid, "
+ "myothertable.othername AS othername "
+ "FROM myothertable "
+ "WHERE myothertable.otherid = :2) AS anon_1 "
+ "WHERE mytable.name = anon_1.othername",
+ checkpositional=("foo", 5),
+ dialect=dialect,
+ )
+ else:
+ paramstyle.fail()
class UpdateFromCompileTest(
diff --git a/tox.ini b/tox.ini
index 50c39f610..b74805052 100644
--- a/tox.ini
+++ b/tox.ini
@@ -89,7 +89,7 @@ setenv=
sqlite: SQLITE={env:TOX_SQLITE:--db sqlite}
sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}
- py3{,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite}
+ py3{,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver pysqlite_numeric --dbdriver aiosqlite}
py3{,7,8,9}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite --dbdriver pysqlcipher}