summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/8994.rst13
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py35
-rw-r--r--lib/sqlalchemy/sql/compiler.py75
-rw-r--r--test/sql/test_compiler.py39
4 files changed, 144 insertions, 18 deletions
diff --git a/doc/build/changelog/unreleased_20/8994.rst b/doc/build/changelog/unreleased_20/8994.rst
new file mode 100644
index 000000000..cd2a056fa
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/8994.rst
@@ -0,0 +1,13 @@
+.. change::
+ :tags: bug, sql
+ :tickets: 8994
+
+ To accommodate for third party dialects with different character escaping
+ needs regarding bound parameters, the system by which SQLAlchemy "escapes"
+ (i.e., replaces with another character in its place) special characters in
+ bound parameter names has been made extensible for third party dialects,
+ using the :attr:`.SQLCompiler.bindname_escape_chars` dictionary which can
+ be overridden at the class declaration level on any :class:`.SQLCompiler`
+ subclass. As part of this change, also added the dot ``"."`` as a default
+ "escaped" character.
+
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index 8f80aed65..c45aafae6 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -445,15 +445,6 @@ from ...sql._typing import is_sql_compiler
_CX_ORACLE_MAGIC_LOB_SIZE = 131072
-_ORACLE_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]\.\/\? ]")
-
-# Oracle bind names can't start with digits or underscores.
-# currently we rely upon Oracle-specific quoting of bind names in most cases.
-# however for expanding params, the escape chars are used.
-# see #8708
-_ORACLE_BIND_TRANSLATE_CHARS = dict(zip("%():[]./? ", "PAZCCCCCCCC"))
-
-
class _OracleInteger(sqltypes.Integer):
def get_dbapi_type(self, dbapi):
# see https://github.com/oracle/python-cx_Oracle/issues/
@@ -694,6 +685,26 @@ class OracleCompiler_cx_oracle(OracleCompiler):
_oracle_returning = False
+ # Oracle bind names can't start with digits or underscores.
+ # currently we rely upon Oracle-specific quoting of bind names in most
+ # cases. however for expanding params, the escape chars are used.
+ # see #8708
+ bindname_escape_characters = util.immutabledict(
+ {
+ "%": "P",
+ "(": "A",
+ ")": "Z",
+ ":": "C",
+ ".": "C",
+ "[": "C",
+ "]": "C",
+ " ": "C",
+ "\\": "C",
+ "/": "C",
+ "?": "C",
+ }
+ )
+
def bindparam_string(self, name, **kw):
quote = getattr(name, "quote", None)
if (
@@ -721,12 +732,12 @@ class OracleCompiler_cx_oracle(OracleCompiler):
escaped_from = kw.get("escaped_from", None)
if not escaped_from:
- if _ORACLE_BIND_TRANSLATE_RE.search(name):
+ if self._bind_translate_re.search(name):
# not quite the translate use case as we want to
# also get a quick boolean if we even found
# unusual characters in the name
- new_name = _ORACLE_BIND_TRANSLATE_RE.sub(
- lambda m: _ORACLE_BIND_TRANSLATE_CHARS[m.group(0)],
+ new_name = self._bind_translate_re.sub(
+ lambda m: self._bind_translate_chars[m.group(0)],
name,
)
if new_name[0].isdigit() or new_name[0] == "_":
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 66a294d10..596ca986f 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -37,6 +37,7 @@ import typing
from typing import Any
from typing import Callable
from typing import cast
+from typing import ClassVar
from typing import Dict
from typing import FrozenSet
from typing import Iterable
@@ -46,6 +47,7 @@ from typing import MutableMapping
from typing import NamedTuple
from typing import NoReturn
from typing import Optional
+from typing import Pattern
from typing import Sequence
from typing import Set
from typing import Tuple
@@ -238,9 +240,6 @@ BIND_TEMPLATES = {
}
-_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]")
-_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___"))
-
OPERATORS = {
# binary
operators.and_: " AND ",
@@ -714,6 +713,14 @@ class Compiled:
self._gen_time = perf_counter()
+ def __init_subclass__(cls) -> None:
+ cls._init_compiler_cls()
+ return super().__init_subclass__()
+
+ @classmethod
+ def _init_compiler_cls(cls):
+ pass
+
def _execute_on_connection(
self, connection, distilled_params, execution_options
):
@@ -866,6 +873,52 @@ class SQLCompiler(Compiled):
extract_map = EXTRACT_MAP
+ bindname_escape_characters: ClassVar[
+ Mapping[str, str]
+ ] = util.immutabledict(
+ {
+ "%": "P",
+ "(": "A",
+ ")": "Z",
+ ":": "C",
+ ".": "_",
+ "[": "_",
+ "]": "_",
+ " ": "_",
+ }
+ )
+ """A mapping (e.g. dict or similar) containing a lookup of
+ characters keyed to replacement characters which will be applied to all
+ 'bind names' used in SQL statements as a form of 'escaping'; the given
+ characters are replaced entirely with the 'replacement' character when
+ rendered in the SQL statement, and a similar translation is performed
+ on the incoming names used in parameter dictionaries passed to methods
+ like :meth:`_engine.Connection.execute`.
+
+ This allows bound parameter names used in :func:`_sql.bindparam` and
+ other constructs to have any arbitrary characters present without any
+ concern for characters that aren't allowed at all on the target database.
+
+ Third party dialects can establish their own dictionary here to replace the
+ default mapping, which will ensure that the particular characters in the
+ mapping will never appear in a bound parameter name.
+
+ The dictionary is evaluated at **class creation time**, so cannot be
+ modified at runtime; it must be present on the class when the class
+ is first declared.
+
+ Note that for dialects that have additional bound parameter rules such
+ as additional restrictions on leading characters, the
+ :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented.
+ See the cx_Oracle compiler for an example of this.
+
+ .. versionadded:: 2.0.0b5
+
+ """
+
+ _bind_translate_re: ClassVar[Pattern[str]]
+ _bind_translate_chars: ClassVar[Mapping[str, str]]
+
is_sql = True
compound_keywords = COMPOUND_KEYWORDS
@@ -1108,6 +1161,16 @@ class SQLCompiler(Compiled):
f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
)
+ @classmethod
+ def _init_compiler_cls(cls):
+ cls._init_bind_translate()
+
+ @classmethod
+ def _init_bind_translate(cls):
+ reg = re.escape("".join(cls.bindname_escape_characters))
+ cls._bind_translate_re = re.compile(f"[{reg}]")
+ cls._bind_translate_chars = cls.bindname_escape_characters
+
def __init__(
self,
dialect: Dialect,
@@ -3591,12 +3654,12 @@ class SQLCompiler(Compiled):
if not escaped_from:
- if _BIND_TRANSLATE_RE.search(name):
+ if self._bind_translate_re.search(name):
# not quite the translate use case as we want to
# also get a quick boolean if we even found
# unusual characters in the name
- new_name = _BIND_TRANSLATE_RE.sub(
- lambda m: _BIND_TRANSLATE_CHARS[m.group(0)],
+ new_name = self._bind_translate_re.sub(
+ lambda m: self._bind_translate_chars[m.group(0)],
name,
)
escaped_from = name
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index 39971fd76..2907c6e0e 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -5152,6 +5152,45 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
render_postcompile=True,
)
+ def test_bind_escape_extensibility(self):
+ """test #8994, extensibility of the bind escape character lookup.
+
+ The main test for actual known characters passing through for bound
+ params is in
+ sqlalchemy.testing.suite.test_dialect.DifficultParametersTest.
+
+ """
+ dialect = default.DefaultDialect()
+
+ class Compiler(compiler.StrSQLCompiler):
+ bindname_escape_characters = {
+ "%": "P",
+ # chars that need regex escaping
+ "(": "A",
+ ")": "Z",
+ "*": "S",
+ "+": "L",
+ # completely random "normie" character
+ "8": "E",
+ ":": "C",
+ # left bracket is not escaped, right bracket is
+ "]": "_",
+ " ": "_",
+ }
+
+ dialect.statement_compiler = Compiler
+
+ self.assert_compile(
+ select(
+ bindparam("number8ight"),
+ bindparam("plus+sign"),
+ bindparam("par(en)s and [brackets]"),
+ ),
+ "SELECT :numberEight AS anon_1, :plusLsign AS anon_2, "
+ ":parAenZs_and_[brackets_ AS anon_3",
+ dialect=dialect,
+ )
+
class CompileUXTest(fixtures.TestBase):
"""tests focused on calling stmt.compile() directly, user cases"""