diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 22 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 115 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 64 |
5 files changed, 170 insertions, 52 deletions
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 0d51bf73d..41b9ac43d 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -854,7 +854,7 @@ class OracleCompiler(compiler.SQLCompiler): def visit_function(self, func, **kw): text = super().visit_function(func, **kw) if kw.get("asfrom", False): - text = "TABLE (%s)" % func + text = "TABLE (%s)" % text return text def visit_table_valued_column(self, element, **kw): @@ -1222,20 +1222,18 @@ class OracleCompiler(compiler.SQLCompiler): self.process(binary.right), ) - def _get_regexp_args(self, binary, kw): + def visit_regexp_match_op_binary(self, binary, operator, **kw): string = self.process(binary.left, **kw) pattern = self.process(binary.right, **kw) flags = binary.modifiers["flags"] - if flags is not None: - flags = self.process(flags, **kw) - return string, pattern, flags - - def visit_regexp_match_op_binary(self, binary, operator, **kw): - string, pattern, flags = self._get_regexp_args(binary, kw) if flags is None: return "REGEXP_LIKE(%s, %s)" % (string, pattern) else: - return "REGEXP_LIKE(%s, %s, %s)" % (string, pattern, flags) + return "REGEXP_LIKE(%s, %s, %s)" % ( + string, + pattern, + self.process(flags, **kw), + ) def visit_not_regexp_match_op_binary(self, binary, operator, **kw): return "NOT %s" % self.visit_regexp_match_op_binary( @@ -1243,8 +1241,10 @@ class OracleCompiler(compiler.SQLCompiler): ) def visit_regexp_replace_op_binary(self, binary, operator, **kw): - string, pattern, flags = self._get_regexp_args(binary, kw) + string = self.process(binary.left, **kw) + pattern = self.process(binary.right, **kw) replacement = self.process(binary.modifiers["replacement"], **kw) + flags = binary.modifiers["flags"] if flags is None: return "REGEXP_REPLACE(%s, %s, %s)" % ( string, @@ -1256,7 +1256,7 @@ class OracleCompiler(compiler.SQLCompiler): string, pattern, replacement, - flags, + self.process(flags, **kw), ) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 99c48fb2f..f9108094f 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1747,14 +1747,11 @@ class PGCompiler(compiler.SQLCompiler): return self._generate_generic_binary( binary, " %s* " % base_op, **kw ) - flags = self.process(flags, **kw) - string = self.process(binary.left, **kw) - pattern = self.process(binary.right, **kw) return "%s %s CONCAT('(?', %s, ')', %s)" % ( - string, + self.process(binary.left, **kw), base_op, - flags, - pattern, + self.process(flags, **kw), + self.process(binary.right, **kw), ) def visit_regexp_match_op_binary(self, binary, operator, **kw): @@ -1767,8 +1764,6 @@ class PGCompiler(compiler.SQLCompiler): string = self.process(binary.left, **kw) pattern = self.process(binary.right, **kw) flags = binary.modifiers["flags"] - if flags is not None: - flags = self.process(flags, **kw) replacement = self.process(binary.modifiers["replacement"], **kw) if flags is None: return "REGEXP_REPLACE(%s, %s, %s)" % ( @@ -1781,7 +1776,7 @@ class PGCompiler(compiler.SQLCompiler): string, pattern, replacement, - flags, + self.process(flags, **kw), ) def visit_empty_set_expr(self, element_types): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 50cf9b477..7ac279ee2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -236,8 +236,8 @@ BIND_TEMPLATES = { } -_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]") -_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__")) +_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]") +_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___")) OPERATORS = { # binary @@ -973,6 +973,7 @@ class SQLCompiler(Compiled): debugging use cases. """ + positiontup_level: Optional[Dict[str, int]] = None inline: bool = False @@ -988,6 +989,8 @@ class SQLCompiler(Compiled): ctes_recursive: bool cte_positional: Dict[CTE, List[str]] + cte_level: Dict[CTE, int] + cte_order: Dict[Optional[CTE], List[CTE]] def __init__( self, @@ -1052,6 +1055,7 @@ class SQLCompiler(Compiled): # true if the paramstyle is positional self.positional = dialect.positional if self.positional: + self.positiontup_level = {} self.positiontup = [] self._numeric_binds = dialect.paramstyle == "numeric" self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] @@ -1215,6 +1219,8 @@ class SQLCompiler(Compiled): self.ctes_recursive = False if self.positional: self.cte_positional = {} + self.cte_level = {} + self.cte_order = collections.defaultdict(list) return ctes @@ -2103,7 +2109,13 @@ class SQLCompiler(Compiled): text = self.process(taf.element, **kw) if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) @@ -2231,6 +2243,7 @@ class SQLCompiler(Compiled): ) def visit_over(self, over, **kwargs): + text = over.element._compiler_dispatch(self, **kwargs) if over.range_: range_ = "RANGE BETWEEN %s" % self._format_frame_clause( over.range_, **kwargs @@ -2243,7 +2256,7 @@ class SQLCompiler(Compiled): range_ = None return "%s OVER (%s)" % ( - over.element._compiler_dispatch(self, **kwargs), + text, " ".join( [ "%s BY %s" @@ -2396,7 +2409,9 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = ( self._render_cte_clause( - nesting_level=nesting_level, include_following_stack=True + nesting_level=nesting_level, + include_following_stack=True, + visiting_cte=kwargs.get("visiting_cte"), ) + text ) @@ -3222,7 +3237,8 @@ class SQLCompiler(Compiled): positional_names.append(name) else: self.positiontup.append(name) # type: ignore[union-attr] - elif not escaped_from: + self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501 + if not escaped_from: if _BIND_TRANSLATE_RE.search(name): # not quite the translate use case as we want to @@ -3333,6 +3349,8 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = new_level_name + ( cte_opts, ) + if self.positional: + self.cte_level[cte] = cte_level else: cte_level = len(self.stack) if nesting else 1 @@ -3396,6 +3414,8 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = cte_level_name + ( cte_opts, ) + if self.positional: + self.cte_level[cte] = cte_level if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -4129,13 +4149,16 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes: - # In compound query, CTEs are shared at the compound level - if not is_embedded_select: - nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause(nesting_level=nesting_level) + text + # In compound query, CTEs are shared at the compound level + if self.ctes and (not is_embedded_select or toplevel): + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kwargs.get("visiting_cte"), ) + + text + ) if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -4309,6 +4332,7 @@ class SQLCompiler(Compiled): self, nesting_level=None, include_following_stack=False, + visiting_cte=None, ): """ include_following_stack @@ -4341,19 +4365,47 @@ class SQLCompiler(Compiled): if not ctes: return "" - ctes_recursive = any([cte.recursive for cte in ctes]) if self.positional: - assert self.positiontup is not None - self.positiontup = ( - list( - itertools.chain.from_iterable( - self.cte_positional[cte] for cte in ctes - ) + self.cte_order[visiting_cte].extend(ctes) + + if visiting_cte is None and self.cte_order: + assert self.positiontup is not None + + def get_nested_positional(cte): + if cte in self.cte_order: + children = self.cte_order.pop(cte) + to_add = list( + itertools.chain.from_iterable( + get_nested_positional(child_cte) + for child_cte in children + ) + ) + if cte in self.cte_positional: + return reorder_positional( + self.cte_positional[cte], + to_add, + self.cte_level[children[0]], + ) + else: + return to_add + else: + return self.cte_positional.get(cte, []) + + def reorder_positional(pos, to_add, level): + if not level: + return to_add + pos + index = 0 + for index, name in enumerate(reversed(pos)): + if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501 + break + return pos[:-index] + to_add + pos[-index:] + + to_add = get_nested_positional(None) + self.positiontup = reorder_positional( + self.positiontup, to_add, nesting_level ) - + self.positiontup - ) cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) @@ -4930,6 +4982,7 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, + visiting_cte=kw.get("visiting_cte"), ), select_text, ) @@ -4997,7 +5050,9 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = ( self._render_cte_clause( - nesting_level=nesting_level, include_following_stack=True + nesting_level=nesting_level, + include_following_stack=True, + visiting_cte=kw.get("visiting_cte"), ) + text ) @@ -5146,7 +5201,13 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) @@ -5260,7 +5321,13 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 97336d416..2dcc611fa 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2052,9 +2052,7 @@ class CTE( else: self.element._generate_fromclause_column_proxies(self) - def alias( - self, name: Optional[str] = None, flat: bool = False - ) -> NamedFromClause: + def alias(self, name: Optional[str] = None, flat: bool = False) -> CTE: """Return an :class:`_expression.Alias` of this :class:`_expression.CTE`. @@ -2078,7 +2076,7 @@ class CTE( _suffixes=self._suffixes, ) - def union(self, *other): + def union(self, *other: _SelectStatementForCompoundArgument) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION`` of the original CTE against the given selectables provided as positional arguments. @@ -2107,7 +2105,7 @@ class CTE( _suffixes=self._suffixes, ) - def union_all(self, *other): + def union_all(self, *other: _SelectStatementForCompoundArgument) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL`` of the original CTE against the given selectables provided as positional arguments. diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 321c05b44..790a72ec8 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -9,7 +9,9 @@ from __future__ import annotations +from collections import defaultdict import contextlib +from copy import copy from itertools import filterfalse import re import sys @@ -493,6 +495,7 @@ class AssertsCompiledSQL: render_schema_translate=False, default_schema_name=None, from_linting=False, + check_param_order=True, ): if use_default_dialect: dialect = default.DefaultDialect() @@ -506,8 +509,11 @@ class AssertsCompiledSQL: if dialect is None: dialect = config.db.dialect - elif dialect == "default": - dialect = default.DefaultDialect() + elif dialect == "default" or dialect == "default_qmark": + if dialect == "default": + dialect = default.DefaultDialect() + else: + dialect = default.DefaultDialect("qmark") dialect.supports_default_values = supports_default_values dialect.supports_default_metavalue = supports_default_metavalue elif dialect == "default_enhanced": @@ -632,7 +638,7 @@ class AssertsCompiledSQL: if checkparams is not None: eq_(c.construct_params(params), checkparams) if checkpositional is not None: - p = c.construct_params(params) + p = c.construct_params(params, escape_names=False) eq_(tuple([p[x] for x in c.positiontup]), checkpositional) if check_prefetch is not None: eq_(c.prefetch, check_prefetch) @@ -652,6 +658,58 @@ class AssertsCompiledSQL: }, check_post_param, ) + if check_param_order and getattr(c, "params", None): + + def get_dialect(paramstyle, positional): + cp = copy(dialect) + cp.paramstyle = paramstyle + cp.positional = positional + return cp + + pyformat_dialect = get_dialect("pyformat", False) + pyformat_c = clause.compile(dialect=pyformat_dialect, **kw) + stmt = re.sub(r"[\n\t]", "", str(pyformat_c)) + + qmark_dialect = get_dialect("qmark", True) + qmark_c = clause.compile(dialect=qmark_dialect, **kw) + values = list(qmark_c.positiontup) + escaped = qmark_c.escaped_bind_names + + for post_param in ( + qmark_c.post_compile_params | qmark_c.literal_execute_params + ): + name = qmark_c.bind_names[post_param] + if name in values: + values = [v for v in values if v != name] + positions = [] + pos_by_value = defaultdict(list) + for v in values: + try: + if v in pos_by_value: + start = pos_by_value[v][-1] + else: + start = 0 + esc = escaped.get(v, v) + pos = stmt.index("%%(%s)s" % (esc,), start) + 2 + positions.append(pos) + pos_by_value[v].append(pos) + except ValueError: + msg = "Expected to find bindparam %r in %r" % (v, stmt) + assert False, msg + + ordered = all( + positions[i - 1] < positions[i] + for i in range(1, len(positions)) + ) + + expected = [v for _, v in sorted(zip(positions, values))] + + msg = ( + "Order of parameters %s does not match the order " + "in the statement %s. Statement %r" % (values, expected, stmt) + ) + + is_true(ordered, msg) class ComparesTables: |
