diff options
| -rw-r--r-- | doc/build/changelog/unreleased_20/8994.rst | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/cx_oracle.py | 35 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 75 | ||||
| -rw-r--r-- | test/sql/test_compiler.py | 39 |
4 files changed, 144 insertions, 18 deletions
diff --git a/doc/build/changelog/unreleased_20/8994.rst b/doc/build/changelog/unreleased_20/8994.rst new file mode 100644 index 000000000..cd2a056fa --- /dev/null +++ b/doc/build/changelog/unreleased_20/8994.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, sql + :tickets: 8994 + + To accommodate for third party dialects with different character escaping + needs regarding bound parameters, the system by which SQLAlchemy "escapes" + (i.e., replaces with another character in its place) special characters in + bound parameter names has been made extensible for third party dialects, + using the :attr:`.SQLCompiler.bindname_escape_chars` dictionary which can + be overridden at the class declaration level on any :class:`.SQLCompiler` + subclass. As part of this change, also added the dot ``"."`` as a default + "escaped" character. + diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 8f80aed65..c45aafae6 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -445,15 +445,6 @@ from ...sql._typing import is_sql_compiler _CX_ORACLE_MAGIC_LOB_SIZE = 131072 -_ORACLE_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]\.\/\? ]") - -# Oracle bind names can't start with digits or underscores. -# currently we rely upon Oracle-specific quoting of bind names in most cases. -# however for expanding params, the escape chars are used. -# see #8708 -_ORACLE_BIND_TRANSLATE_CHARS = dict(zip("%():[]./? ", "PAZCCCCCCCC")) - - class _OracleInteger(sqltypes.Integer): def get_dbapi_type(self, dbapi): # see https://github.com/oracle/python-cx_Oracle/issues/ @@ -694,6 +685,26 @@ class OracleCompiler_cx_oracle(OracleCompiler): _oracle_returning = False + # Oracle bind names can't start with digits or underscores. + # currently we rely upon Oracle-specific quoting of bind names in most + # cases. however for expanding params, the escape chars are used. + # see #8708 + bindname_escape_characters = util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "C", + "[": "C", + "]": "C", + " ": "C", + "\\": "C", + "/": "C", + "?": "C", + } + ) + def bindparam_string(self, name, **kw): quote = getattr(name, "quote", None) if ( @@ -721,12 +732,12 @@ class OracleCompiler_cx_oracle(OracleCompiler): escaped_from = kw.get("escaped_from", None) if not escaped_from: - if _ORACLE_BIND_TRANSLATE_RE.search(name): + if self._bind_translate_re.search(name): # not quite the translate use case as we want to # also get a quick boolean if we even found # unusual characters in the name - new_name = _ORACLE_BIND_TRANSLATE_RE.sub( - lambda m: _ORACLE_BIND_TRANSLATE_CHARS[m.group(0)], + new_name = self._bind_translate_re.sub( + lambda m: self._bind_translate_chars[m.group(0)], name, ) if new_name[0].isdigit() or new_name[0] == "_": diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 66a294d10..596ca986f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -37,6 +37,7 @@ import typing from typing import Any from typing import Callable from typing import cast +from typing import ClassVar from typing import Dict from typing import FrozenSet from typing import Iterable @@ -46,6 +47,7 @@ from typing import MutableMapping from typing import NamedTuple from typing import NoReturn from typing import Optional +from typing import Pattern from typing import Sequence from typing import Set from typing import Tuple @@ -238,9 +240,6 @@ BIND_TEMPLATES = { } -_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]") -_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___")) - OPERATORS = { # binary operators.and_: " AND ", @@ -714,6 +713,14 @@ class Compiled: self._gen_time = perf_counter() + def __init_subclass__(cls) -> None: + cls._init_compiler_cls() + return super().__init_subclass__() + + @classmethod + def _init_compiler_cls(cls): + pass + def _execute_on_connection( self, connection, distilled_params, execution_options ): @@ -866,6 +873,52 @@ class SQLCompiler(Compiled): extract_map = EXTRACT_MAP + bindname_escape_characters: ClassVar[ + Mapping[str, str] + ] = util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "_", + "[": "_", + "]": "_", + " ": "_", + } + ) + """A mapping (e.g. dict or similar) containing a lookup of + characters keyed to replacement characters which will be applied to all + 'bind names' used in SQL statements as a form of 'escaping'; the given + characters are replaced entirely with the 'replacement' character when + rendered in the SQL statement, and a similar translation is performed + on the incoming names used in parameter dictionaries passed to methods + like :meth:`_engine.Connection.execute`. + + This allows bound parameter names used in :func:`_sql.bindparam` and + other constructs to have any arbitrary characters present without any + concern for characters that aren't allowed at all on the target database. + + Third party dialects can establish their own dictionary here to replace the + default mapping, which will ensure that the particular characters in the + mapping will never appear in a bound parameter name. + + The dictionary is evaluated at **class creation time**, so cannot be + modified at runtime; it must be present on the class when the class + is first declared. + + Note that for dialects that have additional bound parameter rules such + as additional restrictions on leading characters, the + :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented. + See the cx_Oracle compiler for an example of this. + + .. versionadded:: 2.0.0b5 + + """ + + _bind_translate_re: ClassVar[Pattern[str]] + _bind_translate_chars: ClassVar[Mapping[str, str]] + is_sql = True compound_keywords = COMPOUND_KEYWORDS @@ -1108,6 +1161,16 @@ class SQLCompiler(Compiled): f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}" ) + @classmethod + def _init_compiler_cls(cls): + cls._init_bind_translate() + + @classmethod + def _init_bind_translate(cls): + reg = re.escape("".join(cls.bindname_escape_characters)) + cls._bind_translate_re = re.compile(f"[{reg}]") + cls._bind_translate_chars = cls.bindname_escape_characters + def __init__( self, dialect: Dialect, @@ -3591,12 +3654,12 @@ class SQLCompiler(Compiled): if not escaped_from: - if _BIND_TRANSLATE_RE.search(name): + if self._bind_translate_re.search(name): # not quite the translate use case as we want to # also get a quick boolean if we even found # unusual characters in the name - new_name = _BIND_TRANSLATE_RE.sub( - lambda m: _BIND_TRANSLATE_CHARS[m.group(0)], + new_name = self._bind_translate_re.sub( + lambda m: self._bind_translate_chars[m.group(0)], name, ) escaped_from = name diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 39971fd76..2907c6e0e 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -5152,6 +5152,45 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): render_postcompile=True, ) + def test_bind_escape_extensibility(self): + """test #8994, extensibility of the bind escape character lookup. + + The main test for actual known characters passing through for bound + params is in + sqlalchemy.testing.suite.test_dialect.DifficultParametersTest. + + """ + dialect = default.DefaultDialect() + + class Compiler(compiler.StrSQLCompiler): + bindname_escape_characters = { + "%": "P", + # chars that need regex escaping + "(": "A", + ")": "Z", + "*": "S", + "+": "L", + # completely random "normie" character + "8": "E", + ":": "C", + # left bracket is not escaped, right bracket is + "]": "_", + " ": "_", + } + + dialect.statement_compiler = Compiler + + self.assert_compile( + select( + bindparam("number8ight"), + bindparam("plus+sign"), + bindparam("par(en)s and [brackets]"), + ), + "SELECT :numberEight AS anon_1, :plusLsign AS anon_2, " + ":parAenZs_and_[brackets_ AS anon_3", + dialect=dialect, + ) + class CompileUXTest(fixtures.TestBase): """tests focused on calling stmt.compile() directly, user cases""" |
