summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py22
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py13
-rw-r--r--lib/sqlalchemy/sql/compiler.py115
-rw-r--r--lib/sqlalchemy/sql/selectable.py8
-rw-r--r--lib/sqlalchemy/testing/assertions.py64
5 files changed, 170 insertions, 52 deletions
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 0d51bf73d..41b9ac43d 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -854,7 +854,7 @@ class OracleCompiler(compiler.SQLCompiler):
def visit_function(self, func, **kw):
text = super().visit_function(func, **kw)
if kw.get("asfrom", False):
- text = "TABLE (%s)" % func
+ text = "TABLE (%s)" % text
return text
def visit_table_valued_column(self, element, **kw):
@@ -1222,20 +1222,18 @@ class OracleCompiler(compiler.SQLCompiler):
self.process(binary.right),
)
- def _get_regexp_args(self, binary, kw):
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
string = self.process(binary.left, **kw)
pattern = self.process(binary.right, **kw)
flags = binary.modifiers["flags"]
- if flags is not None:
- flags = self.process(flags, **kw)
- return string, pattern, flags
-
- def visit_regexp_match_op_binary(self, binary, operator, **kw):
- string, pattern, flags = self._get_regexp_args(binary, kw)
if flags is None:
return "REGEXP_LIKE(%s, %s)" % (string, pattern)
else:
- return "REGEXP_LIKE(%s, %s, %s)" % (string, pattern, flags)
+ return "REGEXP_LIKE(%s, %s, %s)" % (
+ string,
+ pattern,
+ self.process(flags, **kw),
+ )
def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
return "NOT %s" % self.visit_regexp_match_op_binary(
@@ -1243,8 +1241,10 @@ class OracleCompiler(compiler.SQLCompiler):
)
def visit_regexp_replace_op_binary(self, binary, operator, **kw):
- string, pattern, flags = self._get_regexp_args(binary, kw)
+ string = self.process(binary.left, **kw)
+ pattern = self.process(binary.right, **kw)
replacement = self.process(binary.modifiers["replacement"], **kw)
+ flags = binary.modifiers["flags"]
if flags is None:
return "REGEXP_REPLACE(%s, %s, %s)" % (
string,
@@ -1256,7 +1256,7 @@ class OracleCompiler(compiler.SQLCompiler):
string,
pattern,
replacement,
- flags,
+ self.process(flags, **kw),
)
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 99c48fb2f..f9108094f 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1747,14 +1747,11 @@ class PGCompiler(compiler.SQLCompiler):
return self._generate_generic_binary(
binary, " %s* " % base_op, **kw
)
- flags = self.process(flags, **kw)
- string = self.process(binary.left, **kw)
- pattern = self.process(binary.right, **kw)
return "%s %s CONCAT('(?', %s, ')', %s)" % (
- string,
+ self.process(binary.left, **kw),
base_op,
- flags,
- pattern,
+ self.process(flags, **kw),
+ self.process(binary.right, **kw),
)
def visit_regexp_match_op_binary(self, binary, operator, **kw):
@@ -1767,8 +1764,6 @@ class PGCompiler(compiler.SQLCompiler):
string = self.process(binary.left, **kw)
pattern = self.process(binary.right, **kw)
flags = binary.modifiers["flags"]
- if flags is not None:
- flags = self.process(flags, **kw)
replacement = self.process(binary.modifiers["replacement"], **kw)
if flags is None:
return "REGEXP_REPLACE(%s, %s, %s)" % (
@@ -1781,7 +1776,7 @@ class PGCompiler(compiler.SQLCompiler):
string,
pattern,
replacement,
- flags,
+ self.process(flags, **kw),
)
def visit_empty_set_expr(self, element_types):
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 50cf9b477..7ac279ee2 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -236,8 +236,8 @@ BIND_TEMPLATES = {
}
-_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
-_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
+_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]")
+_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___"))
OPERATORS = {
# binary
@@ -973,6 +973,7 @@ class SQLCompiler(Compiled):
debugging use cases.
"""
+ positiontup_level: Optional[Dict[str, int]] = None
inline: bool = False
@@ -988,6 +989,8 @@ class SQLCompiler(Compiled):
ctes_recursive: bool
cte_positional: Dict[CTE, List[str]]
+ cte_level: Dict[CTE, int]
+ cte_order: Dict[Optional[CTE], List[CTE]]
def __init__(
self,
@@ -1052,6 +1055,7 @@ class SQLCompiler(Compiled):
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
+ self.positiontup_level = {}
self.positiontup = []
self._numeric_binds = dialect.paramstyle == "numeric"
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
@@ -1215,6 +1219,8 @@ class SQLCompiler(Compiled):
self.ctes_recursive = False
if self.positional:
self.cte_positional = {}
+ self.cte_level = {}
+ self.cte_order = collections.defaultdict(list)
return ctes
@@ -2103,7 +2109,13 @@ class SQLCompiler(Compiled):
text = self.process(taf.element, **kw)
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)
@@ -2231,6 +2243,7 @@ class SQLCompiler(Compiled):
)
def visit_over(self, over, **kwargs):
+ text = over.element._compiler_dispatch(self, **kwargs)
if over.range_:
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
over.range_, **kwargs
@@ -2243,7 +2256,7 @@ class SQLCompiler(Compiled):
range_ = None
return "%s OVER (%s)" % (
- over.element._compiler_dispatch(self, **kwargs),
+ text,
" ".join(
[
"%s BY %s"
@@ -2396,7 +2409,9 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = (
self._render_cte_clause(
- nesting_level=nesting_level, include_following_stack=True
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ visiting_cte=kwargs.get("visiting_cte"),
)
+ text
)
@@ -3222,7 +3237,8 @@ class SQLCompiler(Compiled):
positional_names.append(name)
else:
self.positiontup.append(name) # type: ignore[union-attr]
- elif not escaped_from:
+ self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501
+ if not escaped_from:
if _BIND_TRANSLATE_RE.search(name):
# not quite the translate use case as we want to
@@ -3333,6 +3349,8 @@ class SQLCompiler(Compiled):
self.level_name_by_cte[_reference_cte] = new_level_name + (
cte_opts,
)
+ if self.positional:
+ self.cte_level[cte] = cte_level
else:
cte_level = len(self.stack) if nesting else 1
@@ -3396,6 +3414,8 @@ class SQLCompiler(Compiled):
self.level_name_by_cte[_reference_cte] = cte_level_name + (
cte_opts,
)
+ if self.positional:
+ self.cte_level[cte] = cte_level
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
@@ -4129,13 +4149,16 @@ class SQLCompiler(Compiled):
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
- if self.ctes:
- # In compound query, CTEs are shared at the compound level
- if not is_embedded_select:
- nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(nesting_level=nesting_level) + text
+ # In compound query, CTEs are shared at the compound level
+ if self.ctes and (not is_embedded_select or toplevel):
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kwargs.get("visiting_cte"),
)
+ + text
+ )
if select_stmt._suffixes:
text += " " + self._generate_prefixes(
@@ -4309,6 +4332,7 @@ class SQLCompiler(Compiled):
self,
nesting_level=None,
include_following_stack=False,
+ visiting_cte=None,
):
"""
include_following_stack
@@ -4341,19 +4365,47 @@ class SQLCompiler(Compiled):
if not ctes:
return ""
-
ctes_recursive = any([cte.recursive for cte in ctes])
if self.positional:
- assert self.positiontup is not None
- self.positiontup = (
- list(
- itertools.chain.from_iterable(
- self.cte_positional[cte] for cte in ctes
- )
+ self.cte_order[visiting_cte].extend(ctes)
+
+ if visiting_cte is None and self.cte_order:
+ assert self.positiontup is not None
+
+ def get_nested_positional(cte):
+ if cte in self.cte_order:
+ children = self.cte_order.pop(cte)
+ to_add = list(
+ itertools.chain.from_iterable(
+ get_nested_positional(child_cte)
+ for child_cte in children
+ )
+ )
+ if cte in self.cte_positional:
+ return reorder_positional(
+ self.cte_positional[cte],
+ to_add,
+ self.cte_level[children[0]],
+ )
+ else:
+ return to_add
+ else:
+ return self.cte_positional.get(cte, [])
+
+ def reorder_positional(pos, to_add, level):
+ if not level:
+ return to_add + pos
+ index = 0
+ for index, name in enumerate(reversed(pos)):
+ if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501
+ break
+ return pos[:-index] + to_add + pos[-index:]
+
+ to_add = get_nested_positional(None)
+ self.positiontup = reorder_positional(
+ self.positiontup, to_add, nesting_level
)
- + self.positiontup
- )
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
@@ -4930,6 +4982,7 @@ class SQLCompiler(Compiled):
self._render_cte_clause(
nesting_level=nesting_level,
include_following_stack=True,
+ visiting_cte=kw.get("visiting_cte"),
),
select_text,
)
@@ -4997,7 +5050,9 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = (
self._render_cte_clause(
- nesting_level=nesting_level, include_following_stack=True
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ visiting_cte=kw.get("visiting_cte"),
)
+ text
)
@@ -5146,7 +5201,13 @@ class SQLCompiler(Compiled):
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)
@@ -5260,7 +5321,13 @@ class SQLCompiler(Compiled):
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 97336d416..2dcc611fa 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -2052,9 +2052,7 @@ class CTE(
else:
self.element._generate_fromclause_column_proxies(self)
- def alias(
- self, name: Optional[str] = None, flat: bool = False
- ) -> NamedFromClause:
+ def alias(self, name: Optional[str] = None, flat: bool = False) -> CTE:
"""Return an :class:`_expression.Alias` of this
:class:`_expression.CTE`.
@@ -2078,7 +2076,7 @@ class CTE(
_suffixes=self._suffixes,
)
- def union(self, *other):
+ def union(self, *other: _SelectStatementForCompoundArgument) -> CTE:
r"""Return a new :class:`_expression.CTE` with a SQL ``UNION``
of the original CTE against the given selectables provided
as positional arguments.
@@ -2107,7 +2105,7 @@ class CTE(
_suffixes=self._suffixes,
)
- def union_all(self, *other):
+ def union_all(self, *other: _SelectStatementForCompoundArgument) -> CTE:
r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL``
of the original CTE against the given selectables provided
as positional arguments.
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index 321c05b44..790a72ec8 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -9,7 +9,9 @@
from __future__ import annotations
+from collections import defaultdict
import contextlib
+from copy import copy
from itertools import filterfalse
import re
import sys
@@ -493,6 +495,7 @@ class AssertsCompiledSQL:
render_schema_translate=False,
default_schema_name=None,
from_linting=False,
+ check_param_order=True,
):
if use_default_dialect:
dialect = default.DefaultDialect()
@@ -506,8 +509,11 @@ class AssertsCompiledSQL:
if dialect is None:
dialect = config.db.dialect
- elif dialect == "default":
- dialect = default.DefaultDialect()
+ elif dialect == "default" or dialect == "default_qmark":
+ if dialect == "default":
+ dialect = default.DefaultDialect()
+ else:
+ dialect = default.DefaultDialect("qmark")
dialect.supports_default_values = supports_default_values
dialect.supports_default_metavalue = supports_default_metavalue
elif dialect == "default_enhanced":
@@ -632,7 +638,7 @@ class AssertsCompiledSQL:
if checkparams is not None:
eq_(c.construct_params(params), checkparams)
if checkpositional is not None:
- p = c.construct_params(params)
+ p = c.construct_params(params, escape_names=False)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
if check_prefetch is not None:
eq_(c.prefetch, check_prefetch)
@@ -652,6 +658,58 @@ class AssertsCompiledSQL:
},
check_post_param,
)
+ if check_param_order and getattr(c, "params", None):
+
+ def get_dialect(paramstyle, positional):
+ cp = copy(dialect)
+ cp.paramstyle = paramstyle
+ cp.positional = positional
+ return cp
+
+ pyformat_dialect = get_dialect("pyformat", False)
+ pyformat_c = clause.compile(dialect=pyformat_dialect, **kw)
+ stmt = re.sub(r"[\n\t]", "", str(pyformat_c))
+
+ qmark_dialect = get_dialect("qmark", True)
+ qmark_c = clause.compile(dialect=qmark_dialect, **kw)
+ values = list(qmark_c.positiontup)
+ escaped = qmark_c.escaped_bind_names
+
+ for post_param in (
+ qmark_c.post_compile_params | qmark_c.literal_execute_params
+ ):
+ name = qmark_c.bind_names[post_param]
+ if name in values:
+ values = [v for v in values if v != name]
+ positions = []
+ pos_by_value = defaultdict(list)
+ for v in values:
+ try:
+ if v in pos_by_value:
+ start = pos_by_value[v][-1]
+ else:
+ start = 0
+ esc = escaped.get(v, v)
+ pos = stmt.index("%%(%s)s" % (esc,), start) + 2
+ positions.append(pos)
+ pos_by_value[v].append(pos)
+ except ValueError:
+ msg = "Expected to find bindparam %r in %r" % (v, stmt)
+ assert False, msg
+
+ ordered = all(
+ positions[i - 1] < positions[i]
+ for i in range(1, len(positions))
+ )
+
+ expected = [v for _, v in sorted(zip(positions, values))]
+
+ msg = (
+ "Order of parameters %s does not match the order "
+ "in the statement %s. Statement %r" % (values, expected, stmt)
+ )
+
+ is_true(ordered, msg)
class ComparesTables: