diff options
28 files changed, 989 insertions, 348 deletions
diff --git a/doc/build/changelog/unreleased_20/8849.rst b/doc/build/changelog/unreleased_20/8849.rst new file mode 100644 index 000000000..29ecf2a2c --- /dev/null +++ b/doc/build/changelog/unreleased_20/8849.rst @@ -0,0 +1,14 @@ +.. change:: + :tags: bug, sql + :tickets: 8849 + + Reworked how numeric paramstyle behavers, in particular, fixed insertmany + behaviour that prior to this was non functional; added support for repeated + parameter without duplicating them like in other positional dialects; + introduced new numeric paramstyle called ``numeric_dollar`` that can be + used to render statements that use the PostgreSQL placeholder style ( + i.e. ``$1, $2, $3``). + This change requires that the dialect supports out of order placehoders, + that may be used used in the statements, in particular when using + insert-many values with statement that have parameters in the returning + clause. diff --git a/doc/build/changelog/unreleased_20/8926.rst b/doc/build/changelog/unreleased_20/8926.rst new file mode 100644 index 000000000..a0000fb87 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8926.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: asyncpg + :tickets: 8926 + + Changed the paramstyle used by asyncpg from ``format`` to + ``numeric_dollar``. This has two main benefits since it does not require + additional processing of the statement and allows for duplicate parameters + to be present in the statements. diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 751dc3dcf..b8f614eba 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -438,9 +438,6 @@ class AsyncAdapt_asyncpg_cursor: def _handle_exception(self, error): self._adapt_connection._handle_exception(error) - def _parameter_placeholders(self, params): - return tuple(f"${idx:d}" for idx, _ in enumerate(params, 1)) - async def _prepare_and_execute(self, operation, parameters): adapt_connection = self._adapt_connection @@ -449,11 +446,7 @@ class AsyncAdapt_asyncpg_cursor: if not adapt_connection._started: await adapt_connection._start_transaction() - if parameters is not None: - operation = operation % self._parameter_placeholders( - parameters - ) - else: + if parameters is None: parameters = () try: @@ -506,10 +499,6 @@ class AsyncAdapt_asyncpg_cursor: if not adapt_connection._started: await adapt_connection._start_transaction() - operation = operation % self._parameter_placeholders( - seq_of_parameters[0] - ) - try: return await self._connection.executemany( operation, seq_of_parameters @@ -808,7 +797,7 @@ class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection): class AsyncAdapt_asyncpg_dbapi: def __init__(self, asyncpg): self.asyncpg = asyncpg - self.paramstyle = "format" + self.paramstyle = "numeric_dollar" def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) @@ -900,7 +889,7 @@ class PGDialect_asyncpg(PGDialect): render_bind_cast = True has_terminate = True - default_paramstyle = "format" + default_paramstyle = "numeric_dollar" supports_sane_multi_rowcount = False execution_ctx_cls = PGExecutionContext_asyncpg statement_compiler = PGCompiler_asyncpg diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py index 05ee6c625..851f0951f 100644 --- a/lib/sqlalchemy/dialects/sqlite/provision.py +++ b/lib/sqlalchemy/dialects/sqlite/provision.py @@ -18,7 +18,13 @@ from ...testing.provision import upsert # TODO: I can't get this to build dynamically with pytest-xdist procs -_drivernames = {"pysqlite", "aiosqlite", "pysqlcipher"} +_drivernames = { + "pysqlite", + "aiosqlite", + "pysqlcipher", + "pysqlite_numeric", + "pysqlite_dollar", +} @generate_driver_url.for_db("sqlite") diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 4475ccae7..c04a3601d 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -637,3 +637,110 @@ class SQLiteDialect_pysqlite(SQLiteDialect): dialect = SQLiteDialect_pysqlite + + +class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite): + """numeric dialect for testing only + + internal use only. This dialect is **NOT** supported by SQLAlchemy + and may change at any time. + + """ + + supports_statement_cache = True + default_paramstyle = "numeric" + driver = "pysqlite_numeric" + + _first_bind = ":1" + _not_in_statement_regexp = None + + def __init__(self, *arg, **kw): + kw.setdefault("paramstyle", "numeric") + super().__init__(*arg, **kw) + + def create_connect_args(self, url): + arg, opts = super().create_connect_args(url) + opts["factory"] = self._fix_sqlite_issue_99953() + return arg, opts + + def _fix_sqlite_issue_99953(self): + import sqlite3 + + first_bind = self._first_bind + if self._not_in_statement_regexp: + nis = self._not_in_statement_regexp + + def _test_sql(sql): + m = nis.search(sql) + assert not m, f"Found {nis.pattern!r} in {sql!r}" + + else: + + def _test_sql(sql): + pass + + def _numeric_param_as_dict(parameters): + if parameters: + assert isinstance(parameters, tuple) + return { + str(idx): value for idx, value in enumerate(parameters, 1) + } + else: + return () + + class SQLiteFix99953Cursor(sqlite3.Cursor): + def execute(self, sql, parameters=()): + _test_sql(sql) + if first_bind in sql: + parameters = _numeric_param_as_dict(parameters) + return super().execute(sql, parameters) + + def executemany(self, sql, parameters): + _test_sql(sql) + if first_bind in sql: + parameters = [ + _numeric_param_as_dict(p) for p in parameters + ] + return super().executemany(sql, parameters) + + class SQLiteFix99953Connection(sqlite3.Connection): + def cursor(self, factory=None): + if factory is None: + factory = SQLiteFix99953Cursor + return super().cursor(factory=factory) + + def execute(self, sql, parameters=()): + _test_sql(sql) + if first_bind in sql: + parameters = _numeric_param_as_dict(parameters) + return super().execute(sql, parameters) + + def executemany(self, sql, parameters): + _test_sql(sql) + if first_bind in sql: + parameters = [ + _numeric_param_as_dict(p) for p in parameters + ] + return super().executemany(sql, parameters) + + return SQLiteFix99953Connection + + +class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric): + """numeric dialect that uses $ for testing only + + internal use only. This dialect is **NOT** supported by SQLAlchemy + and may change at any time. + + """ + + supports_statement_cache = True + default_paramstyle = "numeric_dollar" + driver = "pysqlite_dollar" + + _first_bind = "$1" + _not_in_statement_regexp = re.compile(r"[^\d]:\d+") + + def __init__(self, *arg, **kw): + kw.setdefault("paramstyle", "numeric_dollar") + super().__init__(*arg, **kw) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 3cc9cab8b..4647c84d1 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -320,7 +320,12 @@ class DefaultDialect(Dialect): self.paramstyle = self.dbapi.paramstyle else: self.paramstyle = self.default_paramstyle - self.positional = self.paramstyle in ("qmark", "format", "numeric") + self.positional = self.paramstyle in ( + "qmark", + "format", + "numeric", + "numeric_dollar", + ) self.identifier_preparer = self.preparer(self) self._on_connect_isolation_level = isolation_level diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 2f5efce25..ddf8a53fb 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -255,7 +255,9 @@ SchemaTranslateMapType = Mapping[Optional[str], Optional[str]] _ImmutableExecuteOptions = immutabledict[str, Any] -_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"] +_ParamStyle = Literal[ + "qmark", "numeric", "named", "format", "pyformat", "numeric_dollar" +] _GenericSetInputSizesType = List[Tuple[str, Any, "TypeEngine[Any]"]] diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 7ac279ee2..d7358ad3b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -227,11 +227,13 @@ FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I) BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE) BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE) +_pyformat_template = "%%(%(name)s)s" BIND_TEMPLATES = { - "pyformat": "%%(%(name)s)s", + "pyformat": _pyformat_template, "qmark": "?", "format": "%%s", "numeric": ":[_POSITION]", + "numeric_dollar": "$[_POSITION]", "named": ":%(name)s", } @@ -420,6 +422,22 @@ class _InsertManyValues(NamedTuple): num_positional_params_counted: int +class CompilerState(IntEnum): + COMPILING = 0 + """statement is present, compilation phase in progress""" + + STRING_APPLIED = 1 + """statement is present, string form of the statement has been applied. + + Additional processors by subclasses may still be pending. + + """ + + NO_STATEMENT = 2 + """compiler does not have a statement to compile, is used + for method access""" + + class Linting(IntEnum): """represent preferences for the 'SQL linting' feature. @@ -527,6 +545,14 @@ class Compiled: defaults. """ + statement: Optional[ClauseElement] = None + "The statement to compile." + string: str = "" + "The string representation of the ``statement``" + + state: CompilerState + """description of the compiler's state""" + is_sql = False is_ddl = False @@ -618,7 +644,6 @@ class Compiled: """ - self.dialect = dialect self.preparer = self.dialect.identifier_preparer if schema_translate_map: @@ -628,6 +653,7 @@ class Compiled: ) if statement is not None: + self.state = CompilerState.COMPILING self.statement = statement self.can_execute = statement.supports_execution self._annotations = statement._annotations @@ -641,6 +667,11 @@ class Compiled: self.string = self.preparer._render_schema_translates( self.string, schema_translate_map ) + + self.state = CompilerState.STRING_APPLIED + else: + self.state = CompilerState.NO_STATEMENT + self._gen_time = perf_counter() def _execute_on_connection( @@ -672,7 +703,10 @@ class Compiled: def __str__(self) -> str: """Return the string text of the generated SQL or DDL.""" - return self.string or "" + if self.state is CompilerState.STRING_APPLIED: + return self.string + else: + return "" def construct_params( self, @@ -859,6 +893,19 @@ class SQLCompiler(Compiled): driver/DB enforces this """ + bindtemplate: str + """template to render bound parameters based on paramstyle.""" + + compilation_bindtemplate: str + """template used by compiler to render parameters before positional + paramstyle application""" + + _numeric_binds_identifier_char: str + """Character that's used to as the identifier of a numerical bind param. + For example if this char is set to ``$``, numerical binds will be rendered + in the form ``$1, $2, $3``. + """ + _result_columns: List[ResultColumnsEntry] """relates label names in the final SQL to a tuple of local column/label name, ColumnElement object (if any) and @@ -967,13 +1014,17 @@ class SQLCompiler(Compiled): and is combined with the :attr:`_sql.Compiled.params` dictionary to render parameters. + This sequence always contains the unescaped name of the parameters. + .. seealso:: :ref:`faq_sql_expression_string` - includes a usage example for debugging use cases. """ - positiontup_level: Optional[Dict[str, int]] = None + _values_bindparam: Optional[List[str]] = None + + _visited_bindparam: Optional[List[str]] = None inline: bool = False @@ -988,9 +1039,12 @@ class SQLCompiler(Compiled): level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]] ctes_recursive: bool - cte_positional: Dict[CTE, List[str]] - cte_level: Dict[CTE, int] - cte_order: Dict[Optional[CTE], List[CTE]] + + _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]") + _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s") + _positional_pattern = re.compile( + f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}" + ) def __init__( self, @@ -1055,10 +1109,15 @@ class SQLCompiler(Compiled): # true if the paramstyle is positional self.positional = dialect.positional if self.positional: - self.positiontup_level = {} - self.positiontup = [] - self._numeric_binds = dialect.paramstyle == "numeric" - self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + self._numeric_binds = nb = dialect.paramstyle.startswith("numeric") + if nb: + self._numeric_binds_identifier_char = ( + "$" if dialect.paramstyle == "numeric_dollar" else ":" + ) + + self.compilation_bindtemplate = _pyformat_template + else: + self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle] self.ctes = None @@ -1095,11 +1154,17 @@ class SQLCompiler(Compiled): ): self.inline = True - if self.positional and self._numeric_binds: - self._apply_numbered_params() + self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + + if self.state is CompilerState.STRING_APPLIED: + if self.positional: + if self._numeric_binds: + self._process_numeric() + else: + self._process_positional() - if self._render_postcompile: - self._process_parameters_for_postcompile(_populate_self=True) + if self._render_postcompile: + self._process_parameters_for_postcompile(_populate_self=True) @property def insert_single_values_expr(self) -> Optional[str]: @@ -1135,7 +1200,7 @@ class SQLCompiler(Compiled): """ if self.implicit_returning: return self.implicit_returning - elif is_dml(self.statement): + elif self.statement is not None and is_dml(self.statement): return [ c for c in self.statement._all_selected_columns @@ -1217,10 +1282,6 @@ class SQLCompiler(Compiled): self.level_name_by_cte = {} self.ctes_recursive = False - if self.positional: - self.cte_positional = {} - self.cte_level = {} - self.cte_order = collections.defaultdict(list) return ctes @@ -1248,12 +1309,145 @@ class SQLCompiler(Compiled): ordered_columns, ) - def _apply_numbered_params(self): - poscount = itertools.count(1) + def _process_positional(self): + assert not self.positiontup + assert self.state is CompilerState.STRING_APPLIED + assert not self._numeric_binds + + if self.dialect.paramstyle == "format": + placeholder = "%s" + else: + assert self.dialect.paramstyle == "qmark" + placeholder = "?" + + positions = [] + + def find_position(m: re.Match[str]) -> str: + normal_bind = m.group(1) + if normal_bind: + positions.append(normal_bind) + return placeholder + else: + # this a post-compile bind + positions.append(m.group(2)) + return m.group(0) + self.string = re.sub( - r"\[_POSITION\]", lambda m: str(next(poscount)), self.string + self._positional_pattern, find_position, self.string ) + if self.escaped_bind_names: + reverse_escape = {v: k for k, v in self.escaped_bind_names.items()} + assert len(self.escaped_bind_names) == len(reverse_escape) + self.positiontup = [ + reverse_escape.get(name, name) for name in positions + ] + else: + self.positiontup = positions + + if self._insertmanyvalues: + positions = [] + single_values_expr = re.sub( + self._positional_pattern, + find_position, + self._insertmanyvalues.single_values_expr, + ) + insert_crud_params = [ + ( + v[0], + v[1], + re.sub(self._positional_pattern, find_position, v[2]), + v[3], + ) + for v in self._insertmanyvalues.insert_crud_params + ] + + self._insertmanyvalues = _InsertManyValues( + is_default_expr=self._insertmanyvalues.is_default_expr, + single_values_expr=single_values_expr, + insert_crud_params=insert_crud_params, + num_positional_params_counted=( + self._insertmanyvalues.num_positional_params_counted + ), + ) + + def _process_numeric(self): + assert self._numeric_binds + assert self.state is CompilerState.STRING_APPLIED + + num = 1 + param_pos: Dict[str, str] = {} + order: Iterable[str] + if self._insertmanyvalues and self._values_bindparam is not None: + # bindparams that are not in values are always placed first. + # this avoids the need of changing them when using executemany + # values () () + order = itertools.chain( + ( + name + for name in self.bind_names.values() + if name not in self._values_bindparam + ), + self.bind_names.values(), + ) + else: + order = self.bind_names.values() + + for bind_name in order: + if bind_name in param_pos: + continue + bind = self.binds[bind_name] + if ( + bind in self.post_compile_params + or bind in self.literal_execute_params + ): + # set to None to just mark the in positiontup, it will not + # be replaced below. + param_pos[bind_name] = None # type: ignore + else: + ph = f"{self._numeric_binds_identifier_char}{num}" + num += 1 + param_pos[bind_name] = ph + + self.next_numeric_pos = num + + self.positiontup = list(param_pos) + if self.escaped_bind_names: + reverse_escape = {v: k for k, v in self.escaped_bind_names.items()} + assert len(self.escaped_bind_names) == len(reverse_escape) + param_pos = { + self.escaped_bind_names.get(name, name): pos + for name, pos in param_pos.items() + } + + # Can't use format here since % chars are not escaped. + self.string = self._pyformat_pattern.sub( + lambda m: param_pos[m.group(1)], self.string + ) + + if self._insertmanyvalues: + single_values_expr = ( + # format is ok here since single_values_expr includes only + # place-holders + self._insertmanyvalues.single_values_expr + % param_pos + ) + insert_crud_params = [ + (v[0], v[1], "%s", v[3]) + for v in self._insertmanyvalues.insert_crud_params + ] + + self._insertmanyvalues = _InsertManyValues( + is_default_expr=self._insertmanyvalues.is_default_expr, + # This has the numbers (:1, :2) + single_values_expr=single_values_expr, + # The single binds are instead %s so they can be formatted + insert_crud_params=insert_crud_params, + num_positional_params_counted=( + self._insertmanyvalues.num_positional_params_counted + ), + ) + @util.memoized_property def _bind_processors( self, @@ -1492,39 +1686,30 @@ class SQLCompiler(Compiled): new_processors: Dict[str, _BindProcessorType[Any]] = {} - if self.positional and self._numeric_binds: - # I'm not familiar with any DBAPI that uses 'numeric'. - # strategy would likely be to make use of numbers greater than - # the highest number present; then for expanding parameters, - # append them to the end of the parameter list. that way - # we avoid having to renumber all the existing parameters. - raise NotImplementedError( - "'post-compile' bind parameters are not supported with " - "the 'numeric' paramstyle at this time." - ) - replacement_expressions: Dict[str, Any] = {} to_update_sets: Dict[str, Any] = {} # notes: # *unescaped* parameter names in: - # self.bind_names, self.binds, self._bind_processors + # self.bind_names, self.binds, self._bind_processors, self.positiontup # # *escaped* parameter names in: # construct_params(), replacement_expressions + numeric_positiontup: Optional[List[str]] = None + if self.positional and self.positiontup is not None: names: Iterable[str] = self.positiontup + if self._numeric_binds: + numeric_positiontup = [] else: names = self.bind_names.values() + ebn = self.escaped_bind_names for name in names: - escaped_name = ( - self.escaped_bind_names.get(name, name) - if self.escaped_bind_names - else name - ) + escaped_name = ebn.get(name, name) if ebn else name parameter = self.binds[name] + if parameter in self.literal_execute_params: if escaped_name not in replacement_expressions: value = parameters.pop(escaped_name) @@ -1555,10 +1740,10 @@ class SQLCompiler(Compiled): # in the escaped_bind_names dictionary. values = parameters.pop(name) - leep = self._literal_execute_expanding_parameter - to_update, replacement_expr = leep( + leep_res = self._literal_execute_expanding_parameter( escaped_name, parameter, values ) + (to_update, replacement_expr) = leep_res to_update_sets[escaped_name] = to_update replacement_expressions[escaped_name] = replacement_expr @@ -1583,7 +1768,14 @@ class SQLCompiler(Compiled): for key, _ in to_update if name in single_processors ) - if positiontup is not None: + if numeric_positiontup is not None: + numeric_positiontup.extend( + name for name, _ in to_update + ) + elif positiontup is not None: + # to_update has escaped names, but that's ok since + # these are new names, that aren't in the + # escaped_bind_names dict. positiontup.extend(name for name, _ in to_update) expanded_parameters[name] = [ expand_key for expand_key, _ in to_update @@ -1607,11 +1799,23 @@ class SQLCompiler(Compiled): return expr statement = re.sub( - r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]", - process_expanding, - self.string, + self._post_compile_pattern, process_expanding, self.string ) + if numeric_positiontup is not None: + assert positiontup is not None + param_pos = { + key: f"{self._numeric_binds_identifier_char}{num}" + for num, key in enumerate( + numeric_positiontup, self.next_numeric_pos + ) + } + # Can't use format here since % chars are not escaped. + statement = self._pyformat_pattern.sub( + lambda m: param_pos[m.group(1)], statement + ) + positiontup.extend(numeric_positiontup) + expanded_state = ExpandedState( statement, parameters, @@ -2109,13 +2313,7 @@ class SQLCompiler(Compiled): text = self.process(taf.element, **kw) if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kw.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) @@ -2411,7 +2609,6 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, - visiting_cte=kwargs.get("visiting_cte"), ) + text ) @@ -2625,6 +2822,11 @@ class SQLCompiler(Compiled): dialect = self.dialect typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect) + if self._numeric_binds: + bind_template = self.compilation_bindtemplate + else: + bind_template = self.bindtemplate + if ( self.dialect._bind_typing_render_casts and typ_dialect_impl.render_bind_cast @@ -2634,13 +2836,13 @@ class SQLCompiler(Compiled): return self.render_bind_cast( parameter.type, typ_dialect_impl, - self.bindtemplate % {"name": name}, + bind_template % {"name": name}, ) else: def _render_bindtemplate(name): - return self.bindtemplate % {"name": name} + return bind_template % {"name": name} if not values: to_update = [] @@ -3224,7 +3426,6 @@ class SQLCompiler(Compiled): def bindparam_string( self, name: str, - positional_names: Optional[List[str]] = None, post_compile: bool = False, expanding: bool = False, escaped_from: Optional[str] = None, @@ -3232,12 +3433,9 @@ class SQLCompiler(Compiled): **kw: Any, ) -> str: - if self.positional: - if positional_names is not None: - positional_names.append(name) - else: - self.positiontup.append(name) # type: ignore[union-attr] - self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501 + if self._visited_bindparam is not None: + self._visited_bindparam.append(name) + if not escaped_from: if _BIND_TRANSLATE_RE.search(name): @@ -3271,6 +3469,8 @@ class SQLCompiler(Compiled): if type_impl.render_literal_cast: ret = self.render_bind_cast(bindparam_type, type_impl, ret) return ret + elif self.state is CompilerState.COMPILING: + ret = self.compilation_bindtemplate % {"name": name} else: ret = self.bindtemplate % {"name": name} @@ -3349,8 +3549,6 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = new_level_name + ( cte_opts, ) - if self.positional: - self.cte_level[cte] = cte_level else: cte_level = len(self.stack) if nesting else 1 @@ -3414,8 +3612,6 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = cte_level_name + ( cte_opts, ) - if self.positional: - self.cte_level[cte] = cte_level if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -3455,9 +3651,6 @@ class SQLCompiler(Compiled): ) ) - if self.positional: - kwargs["positional_names"] = self.cte_positional[cte] = [] - assert kwargs.get("subquery", False) is False if not self.stack: @@ -4152,13 +4345,7 @@ class SQLCompiler(Compiled): # In compound query, CTEs are shared at the compound level if self.ctes and (not is_embedded_select or toplevel): nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kwargs.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -4332,7 +4519,6 @@ class SQLCompiler(Compiled): self, nesting_level=None, include_following_stack=False, - visiting_cte=None, ): """ include_following_stack @@ -4367,46 +4553,6 @@ class SQLCompiler(Compiled): return "" ctes_recursive = any([cte.recursive for cte in ctes]) - if self.positional: - self.cte_order[visiting_cte].extend(ctes) - - if visiting_cte is None and self.cte_order: - assert self.positiontup is not None - - def get_nested_positional(cte): - if cte in self.cte_order: - children = self.cte_order.pop(cte) - to_add = list( - itertools.chain.from_iterable( - get_nested_positional(child_cte) - for child_cte in children - ) - ) - if cte in self.cte_positional: - return reorder_positional( - self.cte_positional[cte], - to_add, - self.cte_level[children[0]], - ) - else: - return to_add - else: - return self.cte_positional.get(cte, []) - - def reorder_positional(pos, to_add, level): - if not level: - return to_add + pos - index = 0 - for index, name in enumerate(reversed(pos)): - if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501 - break - return pos[:-index] + to_add + pos[-index:] - - to_add = get_nested_positional(None) - self.positiontup = reorder_positional( - self.positiontup, to_add, nesting_level - ) - cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " @@ -4762,6 +4908,11 @@ class SQLCompiler(Compiled): keys_to_replace = set() base_parameters = {} executemany_values_w_comma = f"({imv.single_values_expr}), " + if self._numeric_binds: + escaped = re.escape(self._numeric_binds_identifier_char) + executemany_values_w_comma = re.sub( + rf"{escaped}\d+", "%s", executemany_values_w_comma + ) while batches: batch = batches[0:batch_size] @@ -4794,25 +4945,37 @@ class SQLCompiler(Compiled): num_ins_params = imv.num_positional_params_counted + batch_iterator: Iterable[Tuple[Any, ...]] if num_ins_params == len(batch[0]): extra_params = () - batch_iterator: Iterable[Tuple[Any, ...]] = batch - elif self.returning_precedes_values: + batch_iterator = batch + elif self.returning_precedes_values or self._numeric_binds: extra_params = batch[0][:-num_ins_params] batch_iterator = (b[-num_ins_params:] for b in batch) else: extra_params = batch[0][num_ins_params:] batch_iterator = (b[:num_ins_params] for b in batch) + values_string = (executemany_values_w_comma * len(batch))[:-2] + if self._numeric_binds and num_ins_params > 0: + # need to format here, since statement may contain + # unescaped %, while values_string contains just (%s, %s) + start = len(extra_params) + 1 + end = num_ins_params * len(batch) + start + positions = tuple( + f"{self._numeric_binds_identifier_char}{i}" + for i in range(start, end) + ) + values_string = values_string % positions + replaced_statement = statement.replace( - "__EXECMANY_TOKEN__", - (executemany_values_w_comma * len(batch))[:-2], + "__EXECMANY_TOKEN__", values_string ) replaced_parameters = tuple( itertools.chain.from_iterable(batch_iterator) ) - if self.returning_precedes_values: + if self.returning_precedes_values or self._numeric_binds: replaced_parameters = extra_params + replaced_parameters else: replaced_parameters = replaced_parameters + extra_params @@ -4869,23 +5032,30 @@ class SQLCompiler(Compiled): } ) - positiontup_before = positiontup_after = 0 + counted_bindparam = 0 # for positional, insertmanyvalues needs to know how many # bound parameters are in the VALUES sequence; there's no simple # rule because default expressions etc. can have zero or more # params inside them. After multiple attempts to figure this out, - # this very simplistic "count before, then count after" works and is + # this very simplistic "count after" works and is # likely the least amount of callcounts, though looks clumsy - if self.positiontup: - positiontup_before = len(self.positiontup) + if self.positional: + self._visited_bindparam = [] crud_params_struct = crud._get_crud_params( self, insert_stmt, compile_state, toplevel, **kw ) - if self.positiontup: - positiontup_after = len(self.positiontup) + if self.positional: + assert self._visited_bindparam is not None + counted_bindparam = len(self._visited_bindparam) + if self._numeric_binds: + if self._values_bindparam is not None: + self._values_bindparam += self._visited_bindparam + else: + self._values_bindparam = self._visited_bindparam + self._visited_bindparam = None crud_params_single = crud_params_struct.single_params @@ -4940,31 +5110,13 @@ class SQLCompiler(Compiled): if self.implicit_returning or insert_stmt._returning: - # if returning clause is rendered first, capture bound parameters - # while visiting and place them prior to the VALUES oriented - # bound parameters, when using positional parameter scheme - rpv = self.returning_precedes_values - flip_pt = rpv and self.positional - if flip_pt: - pt: Optional[List[str]] = self.positiontup - temp_pt: Optional[List[str]] - self.positiontup = temp_pt = [] - else: - temp_pt = pt = None - returning_clause = self.returning_clause( insert_stmt, self.implicit_returning or insert_stmt._returning, populate_result_map=toplevel, ) - if flip_pt: - if TYPE_CHECKING: - assert temp_pt is not None - assert pt is not None - self.positiontup = temp_pt + pt - - if rpv: + if self.returning_precedes_values: text += " " + returning_clause else: @@ -4982,7 +5134,6 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, - visiting_cte=kw.get("visiting_cte"), ), select_text, ) @@ -4999,7 +5150,7 @@ class SQLCompiler(Compiled): cast( "List[crud._CrudParamElementStr]", crud_params_single ), - (positiontup_after - positiontup_before), + counted_bindparam, ) elif compile_state._has_multi_parameters: text += " VALUES %s" % ( @@ -5033,7 +5184,7 @@ class SQLCompiler(Compiled): "List[crud._CrudParamElementStr]", crud_params_single, ), - positiontup_after - positiontup_before, + counted_bindparam, ) if insert_stmt._post_values_clause is not None: @@ -5052,7 +5203,6 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, - visiting_cte=kw.get("visiting_cte"), ) + text ) @@ -5201,13 +5351,7 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kw.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) @@ -5321,13 +5465,7 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kw.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 017ff7baa..ae1b032ae 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -85,8 +85,8 @@ _CrudParamElement = Tuple[ ] _CrudParamElementStr = Tuple[ "KeyedColumnElement[Any]", - str, - str, + str, # column name + str, # placeholder Iterable[str], ] _CrudParamElementSQLExpr = Tuple[ diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index d183372c3..45a2496dd 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -11,6 +11,7 @@ from __future__ import annotations import collections import contextlib +import itertools import re from .. import event @@ -285,7 +286,8 @@ class DialectSQL(CompiledSQL): return received_stmt, execute_observed.context.compiled_parameters - def _dialect_adjusted_statement(self, paramstyle): + def _dialect_adjusted_statement(self, dialect): + paramstyle = dialect.paramstyle stmt = re.sub(r"[\n\t]", "", self.statement) # temporarily escape out PG double colons @@ -300,8 +302,14 @@ class DialectSQL(CompiledSQL): repl = "?" elif paramstyle == "format": repl = r"%s" - elif paramstyle == "numeric": - repl = None + elif paramstyle.startswith("numeric"): + counter = itertools.count(1) + + num_identifier = "$" if paramstyle == "numeric_dollar" else ":" + + def repl(m): + return f"{num_identifier}{next(counter)}" + stmt = re.sub(r":([\w_]+)", repl, stmt) # put them back @@ -310,20 +318,20 @@ class DialectSQL(CompiledSQL): return stmt def _compare_sql(self, execute_observed, received_statement): - paramstyle = execute_observed.context.dialect.paramstyle - stmt = self._dialect_adjusted_statement(paramstyle) + stmt = self._dialect_adjusted_statement( + execute_observed.context.dialect + ) return received_statement == stmt def _failure_message(self, execute_observed, expected_params): - paramstyle = execute_observed.context.dialect.paramstyle return ( "Testing for compiled statement\n%r partial params %s, " "received\n%%(received_statement)r with params " "%%(received_parameters)r" % ( - self._dialect_adjusted_statement(paramstyle).replace( - "%", "%%" - ), + self._dialect_adjusted_statement( + execute_observed.context.dialect + ).replace("%", "%%"), repr(expected_params).replace("%", "%%"), ) ) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 957876579..6adcf5b64 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -189,7 +189,7 @@ def variation(argname, cases): elif querytyp.legacy_query: stmt = Session.query(Thing) else: - assert False + querytyp.fail() The variable provided is a slots object of boolean variables, as well diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 656a4e98a..ffe0f453a 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -371,6 +371,22 @@ def _setup_options(opt, file_config): options = opt +@pre +def _register_sqlite_numeric_dialect(opt, file_config): + from sqlalchemy.dialects import registry + + registry.register( + "sqlite.pysqlite_numeric", + "sqlalchemy.dialects.sqlite.pysqlite", + "_SQLiteDialect_pysqlite_numeric", + ) + registry.register( + "sqlite.pysqlite_dollar", + "sqlalchemy.dialects.sqlite.pysqlite", + "_SQLiteDialect_pysqlite_dollar", + ) + + @post def __ensure_cext(opt, file_config): if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1": @@ -144,6 +144,8 @@ oracle_db_link2 = test_link2 [db] default = sqlite:///:memory: sqlite = sqlite:///:memory: +sqlite_numeric = sqlite+pysqlite_numeric:///:memory: +sqlite_dollar = sqlite+pysqlite_dollar:///:memory: aiosqlite = sqlite+aiosqlite:///:memory: sqlite_file = sqlite:///querytest.db aiosqlite_file = sqlite+aiosqlite:///async_querytest.db diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index 42ec20743..2b32d6db7 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -990,19 +990,19 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): "matchtable.title @@ plainto_tsquery(%(title_1)s)", ) - @testing.requires.format_paramstyle + @testing.only_if("+asyncpg") def test_expression_positional(self, connection): matchtable = self.tables.matchtable if self._strs_render_bind_casts(connection): self.assert_compile( matchtable.c.title.match("somstr"), - "matchtable.title @@ plainto_tsquery(%s::VARCHAR(200))", + "matchtable.title @@ plainto_tsquery($1::VARCHAR(200))", ) else: self.assert_compile( matchtable.c.title.match("somstr"), - "matchtable.title @@ plainto_tsquery(%s)", + "matchtable.title @@ plainto_tsquery($1)", ) def test_simple_match(self, connection): diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index c5147e37f..07117b862 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -2916,6 +2916,8 @@ class OnConflictTest(AssertsCompiledSQL, fixtures.TablesTest): ) @testing.combinations("control", "excluded", "dict") + @testing.skip_if("+pysqlite_numeric") + @testing.skip_if("+pysqlite_dollar") def test_set_excluded(self, scenario): """test #8014, sending all of .excluded to set""" diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py index 277248617..19c26f43c 100644 --- a/test/engine/test_logging.py +++ b/test/engine/test_logging.py @@ -28,7 +28,7 @@ def exec_sql(engine, sql, *args, **kwargs): class LogParamsTest(fixtures.TestBase): - __only_on__ = "sqlite" + __only_on__ = "sqlite+pysqlite" __requires__ = ("ad_hoc_engines",) def setup_test(self): @@ -704,7 +704,7 @@ class LoggingNameTest(fixtures.TestBase): class TransactionContextLoggingTest(fixtures.TestBase): - __only_on__ = "sqlite" + __only_on__ = "sqlite+pysqlite" @testing.fixture() def plain_assert_buf(self, plain_logging_engine): diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 557b5e9da..78607e03d 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -958,7 +958,7 @@ class BulkDMLReturningJoinedInhTest( BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest ): - __requires__ = ("insert_returning",) + __requires__ = ("insert_returning", "insert_executemany_returning") __backend__ = True @classmethod @@ -1044,7 +1044,7 @@ class BulkDMLReturningJoinedInhTest( class BulkDMLReturningSingleInhTest( BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest ): - __requires__ = ("insert_returning",) + __requires__ = ("insert_returning", "insert_executemany_returning") __backend__ = True @classmethod @@ -1075,7 +1075,7 @@ class BulkDMLReturningSingleInhTest( class BulkDMLReturningConcreteInhTest( BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest ): - __requires__ = ("insert_returning",) + __requires__ = ("insert_returning", "insert_executemany_returning") __backend__ = True @classmethod diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index df335f0f6..714878f4e 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -1791,14 +1791,33 @@ class WriteOnlyBulkTest( "INSERT INTO users (name) VALUES (:name)", [{"name": "x"}], ), - CompiledSQL( - "INSERT INTO addresses (user_id, email_address) " - "VALUES (:user_id, :email_address) " - "RETURNING addresses.id", + Conditional( + testing.requires.insert_executemany_returning.enabled, + [ + CompiledSQL( + "INSERT INTO addresses " + "(user_id, email_address) " + "VALUES (:user_id, :email_address) " + "RETURNING addresses.id", + [ + {"user_id": uid, "email_address": "e1"}, + {"user_id": uid, "email_address": "e2"}, + {"user_id": uid, "email_address": "e3"}, + ], + ) + ], [ - {"user_id": uid, "email_address": "e1"}, - {"user_id": uid, "email_address": "e2"}, - {"user_id": uid, "email_address": "e3"}, + CompiledSQL( + "INSERT INTO addresses " + "(user_id, email_address) " + "VALUES (:user_id, :email_address)", + param, + ) + for param in [ + {"user_id": uid, "email_address": "e1"}, + {"user_id": uid, "email_address": "e2"}, + {"user_id": uid, "email_address": "e3"}, + ] ], ), ], @@ -1863,14 +1882,33 @@ class WriteOnlyBulkTest( "INSERT INTO users (name) VALUES (:name)", [{"name": "x"}], ), - CompiledSQL( - "INSERT INTO addresses (user_id, email_address) " - "VALUES (:user_id, :email_address) " - "RETURNING addresses.id", + Conditional( + testing.requires.insert_executemany_returning.enabled, + [ + CompiledSQL( + "INSERT INTO addresses " + "(user_id, email_address) " + "VALUES (:user_id, :email_address) " + "RETURNING addresses.id", + [ + {"user_id": uid, "email_address": "e1"}, + {"user_id": uid, "email_address": "e2"}, + {"user_id": uid, "email_address": "e3"}, + ], + ) + ], [ - {"user_id": uid, "email_address": "e1"}, - {"user_id": uid, "email_address": "e2"}, - {"user_id": uid, "email_address": "e3"}, + CompiledSQL( + "INSERT INTO addresses " + "(user_id, email_address) " + "VALUES (:user_id, :email_address)", + param, + ) + for param in [ + {"user_id": uid, "email_address": "e1"}, + {"user_id": uid, "email_address": "e2"}, + {"user_id": uid, "email_address": "e3"}, + ] ], ), ], diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index 36c47e27b..eb5a795e2 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -1458,7 +1458,7 @@ class MergeTest(_fixtures.FixtureTest): ) attrname = "user" else: - assert False + direction.fail() assert attrname in obj_to_merge.__dict__ diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index f204e954c..468d43063 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -3077,7 +3077,7 @@ class EagerDefaultsTest(fixtures.MappedTest): asserter.assert_( Conditional( - testing.db.dialect.insert_executemany_returning, + testing.db.dialect.insert_returning, [ CompiledSQL( "INSERT INTO test (id) VALUES (:id) " diff --git a/test/requirements.py b/test/requirements.py index 5276593c9..83cd65cd8 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -232,7 +232,6 @@ class DefaultRequirements(SuiteRequirements): "mariadb+pymysql", "mariadb+cymysql", "mariadb+mysqlconnector", - "postgresql+asyncpg", "postgresql+pg8000", ] ) @@ -388,6 +387,14 @@ class DefaultRequirements(SuiteRequirements): ) @property + def predictable_gc(self): + """target platform must remove all cycles unconditionally when + gc.collect() is called, as well as clean out unreferenced subclasses. + + """ + return self.cpython + skip_if("+aiosqlite") + + @property def memory_process_intensive(self): """Driver is able to handle the memory tests which run in a subprocess and iterate through hundreds of connections @@ -969,6 +976,8 @@ class DefaultRequirements(SuiteRequirements): "mariadb", "sqlite+aiosqlite", "sqlite+pysqlite", + "sqlite+pysqlite_numeric", + "sqlite+pysqlite_dollar", "sqlite+pysqlcipher", "mssql", ) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 205ce5157..d342b9248 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -79,6 +79,7 @@ from sqlalchemy.sql import util as sql_util from sqlalchemy.sql.elements import BooleanClauseList from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import CompilerColumnElement +from sqlalchemy.sql.elements import Grouping from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.sql.expression import ClauseList from sqlalchemy.sql.selectable import LABEL_STYLE_NONE @@ -4915,88 +4916,259 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): dialect="default", ) - @standalone_escape - @testing.variation("use_assert_compile", [True, False]) @testing.variation("use_positional", [True, False]) - def test_standalone_bindparam_escape_expanding( - self, paramname, expected, use_assert_compile, use_positional + def test_standalone_bindparam_escape_collision(self, use_positional): + """this case is currently not supported + + it's kinda bad since positional takes the unescaped param + while non positional takes the escaped one. + """ + stmt = select(table1.c.myid).where( + table1.c.name == bindparam("[brackets]", value="x"), + table1.c.description == bindparam("_brackets_", value="y"), + ) + + if use_positional: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable WHERE mytable.name = ? " + "AND mytable.description = ?", + params={"[brackets]": "a", "_brackets_": "b"}, + checkpositional=("a", "a"), + dialect="sqlite", + ) + else: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable WHERE mytable.name = " + ":_brackets_ AND mytable.description = :_brackets_", + params={"[brackets]": "a", "_brackets_": "b"}, + checkparams={"_brackets_": "b"}, + dialect="default", + ) + + paramstyle = testing.variation("paramstyle", ["named", "qmark", "numeric"]) + + @standalone_escape + @paramstyle + def test_standalone_bindparam_escape_expanding_compile( + self, paramname, expected, paramstyle ): stmt = select(table1.c.myid).where( table1.c.name.in_(bindparam(paramname, value=["a", "b"])) ) - if use_assert_compile: - if use_positional: - self.assert_compile( - stmt, - "SELECT mytable.myid FROM mytable " - "WHERE mytable.name IN (?, ?)", - params={paramname: ["y", "z"]}, - # NOTE: this is what render_postcompile will do right now - # if you run construct_params(). render_postcompile mode - # is not actually used by the execution internals, it's for - # user-facing compilation code. So this is likely a - # current limitation of construct_params() which is not - # doing the full blown postcompile; just assert that's - # what it does for now. it likely should be corrected - # to make more sense. - checkpositional=(["y", "z"], ["y", "z"]), - dialect="sqlite", - render_postcompile=True, - ) - else: - self.assert_compile( - stmt, - "SELECT mytable.myid FROM mytable WHERE mytable.name IN " - "(:%s_1, :%s_2)" % (expected, expected), - params={paramname: ["y", "z"]}, - # NOTE: this is what render_postcompile will do right now - # if you run construct_params(). render_postcompile mode - # is not actually used by the execution internals, it's for - # user-facing compilation code. So this is likely a - # current limitation of construct_params() which is not - # doing the full blown postcompile; just assert that's - # what it does for now. it likely should be corrected - # to make more sense. - checkparams={ - "%s_1" % expected: ["y", "z"], - "%s_2" % expected: ["y", "z"], - }, - dialect="default", - render_postcompile=True, - ) + # NOTE: below the rendered params are just what + # render_postcompile will do right now + # if you run construct_params(). render_postcompile mode + # is not actually used by the execution internals, it's for + # user-facing compilation code. So this is likely a + # current limitation of construct_params() which is not + # doing the full blown postcompile; just assert that's + # what it does for now. it likely should be corrected + # to make more sense. + if paramstyle.qmark: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable " + "WHERE mytable.name IN (?, ?)", + params={paramname: ["y", "z"]}, + checkpositional=(["y", "z"], ["y", "z"]), + dialect="sqlite", + render_postcompile=True, + ) + elif paramstyle.numeric: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable " + "WHERE mytable.name IN (:1, :2)", + params={paramname: ["y", "z"]}, + checkpositional=(["y", "z"], ["y", "z"]), + dialect=sqlite.dialect(paramstyle="numeric"), + render_postcompile=True, + ) + elif paramstyle.named: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable WHERE mytable.name IN " + "(:%s_1, :%s_2)" % (expected, expected), + params={paramname: ["y", "z"]}, + checkparams={ + "%s_1" % expected: ["y", "z"], + "%s_2" % expected: ["y", "z"], + }, + dialect="default", + render_postcompile=True, + ) else: - # this is what DefaultDialect actually does. - # this should be matched to DefaultDialect._init_compiled() - if use_positional: - compiled = stmt.compile( - dialect=default.DefaultDialect(paramstyle="qmark") - ) - else: - compiled = stmt.compile(dialect=default.DefaultDialect()) + paramstyle.fail() - checkparams = compiled.construct_params( - {paramname: ["y", "z"]}, escape_names=False - ) + @standalone_escape + @paramstyle + def test_standalone_bindparam_escape_expanding( + self, paramname, expected, paramstyle + ): + stmt = select(table1.c.myid).where( + table1.c.name.in_(bindparam(paramname, value=["a", "b"])) + ) + # this is what DefaultDialect actually does. + # this should be matched to DefaultDialect._init_compiled() + if paramstyle.qmark: + dialect = default.DefaultDialect(paramstyle="qmark") + elif paramstyle.numeric: + dialect = default.DefaultDialect(paramstyle="numeric") + else: + dialect = default.DefaultDialect() - # nothing actually happened. if the compiler had - # render_postcompile set, the - # above weird param thing happens - eq_(checkparams, {paramname: ["y", "z"]}) + compiled = stmt.compile(dialect=dialect) + checkparams = compiled.construct_params( + {paramname: ["y", "z"]}, escape_names=False + ) - expanded_state = compiled._process_parameters_for_postcompile( - checkparams - ) + # nothing actually happened. if the compiler had + # render_postcompile set, the + # above weird param thing happens + eq_(checkparams, {paramname: ["y", "z"]}) + + expanded_state = compiled._process_parameters_for_postcompile( + checkparams + ) + eq_( + expanded_state.additional_parameters, + {f"{expected}_1": "y", f"{expected}_2": "z"}, + ) + + if paramstyle.qmark or paramstyle.numeric: eq_( - expanded_state.additional_parameters, - {f"{expected}_1": "y", f"{expected}_2": "z"}, + expanded_state.positiontup, + [f"{expected}_1", f"{expected}_2"], ) - if use_positional: - eq_( - expanded_state.positiontup, - [f"{expected}_1", f"{expected}_2"], + @paramstyle + def test_expanding_in_repeated(self, paramstyle): + stmt = ( + select(table1) + .where( + table1.c.name.in_( + bindparam("uname", value=["h", "e"], expanding=True) + ) + | table1.c.name.in_( + bindparam("uname2", value=["y"], expanding=True) + ) + ) + .where(table1.c.myid == 8) + ) + stmt = stmt.union( + select(table1) + .where( + table1.c.name.in_( + bindparam("uname", value=["h", "e"], expanding=True) + ) + | table1.c.name.in_( + bindparam("uname2", value=["y"], expanding=True) ) + ) + .where(table1.c.myid == 9) + ).order_by("myid") + + # NOTE: below the rendered params are just what + # render_postcompile will do right now + # if you run construct_params(). render_postcompile mode + # is not actually used by the execution internals, it's for + # user-facing compilation code. So this is likely a + # current limitation of construct_params() which is not + # doing the full blown postcompile; just assert that's + # what it does for now. it likely should be corrected + # to make more sense. + + if paramstyle.qmark: + self.assert_compile( + stmt, + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable " + "WHERE (mytable.name IN (?, ?) OR " + "mytable.name IN (?)) " + "AND mytable.myid = ? " + "UNION SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable " + "WHERE (mytable.name IN (?, ?) OR " + "mytable.name IN (?)) " + "AND mytable.myid = ? ORDER BY myid", + params={"uname": ["y", "z"], "uname2": ["a"]}, + checkpositional=( + ["y", "z"], + ["y", "z"], + ["a"], + 8, + ["y", "z"], + ["y", "z"], + ["a"], + 9, + ), + dialect="sqlite", + render_postcompile=True, + ) + elif paramstyle.numeric: + self.assert_compile( + stmt, + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable " + "WHERE (mytable.name IN (:3, :4) OR " + "mytable.name IN (:5)) " + "AND mytable.myid = :1 " + "UNION SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable " + "WHERE (mytable.name IN (:3, :4) OR " + "mytable.name IN (:5)) " + "AND mytable.myid = :2 ORDER BY myid", + params={"uname": ["y", "z"], "uname2": ["a"]}, + checkpositional=(8, 9, ["y", "z"], ["y", "z"], ["a"]), + dialect=sqlite.dialect(paramstyle="numeric"), + render_postcompile=True, + ) + elif paramstyle.named: + self.assert_compile( + stmt, + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable " + "WHERE (mytable.name IN (:uname_1, :uname_2) OR " + "mytable.name IN (:uname2_1)) " + "AND mytable.myid = :myid_1 " + "UNION SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable " + "WHERE (mytable.name IN (:uname_1, :uname_2) OR " + "mytable.name IN (:uname2_1)) " + "AND mytable.myid = :myid_2 ORDER BY myid", + params={"uname": ["y", "z"], "uname2": ["a"]}, + checkparams={ + "uname": ["y", "z"], + "uname2": ["a"], + "uname_1": ["y", "z"], + "uname_2": ["y", "z"], + "uname2_1": ["a"], + "myid_1": 8, + "myid_2": 9, + }, + dialect="default", + render_postcompile=True, + ) + else: + paramstyle.fail() + + def test_numeric_dollar_bindparam(self): + stmt = table1.select().where( + table1.c.name == "a", table1.c.myid.in_([1, 2]) + ) + self.assert_compile( + stmt, + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable " + "WHERE mytable.name = $1 " + "AND mytable.myid IN ($2, $3)", + checkpositional=("a", 1, 2), + dialect=default.DefaultDialect(paramstyle="numeric_dollar"), + render_postcompile=True, + ) class UnsupportedTest(fixtures.TestBase): @@ -5096,6 +5268,28 @@ class StringifySpecialTest(fixtures.TestBase): "INSERT INTO mytable (myid) VALUES (:myid_m0), (:myid_m1)", ) + def test_multirow_insert_positional(self): + stmt = table1.insert().values([{"myid": 1}, {"myid": 2}]) + eq_ignore_whitespace( + stmt.compile(dialect=sqlite.dialect()).string, + "INSERT INTO mytable (myid) VALUES (?), (?)", + ) + + def test_multirow_insert_numeric(self): + stmt = table1.insert().values([{"myid": 1}, {"myid": 2}]) + eq_ignore_whitespace( + stmt.compile(dialect=sqlite.dialect(paramstyle="numeric")).string, + "INSERT INTO mytable (myid) VALUES (:1), (:2)", + ) + + def test_insert_noparams_numeric(self): + ii = table1.insert().returning(table1.c.myid) + eq_ignore_whitespace( + ii.compile(dialect=sqlite.dialect(paramstyle="numeric")).string, + "INSERT INTO mytable (myid, name, description) VALUES " + "(:1, :2, :3) RETURNING myid", + ) + def test_cte(self): # stringify of these was supported anyway by defaultdialect. stmt = select(table1.c.myid).cte() @@ -5153,6 +5347,42 @@ class StringifySpecialTest(fixtures.TestBase): "SELECT CAST(mytable.myid AS MyType()) AS myid FROM mytable", ) + def test_dialect_sub_compile(self): + class Widget(ClauseElement): + __visit_name__ = "widget" + stringify_dialect = "sqlite" + + def visit_widget(self, element, **kw): + return "widget" + + with mock.patch( + "sqlalchemy.dialects.sqlite.base.SQLiteCompiler.visit_widget", + visit_widget, + create=True, + ): + eq_(str(Grouping(Widget())), "(widget)") + + def test_dialect_sub_compile_w_binds(self): + """test sub-compile into a new compiler where + state != CompilerState.COMPILING, but we have to render a bindparam + string. has to render the correct template. + + """ + + class Widget(ClauseElement): + __visit_name__ = "widget" + stringify_dialect = "sqlite" + + def visit_widget(self, element, **kw): + return f"widget {self.process(bindparam('q'), **kw)}" + + with mock.patch( + "sqlalchemy.dialects.sqlite.base.SQLiteCompiler.visit_widget", + visit_widget, + create=True, + ): + eq_(str(Grouping(Widget())), "(widget ?)") + def test_within_group(self): # stringify of these was supported anyway by defaultdialect. from sqlalchemy import within_group diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index b89d18de6..502104dae 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -993,20 +993,20 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( s, 'WITH regional_sales AS (SELECT orders."order" ' - 'AS "order", :1 AS anon_2 FROM orders) SELECT ' - 'regional_sales."order", :2 AS anon_1 FROM regional_sales', - checkpositional=("x", "y"), + 'AS "order", :2 AS anon_2 FROM orders) SELECT ' + 'regional_sales."order", :1 AS anon_1 FROM regional_sales', + checkpositional=("y", "x"), dialect=dialect, ) self.assert_compile( s.union(s), 'WITH regional_sales AS (SELECT orders."order" ' - 'AS "order", :1 AS anon_2 FROM orders) SELECT ' - 'regional_sales."order", :2 AS anon_1 FROM regional_sales ' - 'UNION SELECT regional_sales."order", :3 AS anon_1 ' + 'AS "order", :2 AS anon_2 FROM orders) SELECT ' + 'regional_sales."order", :1 AS anon_1 FROM regional_sales ' + 'UNION SELECT regional_sales."order", :1 AS anon_1 ' "FROM regional_sales", - checkpositional=("x", "y", "y"), + checkpositional=("y", "x"), dialect=dialect, ) @@ -1057,8 +1057,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( s3, 'WITH regional_sales_1 AS (SELECT orders."order" AS "order" ' - 'FROM orders WHERE orders."order" = :1), regional_sales_2 AS ' - '(SELECT orders."order" = :2 AS anon_1, ' + 'FROM orders WHERE orders."order" = :2), regional_sales_2 AS ' + '(SELECT orders."order" = :1 AS anon_1, ' 'anon_2."order" AS "order", ' 'orders."order" AS order_1, ' 'regional_sales_1."order" AS order_2 FROM orders, ' @@ -1067,7 +1067,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): 'WHERE orders."order" = :3) SELECT regional_sales_2.anon_1, ' 'regional_sales_2."order", regional_sales_2.order_1, ' "regional_sales_2.order_2 FROM regional_sales_2", - checkpositional=("x", "y", "z"), + checkpositional=("y", "x", "z"), dialect=dialect, ) diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index ac9ac4022..1c24d4c79 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -488,7 +488,8 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): dialect=postgresql.dialect(), ) - def test_heterogeneous_multi_values(self): + @testing.variation("paramstyle", ["pg", "qmark", "numeric", "dollar"]) + def test_heterogeneous_multi_values(self, paramstyle): """for #6047, originally I thought we'd take any insert().values() and be able to convert it to a "many" style execution that we can cache. @@ -519,33 +520,81 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): ] ) + pos_par = ( + 1, + 1, + 2, + 2, + 1, + 2, + 2, + 3, + 1, + 2, + 2, + 10, + ) + # SQL expressions in the params at arbitrary locations means # we have to scan them at compile time, and the shape of the bound # parameters is not predictable. so for #6047 where I originally # thought all of values() could be rewritten, this makes it not # really worth it. - self.assert_compile( - stmt, - "INSERT INTO t (x, y, z) VALUES " - "(%(x_m0)s, sum(%(sum_1)s, %(sum_2)s), %(z_m0)s), " - "(sum(%(sum_3)s, %(sum_4)s), %(y_m1)s, %(z_m1)s), " - "(sum(%(sum_5)s, %(sum_6)s), %(y_m2)s, foo(%(foo_1)s))", - checkparams={ - "x_m0": 1, - "sum_1": 1, - "sum_2": 2, - "z_m0": 2, - "sum_3": 1, - "sum_4": 2, - "y_m1": 2, - "z_m1": 3, - "sum_5": 1, - "sum_6": 2, - "y_m2": 2, - "foo_1": 10, - }, - dialect=postgresql.dialect(), - ) + if paramstyle.pg: + self.assert_compile( + stmt, + "INSERT INTO t (x, y, z) VALUES " + "(%(x_m0)s, sum(%(sum_1)s, %(sum_2)s), %(z_m0)s), " + "(sum(%(sum_3)s, %(sum_4)s), %(y_m1)s, %(z_m1)s), " + "(sum(%(sum_5)s, %(sum_6)s), %(y_m2)s, foo(%(foo_1)s))", + checkparams={ + "x_m0": 1, + "sum_1": 1, + "sum_2": 2, + "z_m0": 2, + "sum_3": 1, + "sum_4": 2, + "y_m1": 2, + "z_m1": 3, + "sum_5": 1, + "sum_6": 2, + "y_m2": 2, + "foo_1": 10, + }, + dialect=postgresql.dialect(), + ) + elif paramstyle.qmark: + self.assert_compile( + stmt, + "INSERT INTO t (x, y, z) VALUES " + "(?, sum(?, ?), ?), " + "(sum(?, ?), ?, ?), " + "(sum(?, ?), ?, foo(?))", + checkpositional=pos_par, + dialect=sqlite.dialect(), + ) + elif paramstyle.numeric: + self.assert_compile( + stmt, + "INSERT INTO t (x, y, z) VALUES " + "(:1, sum(:2, :3), :4), " + "(sum(:5, :6), :7, :8), " + "(sum(:9, :10), :11, foo(:12))", + checkpositional=pos_par, + dialect=sqlite.dialect(paramstyle="numeric"), + ) + elif paramstyle.dollar: + self.assert_compile( + stmt, + "INSERT INTO t (x, y, z) VALUES " + "($1, sum($2, $3), $4), " + "(sum($5, $6), $7, $8), " + "(sum($9, $10), $11, foo($12))", + checkpositional=pos_par, + dialect=sqlite.dialect(paramstyle="numeric_dollar"), + ) + else: + paramstyle.fail() def test_insert_seq_pk_multi_values_seq_not_supported(self): m = MetaData() diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index b856acfd3..7f1124c84 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -107,7 +107,7 @@ class CursorResultTest(fixtures.TablesTest): Column("y", String(50)), ) - @testing.requires.insert_returning + @testing.requires.insert_executemany_returning def test_splice_horizontally(self, connection): users = self.tables.users addresses = self.tables.addresses diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 91413ff35..59519a5ec 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -3322,7 +3322,7 @@ class ExpressionTest( elif expression_type.right_side: expr = (column("x", Integer) == Widget(52)).right else: - assert False + expression_type.fail() if secondary_adapt: is_(expr.type._type_affinity, String) diff --git a/test/sql/test_update.py b/test/sql/test_update.py index cd7f992e2..66971f64e 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -907,7 +907,8 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): dialect=dialect, ) - def test_update_bound_ordering(self): + @testing.variation("paramstyle", ["qmark", "format", "numeric"]) + def test_update_bound_ordering(self, paramstyle): """test that bound parameters between the UPDATE and FROM clauses order correctly in different SQL compilation scenarios. @@ -921,30 +922,47 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): .values(name="foo") ) - dialect = default.StrCompileDialect() - dialect.positional = True - self.assert_compile( - upd, - "UPDATE mytable SET name=:name FROM (SELECT " - "myothertable.otherid AS otherid, " - "myothertable.othername AS othername " - "FROM myothertable " - "WHERE myothertable.otherid = :otherid_1) AS anon_1 " - "WHERE mytable.name = anon_1.othername", - checkpositional=("foo", 5), - dialect=dialect, - ) + if paramstyle.qmark: - self.assert_compile( - upd, - "UPDATE mytable, (SELECT myothertable.otherid AS otherid, " - "myothertable.othername AS othername " - "FROM myothertable " - "WHERE myothertable.otherid = %s) AS anon_1 SET mytable.name=%s " - "WHERE mytable.name = anon_1.othername", - checkpositional=(5, "foo"), - dialect=mysql.dialect(), - ) + dialect = default.StrCompileDialect(paramstyle="qmark") + self.assert_compile( + upd, + "UPDATE mytable SET name=? FROM (SELECT " + "myothertable.otherid AS otherid, " + "myothertable.othername AS othername " + "FROM myothertable " + "WHERE myothertable.otherid = ?) AS anon_1 " + "WHERE mytable.name = anon_1.othername", + checkpositional=("foo", 5), + dialect=dialect, + ) + elif paramstyle.format: + self.assert_compile( + upd, + "UPDATE mytable, (SELECT myothertable.otherid AS otherid, " + "myothertable.othername AS othername " + "FROM myothertable " + "WHERE myothertable.otherid = %s) AS anon_1 " + "SET mytable.name=%s " + "WHERE mytable.name = anon_1.othername", + checkpositional=(5, "foo"), + dialect=mysql.dialect(), + ) + elif paramstyle.numeric: + dialect = default.StrCompileDialect(paramstyle="numeric") + self.assert_compile( + upd, + "UPDATE mytable SET name=:1 FROM (SELECT " + "myothertable.otherid AS otherid, " + "myothertable.othername AS othername " + "FROM myothertable " + "WHERE myothertable.otherid = :2) AS anon_1 " + "WHERE mytable.name = anon_1.othername", + checkpositional=("foo", 5), + dialect=dialect, + ) + else: + paramstyle.fail() class UpdateFromCompileTest( @@ -89,7 +89,7 @@ setenv= sqlite: SQLITE={env:TOX_SQLITE:--db sqlite} sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file} - py3{,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite} + py3{,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver pysqlite_numeric --dbdriver aiosqlite} py3{,7,8,9}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite --dbdriver pysqlcipher} |