diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-02-04 15:50:29 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-02-06 22:53:16 -0500 |
| commit | 30307c4616ad67c01ddae2e1e8e34fabf6028414 (patch) | |
| tree | 6e8edd4cb13132aa19f916409f3a3f3dcba7fd0c /lib/sqlalchemy | |
| parent | 11845453d76e1576f637161e660160f0a6117af6 (diff) | |
| download | sqlalchemy-30307c4616ad67c01ddae2e1e8e34fabf6028414.tar.gz | |
Remove all remaining text() coercions and ensure identifiers are safe
Fully removed the behavior of strings passed directly as components of a
:func:`.select` or :class:`.Query` object being coerced to :func:`.text`
constructs automatically; the warning that has been emitted is now an
ArgumentError or in the case of order_by() / group_by() a CompileError.
This has emitted a warning since version 1.0 however its presence continues
to create concerns for the potential of mis-use of this behavior.
Note that public CVEs have been posted for order_by() / group_by() which
are resolved by this commit: CVE-2019-7164 CVE-2019-7548
Added "SQL phrase validation" to key DDL phrases that are accepted as plain
strings, including :paramref:`.ForeignKeyConstraint.on_delete`,
:paramref:`.ForeignKeyConstraint.on_update`,
:paramref:`.ExcludeConstraint.using`,
:paramref:`.ForeignKeyConstraint.initially`, for areas where a series of SQL
keywords only are expected.Any non-space characters that suggest the phrase
would need to be quoted will raise a :class:`.CompileError`. This change
is related to the series of changes committed as part of :ticket:`4481`.
Fixed issue where using an uppercase name for an index type (e.g. GIST,
BTREE, etc. ) or an EXCLUDE constraint would treat it as an identifier to
be quoted, rather than rendering it as is. The new behavior converts these
types to lowercase and ensures they contain only valid SQL characters.
Quoting is applied to :class:`.Function` names, those which are usually but
not necessarily generated from the :attr:`.sql.func` construct, at compile
time if they contain illegal characters, such as spaces or punctuation. The
names are as before treated as case insensitive however, meaning if the
names contain uppercase or mixed case characters, that alone does not
trigger quoting. The case insensitivity is currently maintained for
backwards compatibility.
Fixes: #4481
Fixes: #4473
Fixes: #4467
Change-Id: Ib22a27d62930e24702e2f0f7c74a0473385a08eb
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/ext.py | 38 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 74 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 87 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 32 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/deprecations.py | 64 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/langhelpers.py | 80 |
12 files changed, 263 insertions, 139 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 5fda721fe..33a0e4af2 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -122,7 +122,7 @@ from .engine import create_engine # noqa nosort from .engine import engine_from_config # noqa nosort -__version__ = '1.3.0b3' +__version__ = "1.3.0b3" def __go(lcls): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 4004a2b9a..4d302dabe 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -948,6 +948,8 @@ except ImportError: _python_UUID = None +IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I) + AUTOCOMMIT_REGEXP = re.compile( r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|GRANT|REVOKE|" "IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW|TRUNCATE)", @@ -1908,7 +1910,10 @@ class PGDDLCompiler(compiler.DDLCompiler): using = index.dialect_options["postgresql"]["using"] if using: - text += "USING %s " % preparer.quote(using) + text += ( + "USING %s " + % self.preparer.validate_sql_phrase(using, IDX_USING).lower() + ) ops = index.dialect_options["postgresql"]["ops"] text += "(%s)" % ( @@ -1983,7 +1988,9 @@ class PGDDLCompiler(compiler.DDLCompiler): "%s WITH %s" % (self.sql_compiler.process(expr, **kw), op) ) text += "EXCLUDE USING %s (%s)" % ( - constraint.using, + self.preparer.validate_sql_phrase( + constraint.using, IDX_USING + ).lower(), ", ".join(elements), ) if constraint.where is not None: diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 49b5e0ec0..426028239 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -91,6 +91,11 @@ class ExcludeConstraint(ColumnCollectionConstraint): where = None + @elements._document_text_coercion( + "where", + ":class:`.ExcludeConstraint`", + ":paramref:`.ExcludeConstraint.where`", + ) def __init__(self, *elements, **kw): r""" Create an :class:`.ExcludeConstraint` object. @@ -123,21 +128,15 @@ class ExcludeConstraint(ColumnCollectionConstraint): ) :param \*elements: + A sequence of two tuples of the form ``(column, operator)`` where "column" is a SQL expression element or a raw SQL string, most - typically a :class:`.Column` object, - and "operator" is a string containing the operator to use. - - .. note:: - - A plain string passed for the value of "column" is interpreted - as an arbitrary SQL expression; when passing a plain string, - any necessary quoting and escaping syntaxes must be applied - manually. In order to specify a column name when a - :class:`.Column` object is not available, while ensuring that - any necessary quoting rules take effect, an ad-hoc - :class:`.Column` or :func:`.sql.expression.column` object may - be used. + typically a :class:`.Column` object, and "operator" is a string + containing the operator to use. In order to specify a column name + when a :class:`.Column` object is not available, while ensuring + that any necessary quoting rules take effect, an ad-hoc + :class:`.Column` or :func:`.sql.expression.column` object should be + used. :param name: Optional, the in-database name of this constraint. @@ -159,12 +158,6 @@ class ExcludeConstraint(ColumnCollectionConstraint): If set, emit WHERE <predicate> when issuing DDL for this constraint. - .. note:: - - A plain string passed here is interpreted as an arbitrary SQL - expression; when passing a plain string, any necessary quoting - and escaping syntaxes must be applied manually. - """ columns = [] render_exprs = [] @@ -184,11 +177,12 @@ class ExcludeConstraint(ColumnCollectionConstraint): # backwards compat self.operators[name] = operator - expr = expression._literal_as_text(expr) + expr = expression._literal_as_column(expr) render_exprs.append((expr, name, operator)) self._render_exprs = render_exprs + ColumnCollectionConstraint.__init__( self, *columns, @@ -199,7 +193,9 @@ class ExcludeConstraint(ColumnCollectionConstraint): self.using = kw.get("using", "gist") where = kw.get("where") if where is not None: - self.where = expression._literal_as_text(where) + self.where = expression._literal_as_text( + where, allow_coercion_to_text=True + ) def copy(self, **kw): elements = [(col, self.operators[col]) for col in self.columns.keys()] diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 9e52ef208..6d4198a4e 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1257,7 +1257,9 @@ class Session(_SessionClassMethods): in order to execute the statement. """ - clause = expression._literal_as_text(clause) + clause = expression._literal_as_text( + clause, allow_coercion_to_text=True + ) if bind is None: bind = self.get_bind(mapper, clause=clause, **kw) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b703c59f2..15ddd7d6f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -139,8 +139,16 @@ RESERVED_WORDS = set( ) LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I) +LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I) ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"]) +FK_ON_DELETE = re.compile( + r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I +) +FK_ON_UPDATE = re.compile( + r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I +) +FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I) BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE) BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE) @@ -758,12 +766,11 @@ class SQLCompiler(Compiled): else: col = with_cols[element.element] except KeyError: - # treat it like text() - util.warn_limited( - "Can't resolve label reference %r; converting to text()", - util.ellipses_string(element.element), + elements._no_text_coercion( + element.element, + exc.CompileError, + "Can't resolve label reference for ORDER BY / GROUP BY.", ) - return self.process(element._text_clause) else: kwargs["render_label_as_label"] = col return self.process( @@ -1076,10 +1083,24 @@ class SQLCompiler(Compiled): if func._has_args: name += "%(expr)s" else: - name = func.name + "%(expr)s" - return ".".join(list(func.packagenames) + [name]) % { - "expr": self.function_argspec(func, **kwargs) - } + name = func.name + name = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) + else name + ) + name = name + "%(expr)s" + return ".".join( + [ + ( + self.preparer.quote(tok) + if self.preparer._requires_quotes_illegal_chars(tok) + else tok + ) + for tok in func.packagenames + ] + + [name] + ) % {"expr": self.function_argspec(func, **kwargs)} def visit_next_value_func(self, next_value, **kw): return self.visit_sequence(next_value.sequence) @@ -3153,9 +3174,13 @@ class DDLCompiler(Compiled): def define_constraint_cascades(self, constraint): text = "" if constraint.ondelete is not None: - text += " ON DELETE %s" % constraint.ondelete + text += " ON DELETE %s" % self.preparer.validate_sql_phrase( + constraint.ondelete, FK_ON_DELETE + ) if constraint.onupdate is not None: - text += " ON UPDATE %s" % constraint.onupdate + text += " ON UPDATE %s" % self.preparer.validate_sql_phrase( + constraint.onupdate, FK_ON_UPDATE + ) return text def define_constraint_deferrability(self, constraint): @@ -3166,7 +3191,9 @@ class DDLCompiler(Compiled): else: text += " NOT DEFERRABLE" if constraint.initially is not None: - text += " INITIALLY %s" % constraint.initially + text += " INITIALLY %s" % self.preparer.validate_sql_phrase( + constraint.initially, FK_INITIALLY + ) return text def define_constraint_match(self, constraint): @@ -3416,6 +3443,24 @@ class IdentifierPreparer(object): return value.replace(self.escape_to_quote, self.escape_quote) + def validate_sql_phrase(self, element, reg): + """keyword sequence filter. + + a filter for elements that are intended to represent keyword sequences, + such as "INITIALLY", "INTIALLY DEFERRED", etc. no special characters + should be present. + + .. versionadded:: 1.3 + + """ + + if element is not None and not reg.match(element): + raise exc.CompileError( + "Unexpected SQL phrase: %r (matching against %r)" + % (element, reg.pattern) + ) + return element + def quote_identifier(self, value): """Quote an identifier. @@ -3439,6 +3484,11 @@ class IdentifierPreparer(object): or (lc_value != value) ) + def _requires_quotes_illegal_chars(self, value): + """Return True if the given identifier requires quoting, but + not taking case convention into account.""" + return not self.legal_characters.match(util.text_type(value)) + def quote_schema(self, schema, force=None): """Conditionally quote a schema name. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 9e4f5d95d..a4623128f 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -37,6 +37,20 @@ def _clone(element, **kw): return element._clone() +def _document_text_coercion(paramname, meth_rst, param_rst): + return util.add_parameter_text( + paramname, + ( + ".. warning:: " + "The %s argument to %s can be passed as a Python string argument, " + "which will be treated " + "as **trusted SQL text** and rendered as given. **DO NOT PASS " + "UNTRUSTED INPUT TO THIS PARAMETER**." + ) + % (param_rst, meth_rst), + ) + + def collate(expression, collation): """Return the clause ``expression COLLATE collation``. @@ -1343,6 +1357,7 @@ class TextClause(Executable, ClauseElement): "refer to the :meth:`.TextClause.columns` method.", ), ) + @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`") def _create_text( self, text, bind=None, bindparams=None, typemap=None, autocommit=None ): @@ -4430,32 +4445,64 @@ def _literal_and_labels_as_label_reference(element): def _expression_literal_as_text(element): - return _literal_as_text(element, warn=True) + return _literal_as_text(element) -def _literal_as_text(element, warn=False): +def _literal_as(element, text_fallback): if isinstance(element, Visitable): return element elif hasattr(element, "__clause_element__"): return element.__clause_element__() elif isinstance(element, util.string_types): - if warn: - util.warn_limited( - "Textual SQL expression %(expr)r should be " - "explicitly declared as text(%(expr)r)", - {"expr": util.ellipses_string(element)}, - ) - - return TextClause(util.text_type(element)) + return text_fallback(element) elif isinstance(element, (util.NoneType, bool)): return _const_expr(element) else: raise exc.ArgumentError( - "SQL expression object or string expected, got object of type %r " + "SQL expression object expected, got object of type %r " "instead" % type(element) ) +def _literal_as_text(element, allow_coercion_to_text=False): + if allow_coercion_to_text: + return _literal_as(element, TextClause) + else: + return _literal_as(element, _no_text_coercion) + + +def _literal_as_column(element): + return _literal_as(element, ColumnClause) + + +def _no_column_coercion(element): + element = str(element) + guess_is_literal = not _guess_straight_column.match(element) + raise exc.ArgumentError( + "Textual column expression %(column)r should be " + "explicitly declared with text(%(column)r), " + "or use %(literal_column)s(%(column)r) " + "for more specificity" + % { + "column": util.ellipses_string(element), + "literal_column": "literal_column" + if guess_is_literal + else "column", + } + ) + + +def _no_text_coercion(element, exc_cls=exc.ArgumentError, extra=None): + raise exc_cls( + "%(extra)sTextual SQL expression %(expr)r should be " + "explicitly declared as text(%(expr)r)" + % { + "expr": util.ellipses_string(element), + "extra": "%s " % extra if extra else "", + } + ) + + def _no_literals(element): if hasattr(element, "__clause_element__"): return element.__clause_element__() @@ -4529,23 +4576,7 @@ def _interpret_as_column_or_from(element): elif isinstance(element, (numbers.Number)): return ColumnClause(str(element), is_literal=True) else: - element = str(element) - # give into temptation, as this fact we are guessing about - # is not one we've previously ever needed our users tell us; - # but let them know we are not happy about it - guess_is_literal = not _guess_straight_column.match(element) - util.warn_limited( - "Textual column expression %(column)r should be " - "explicitly declared with text(%(column)r), " - "or use %(literal_column)s(%(column)r) " - "for more specificity", - { - "column": util.ellipses_string(element), - "literal_column": "literal_column" - if guess_is_literal - else "column", - }, - ) + _no_column_coercion(element) return ColumnClause(element, is_literal=guess_is_literal) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 2a27d0b73..82fe93029 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -101,6 +101,7 @@ from .elements import _expression_literal_as_text # noqa from .elements import _is_column # noqa from .elements import _labeled # noqa from .elements import _literal_as_binds # noqa +from .elements import _literal_as_column # noqa from .elements import _literal_as_label_reference # noqa from .elements import _literal_as_text # noqa from .elements import _only_column_elements # noqa diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 8997e119f..e981d7aed 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -42,6 +42,7 @@ from .base import ColumnCollection from .base import DialectKWArgs from .base import SchemaEventTarget from .elements import _as_truncated +from .elements import _document_text_coercion from .elements import _literal_as_text from .elements import ClauseElement from .elements import ColumnClause @@ -2884,6 +2885,11 @@ class CheckConstraint(ColumnCollectionConstraint): _allow_multiple_tables = True + @_document_text_coercion( + "sqltext", + ":class:`.CheckConstraint`", + ":paramref:`.CheckConstraint.sqltext`", + ) def __init__( self, sqltext, @@ -2925,7 +2931,7 @@ class CheckConstraint(ColumnCollectionConstraint): """ - self.sqltext = _literal_as_text(sqltext, warn=False) + self.sqltext = _literal_as_text(sqltext, allow_coercion_to_text=True) columns = [] visitors.traverse(self.sqltext, {}, {"column": columns.append}) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index a5dee068c..ac08604f5 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -31,11 +31,13 @@ from .elements import _clause_element_as_expr from .elements import _clone from .elements import _cloned_difference from .elements import _cloned_intersection +from .elements import _document_text_coercion from .elements import _expand_cloned from .elements import _interpret_as_column_or_from from .elements import _literal_and_labels_as_label_reference from .elements import _literal_as_label_reference from .elements import _literal_as_text +from .elements import _no_text_coercion from .elements import _select_iterables from .elements import and_ from .elements import BindParameter @@ -43,7 +45,6 @@ from .elements import ClauseElement from .elements import ClauseList from .elements import Grouping from .elements import literal_column -from .elements import TextClause from .elements import True_ from .elements import UnaryExpression from .. import exc @@ -55,14 +56,7 @@ def _interpret_as_from(element): insp = inspection.inspect(element, raiseerr=False) if insp is None: if isinstance(element, util.string_types): - util.warn_limited( - "Textual SQL FROM expression %(expr)r should be " - "explicitly declared as text(%(expr)r), " - "or use table(%(expr)r) for more specificity", - {"expr": util.ellipses_string(element)}, - ) - - return TextClause(util.text_type(element)) + _no_text_coercion(element) try: return insp.selectable except AttributeError: @@ -266,6 +260,11 @@ class HasPrefixes(object): _prefixes = () @_generative + @_document_text_coercion( + "expr", + ":meth:`.HasPrefixes.prefix_with`", + ":paramref:`.HasPrefixes.prefix_with.*expr`", + ) def prefix_with(self, *expr, **kw): r"""Add one or more expressions following the statement keyword, i.e. SELECT, INSERT, UPDATE, or DELETE. Generative. @@ -297,7 +296,10 @@ class HasPrefixes(object): def _setup_prefixes(self, prefixes, dialect=None): self._prefixes = self._prefixes + tuple( - [(_literal_as_text(p, warn=False), dialect) for p in prefixes] + [ + (_literal_as_text(p, allow_coercion_to_text=True), dialect) + for p in prefixes + ] ) @@ -305,6 +307,11 @@ class HasSuffixes(object): _suffixes = () @_generative + @_document_text_coercion( + "expr", + ":meth:`.HasSuffixes.suffix_with`", + ":paramref:`.HasSuffixes.suffix_with.*expr`", + ) def suffix_with(self, *expr, **kw): r"""Add one or more expressions following the statement as a whole. @@ -335,7 +342,10 @@ class HasSuffixes(object): def _setup_suffixes(self, suffixes, dialect=None): self._suffixes = self._suffixes + tuple( - [(_literal_as_text(p, warn=False), dialect) for p in suffixes] + [ + (_literal_as_text(p, allow_coercion_to_text=True), dialect) + for p in suffixes + ] ) diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 1e54ef80b..2f3deb191 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -93,6 +93,7 @@ from .deprecations import inject_docstring_text # noqa from .deprecations import pending_deprecation # noqa from .deprecations import warn_deprecated # noqa from .deprecations import warn_pending_deprecation # noqa +from .langhelpers import add_parameter_text # noqa from .langhelpers import as_interface # noqa from .langhelpers import asbool # noqa from .langhelpers import asint # noqa diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index a43acc72e..9abf4a6be 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -9,11 +9,12 @@ functionality.""" import re -import textwrap import warnings from . import compat from .langhelpers import decorator +from .langhelpers import inject_docstring_text +from .langhelpers import inject_param_text from .. import exc @@ -247,64 +248,3 @@ def _decorate_with_warning(func, wtype, message, docstring_header=None): decorated.__doc__ = doc decorated._sa_warn = lambda: warnings.warn(message, wtype, stacklevel=3) return decorated - - -def _dedent_docstring(text): - split_text = text.split("\n", 1) - if len(split_text) == 1: - return text - else: - firstline, remaining = split_text - if not firstline.startswith(" "): - return firstline + "\n" + textwrap.dedent(remaining) - else: - return textwrap.dedent(text) - - -def inject_docstring_text(doctext, injecttext, pos): - doctext = _dedent_docstring(doctext or "") - lines = doctext.split("\n") - injectlines = textwrap.dedent(injecttext).split("\n") - if injectlines[0]: - injectlines.insert(0, "") - - blanks = [num for num, line in enumerate(lines) if not line.strip()] - blanks.insert(0, 0) - - inject_pos = blanks[min(pos, len(blanks) - 1)] - - lines = lines[0:inject_pos] + injectlines + lines[inject_pos:] - return "\n".join(lines) - - -def inject_param_text(doctext, inject_params): - doclines = doctext.splitlines() - lines = [] - - to_inject = None - while doclines: - line = doclines.pop(0) - if to_inject is None: - m = re.match(r"(\s+):param (.+?):", line) - if m: - param = m.group(2) - if param in inject_params: - # default indent to that of :param: plus one - indent = " " * len(m.group(1)) + " " - - # but if the next line has text, use that line's - # indentntation - if doclines: - m2 = re.match(r"(\s+)\S", doclines[0]) - if m2: - indent = " " * len(m2.group(1)) - - to_inject = indent + inject_params[param] - elif not line.rstrip(): - lines.append(line) - lines.append(to_inject) - lines.append("\n") - to_inject = None - lines.append(line) - - return "\n".join(lines) diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index bfe3fd275..198a23a59 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -16,6 +16,7 @@ import itertools import operator import re import sys +import textwrap import types import warnings @@ -1572,3 +1573,82 @@ def quoted_token_parser(value): idx += 1 return ["".join(token) for token in result] + + +def add_parameter_text(params, text): + params = _collections.to_list(params) + + def decorate(fn): + doc = fn.__doc__ is not None and fn.__doc__ or "" + if doc: + doc = inject_param_text(doc, {param: text for param in params}) + fn.__doc__ = doc + return fn + + return decorate + + +def _dedent_docstring(text): + split_text = text.split("\n", 1) + if len(split_text) == 1: + return text + else: + firstline, remaining = split_text + if not firstline.startswith(" "): + return firstline + "\n" + textwrap.dedent(remaining) + else: + return textwrap.dedent(text) + + +def inject_docstring_text(doctext, injecttext, pos): + doctext = _dedent_docstring(doctext or "") + lines = doctext.split("\n") + injectlines = textwrap.dedent(injecttext).split("\n") + if injectlines[0]: + injectlines.insert(0, "") + + blanks = [num for num, line in enumerate(lines) if not line.strip()] + blanks.insert(0, 0) + + inject_pos = blanks[min(pos, len(blanks) - 1)] + + lines = lines[0:inject_pos] + injectlines + lines[inject_pos:] + return "\n".join(lines) + + +def inject_param_text(doctext, inject_params): + doclines = doctext.splitlines() + lines = [] + + to_inject = None + while doclines: + line = doclines.pop(0) + if to_inject is None: + m = re.match(r"(\s+):param (?:\\\*\*?)?(.+?):", line) + if m: + param = m.group(2) + if param in inject_params: + # default indent to that of :param: plus one + indent = " " * len(m.group(1)) + " " + + # but if the next line has text, use that line's + # indentntation + if doclines: + m2 = re.match(r"(\s+)\S", doclines[0]) + if m2: + indent = " " * len(m2.group(1)) + + to_inject = indent + inject_params[param] + elif line.lstrip().startswith(":param "): + lines.append("\n") + lines.append(to_inject) + lines.append("\n") + to_inject = None + elif not line.rstrip(): + lines.append(line) + lines.append(to_inject) + lines.append("\n") + to_inject = None + lines.append(line) + + return "\n".join(lines) |
