summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/__init__.py11
-rw-r--r--lib/sqlalchemy/sql/annotation.py16
-rw-r--r--lib/sqlalchemy/sql/base.py114
-rw-r--r--lib/sqlalchemy/sql/compiler.py2030
-rw-r--r--lib/sqlalchemy/sql/crud.py440
-rw-r--r--lib/sqlalchemy/sql/ddl.py306
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py234
-rw-r--r--lib/sqlalchemy/sql/dml.py194
-rw-r--r--lib/sqlalchemy/sql/elements.py800
-rw-r--r--lib/sqlalchemy/sql/expression.py205
-rw-r--r--lib/sqlalchemy/sql/functions.py139
-rw-r--r--lib/sqlalchemy/sql/naming.py47
-rw-r--r--lib/sqlalchemy/sql/operators.py109
-rw-r--r--lib/sqlalchemy/sql/schema.py1129
-rw-r--r--lib/sqlalchemy/sql/selectable.py812
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py649
-rw-r--r--lib/sqlalchemy/sql/type_api.py122
-rw-r--r--lib/sqlalchemy/sql/util.py327
-rw-r--r--lib/sqlalchemy/sql/visitors.py48
19 files changed, 4603 insertions, 3129 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index aa811388b..87e2fb6c3 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -72,7 +72,7 @@ from .expression import (
union,
union_all,
update,
- within_group
+ within_group,
)
from .visitors import ClauseVisitor
@@ -84,12 +84,16 @@ def __go(lcls):
import inspect as _inspect
- __all__ = sorted(name for name, obj in lcls.items()
- if not (name.startswith('_') or _inspect.ismodule(obj)))
+ __all__ = sorted(
+ name
+ for name, obj in lcls.items()
+ if not (name.startswith("_") or _inspect.ismodule(obj))
+ )
from .annotation import _prepare_annotations, Annotated
from .elements import AnnotatedColumnElement, ClauseList
from .selectable import AnnotatedFromClause
+
_prepare_annotations(ColumnElement, AnnotatedColumnElement)
_prepare_annotations(FromClause, AnnotatedFromClause)
_prepare_annotations(ClauseList, Annotated)
@@ -98,4 +102,5 @@ def __go(lcls):
from . import naming
+
__go(locals())
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index c1d484d95..64cfa630e 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -76,8 +76,7 @@ class Annotated(object):
return self._with_annotations(_values)
def _compiler_dispatch(self, visitor, **kw):
- return self.__element.__class__._compiler_dispatch(
- self, visitor, **kw)
+ return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
@property
def _constructor(self):
@@ -120,10 +119,13 @@ def _deep_annotate(element, annotations, exclude=None):
Elements within the exclude collection will be cloned but not annotated.
"""
+
def clone(elem):
- if exclude and \
- hasattr(elem, 'proxy_set') and \
- elem.proxy_set.intersection(exclude):
+ if (
+ exclude
+ and hasattr(elem, "proxy_set")
+ and elem.proxy_set.intersection(exclude)
+ ):
newelem = elem._clone()
elif annotations != elem._annotations:
newelem = elem._annotate(annotations)
@@ -191,8 +193,8 @@ def _new_annotation_type(cls, base_cls):
break
annotated_classes[cls] = anno_cls = type(
- "Annotated%s" % cls.__name__,
- (base_cls, cls), {})
+ "Annotated%s" % cls.__name__, (base_cls, cls), {}
+ )
globals()["Annotated%s" % cls.__name__] = anno_cls
return anno_cls
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 6b9b55753..45db215fe 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -15,8 +15,8 @@ import itertools
from .visitors import ClauseVisitor
import re
-PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT')
-NO_ARG = util.symbol('NO_ARG')
+PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT")
+NO_ARG = util.symbol("NO_ARG")
class Immutable(object):
@@ -77,7 +77,8 @@ class _DialectArgView(util.collections_abc.MutableMapping):
dialect, value_key = self._key(key)
except KeyError:
raise exc.ArgumentError(
- "Keys must be of the form <dialectname>_<argname>")
+ "Keys must be of the form <dialectname>_<argname>"
+ )
else:
self.obj.dialect_options[dialect][value_key] = value
@@ -86,15 +87,18 @@ class _DialectArgView(util.collections_abc.MutableMapping):
del self.obj.dialect_options[dialect][value_key]
def __len__(self):
- return sum(len(args._non_defaults) for args in
- self.obj.dialect_options.values())
+ return sum(
+ len(args._non_defaults)
+ for args in self.obj.dialect_options.values()
+ )
def __iter__(self):
return (
util.safe_kwarg("%s_%s" % (dialect_name, value_name))
for dialect_name in self.obj.dialect_options
- for value_name in
- self.obj.dialect_options[dialect_name]._non_defaults
+ for value_name in self.obj.dialect_options[
+ dialect_name
+ ]._non_defaults
)
@@ -187,8 +191,8 @@ class DialectKWArgs(object):
if construct_arg_dictionary is None:
raise exc.ArgumentError(
"Dialect '%s' does have keyword-argument "
- "validation and defaults enabled configured" %
- dialect_name)
+ "validation and defaults enabled configured" % dialect_name
+ )
if cls not in construct_arg_dictionary:
construct_arg_dictionary[cls] = {}
construct_arg_dictionary[cls][argument_name] = default
@@ -230,6 +234,7 @@ class DialectKWArgs(object):
if dialect_cls.construct_arguments is None:
return None
return dict(dialect_cls.construct_arguments)
+
_kw_registry = util.PopulateDict(_kw_reg_for_dialect)
def _kw_reg_for_dialect_cls(self, dialect_name):
@@ -274,11 +279,12 @@ class DialectKWArgs(object):
return
for k in kwargs:
- m = re.match('^(.+?)_(.+)$', k)
+ m = re.match("^(.+?)_(.+)$", k)
if not m:
raise TypeError(
"Additional arguments should be "
- "named <dialectname>_<argument>, got '%s'" % k)
+ "named <dialectname>_<argument>, got '%s'" % k
+ )
dialect_name, arg_name = m.group(1, 2)
try:
@@ -286,20 +292,22 @@ class DialectKWArgs(object):
except exc.NoSuchModuleError:
util.warn(
"Can't validate argument %r; can't "
- "locate any SQLAlchemy dialect named %r" %
- (k, dialect_name))
+ "locate any SQLAlchemy dialect named %r"
+ % (k, dialect_name)
+ )
self.dialect_options[dialect_name] = d = _DialectArgDict()
d._defaults.update({"*": None})
d._non_defaults[arg_name] = kwargs[k]
else:
- if "*" not in construct_arg_dictionary and \
- arg_name not in construct_arg_dictionary:
+ if (
+ "*" not in construct_arg_dictionary
+ and arg_name not in construct_arg_dictionary
+ ):
raise exc.ArgumentError(
"Argument %r is not accepted by "
- "dialect %r on behalf of %r" % (
- k,
- dialect_name, self.__class__
- ))
+ "dialect %r on behalf of %r"
+ % (k, dialect_name, self.__class__)
+ )
else:
construct_arg_dictionary[arg_name] = kwargs[k]
@@ -359,14 +367,14 @@ class Executable(Generative):
:meth:`.Query.execution_options()`
"""
- if 'isolation_level' in kw:
+ if "isolation_level" in kw:
raise exc.ArgumentError(
"'isolation_level' execution option may only be specified "
"on Connection.execution_options(), or "
"per-engine using the isolation_level "
"argument to create_engine()."
)
- if 'compiled_cache' in kw:
+ if "compiled_cache" in kw:
raise exc.ArgumentError(
"'compiled_cache' execution option may only be specified "
"on Connection.execution_options(), not per statement."
@@ -377,10 +385,12 @@ class Executable(Generative):
"""Compile and execute this :class:`.Executable`."""
e = self.bind
if e is None:
- label = getattr(self, 'description', self.__class__.__name__)
- msg = ('This %s is not directly bound to a Connection or Engine. '
- 'Use the .execute() method of a Connection or Engine '
- 'to execute this construct.' % label)
+ label = getattr(self, "description", self.__class__.__name__)
+ msg = (
+ "This %s is not directly bound to a Connection or Engine. "
+ "Use the .execute() method of a Connection or Engine "
+ "to execute this construct." % label
+ )
raise exc.UnboundExecutionError(msg)
return e._execute_clauseelement(self, multiparams, params)
@@ -434,7 +444,7 @@ class SchemaEventTarget(object):
class SchemaVisitor(ClauseVisitor):
"""Define the visiting for ``SchemaItem`` objects."""
- __traverse_options__ = {'schema_visitor': True}
+ __traverse_options__ = {"schema_visitor": True}
class ColumnCollection(util.OrderedProperties):
@@ -446,11 +456,11 @@ class ColumnCollection(util.OrderedProperties):
"""
- __slots__ = '_all_columns'
+ __slots__ = "_all_columns"
def __init__(self, *columns):
super(ColumnCollection, self).__init__()
- object.__setattr__(self, '_all_columns', [])
+ object.__setattr__(self, "_all_columns", [])
for c in columns:
self.add(c)
@@ -485,8 +495,9 @@ class ColumnCollection(util.OrderedProperties):
self._data[column.key] = column
if remove_col is not None:
- self._all_columns[:] = [column if c is remove_col
- else c for c in self._all_columns]
+ self._all_columns[:] = [
+ column if c is remove_col else c for c in self._all_columns
+ ]
else:
self._all_columns.append(column)
@@ -499,7 +510,8 @@ class ColumnCollection(util.OrderedProperties):
"""
if not column.key:
raise exc.ArgumentError(
- "Can't add unnamed column to column collection")
+ "Can't add unnamed column to column collection"
+ )
self[column.key] = column
def __delitem__(self, key):
@@ -521,10 +533,12 @@ class ColumnCollection(util.OrderedProperties):
return
if not existing.shares_lineage(value):
- util.warn('Column %r on table %r being replaced by '
- '%r, which has the same key. Consider '
- 'use_labels for select() statements.' %
- (key, getattr(existing, 'table', None), value))
+ util.warn(
+ "Column %r on table %r being replaced by "
+ "%r, which has the same key. Consider "
+ "use_labels for select() statements."
+ % (key, getattr(existing, "table", None), value)
+ )
# pop out memoized proxy_set as this
# operation may very well be occurring
@@ -540,13 +554,15 @@ class ColumnCollection(util.OrderedProperties):
def remove(self, column):
del self._data[column.key]
self._all_columns[:] = [
- c for c in self._all_columns if c is not column]
+ c for c in self._all_columns if c is not column
+ ]
def update(self, iter):
cols = list(iter)
all_col_set = set(self._all_columns)
self._all_columns.extend(
- c for label, c in cols if c not in all_col_set)
+ c for label, c in cols if c not in all_col_set
+ )
self._data.update((label, c) for label, c in cols)
def extend(self, iter):
@@ -572,12 +588,11 @@ class ColumnCollection(util.OrderedProperties):
return util.OrderedProperties.__contains__(self, other)
def __getstate__(self):
- return {'_data': self._data,
- '_all_columns': self._all_columns}
+ return {"_data": self._data, "_all_columns": self._all_columns}
def __setstate__(self, state):
- object.__setattr__(self, '_data', state['_data'])
- object.__setattr__(self, '_all_columns', state['_all_columns'])
+ object.__setattr__(self, "_data", state["_data"])
+ object.__setattr__(self, "_all_columns", state["_all_columns"])
def contains_column(self, col):
return col in set(self._all_columns)
@@ -589,7 +604,7 @@ class ColumnCollection(util.OrderedProperties):
class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
def __init__(self, data, all_columns):
util.ImmutableProperties.__init__(self, data)
- object.__setattr__(self, '_all_columns', all_columns)
+ object.__setattr__(self, "_all_columns", all_columns)
extend = remove = util.ImmutableProperties._immutable
@@ -622,15 +637,18 @@ def _bind_or_error(schemaitem, msg=None):
bind = schemaitem.bind
if not bind:
name = schemaitem.__class__.__name__
- label = getattr(schemaitem, 'fullname',
- getattr(schemaitem, 'name', None))
+ label = getattr(
+ schemaitem, "fullname", getattr(schemaitem, "name", None)
+ )
if label:
- item = '%s object %r' % (name, label)
+ item = "%s object %r" % (name, label)
else:
- item = '%s object' % name
+ item = "%s object" % name
if msg is None:
- msg = "%s is not bound to an Engine or Connection. "\
- "Execution can not proceed without a database to execute "\
+ msg = (
+ "%s is not bound to an Engine or Connection. "
+ "Execution can not proceed without a database to execute "
"against." % item
+ )
raise exc.UnboundExecutionError(msg)
return bind
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 80ed707ed..f641d0a84 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -25,133 +25,218 @@ To generate user-defined SQL strings, see
import contextlib
import re
-from . import schema, sqltypes, operators, functions, visitors, \
- elements, selectable, crud
+from . import (
+ schema,
+ sqltypes,
+ operators,
+ functions,
+ visitors,
+ elements,
+ selectable,
+ crud,
+)
from .. import util, exc
import itertools
-RESERVED_WORDS = set([
- 'all', 'analyse', 'analyze', 'and', 'any', 'array',
- 'as', 'asc', 'asymmetric', 'authorization', 'between',
- 'binary', 'both', 'case', 'cast', 'check', 'collate',
- 'column', 'constraint', 'create', 'cross', 'current_date',
- 'current_role', 'current_time', 'current_timestamp',
- 'current_user', 'default', 'deferrable', 'desc',
- 'distinct', 'do', 'else', 'end', 'except', 'false',
- 'for', 'foreign', 'freeze', 'from', 'full', 'grant',
- 'group', 'having', 'ilike', 'in', 'initially', 'inner',
- 'intersect', 'into', 'is', 'isnull', 'join', 'leading',
- 'left', 'like', 'limit', 'localtime', 'localtimestamp',
- 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset',
- 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps',
- 'placing', 'primary', 'references', 'right', 'select',
- 'session_user', 'set', 'similar', 'some', 'symmetric', 'table',
- 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user',
- 'using', 'verbose', 'when', 'where'])
-
-LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I)
-ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(['$'])
-
-BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE)
-BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]*)(?![:\w\$])', re.UNICODE)
+RESERVED_WORDS = set(
+ [
+ "all",
+ "analyse",
+ "analyze",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "asymmetric",
+ "authorization",
+ "between",
+ "binary",
+ "both",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "constraint",
+ "create",
+ "cross",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "default",
+ "deferrable",
+ "desc",
+ "distinct",
+ "do",
+ "else",
+ "end",
+ "except",
+ "false",
+ "for",
+ "foreign",
+ "freeze",
+ "from",
+ "full",
+ "grant",
+ "group",
+ "having",
+ "ilike",
+ "in",
+ "initially",
+ "inner",
+ "intersect",
+ "into",
+ "is",
+ "isnull",
+ "join",
+ "leading",
+ "left",
+ "like",
+ "limit",
+ "localtime",
+ "localtimestamp",
+ "natural",
+ "new",
+ "not",
+ "notnull",
+ "null",
+ "off",
+ "offset",
+ "old",
+ "on",
+ "only",
+ "or",
+ "order",
+ "outer",
+ "overlaps",
+ "placing",
+ "primary",
+ "references",
+ "right",
+ "select",
+ "session_user",
+ "set",
+ "similar",
+ "some",
+ "symmetric",
+ "table",
+ "then",
+ "to",
+ "trailing",
+ "true",
+ "union",
+ "unique",
+ "user",
+ "using",
+ "verbose",
+ "when",
+ "where",
+ ]
+)
+
+LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
+ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
+
+BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
+BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
BIND_TEMPLATES = {
- 'pyformat': "%%(%(name)s)s",
- 'qmark': "?",
- 'format': "%%s",
- 'numeric': ":[_POSITION]",
- 'named': ":%(name)s"
+ "pyformat": "%%(%(name)s)s",
+ "qmark": "?",
+ "format": "%%s",
+ "numeric": ":[_POSITION]",
+ "named": ":%(name)s",
}
OPERATORS = {
# binary
- operators.and_: ' AND ',
- operators.or_: ' OR ',
- operators.add: ' + ',
- operators.mul: ' * ',
- operators.sub: ' - ',
- operators.div: ' / ',
- operators.mod: ' % ',
- operators.truediv: ' / ',
- operators.neg: '-',
- operators.lt: ' < ',
- operators.le: ' <= ',
- operators.ne: ' != ',
- operators.gt: ' > ',
- operators.ge: ' >= ',
- operators.eq: ' = ',
- operators.is_distinct_from: ' IS DISTINCT FROM ',
- operators.isnot_distinct_from: ' IS NOT DISTINCT FROM ',
- operators.concat_op: ' || ',
- operators.match_op: ' MATCH ',
- operators.notmatch_op: ' NOT MATCH ',
- operators.in_op: ' IN ',
- operators.notin_op: ' NOT IN ',
- operators.comma_op: ', ',
- operators.from_: ' FROM ',
- operators.as_: ' AS ',
- operators.is_: ' IS ',
- operators.isnot: ' IS NOT ',
- operators.collate: ' COLLATE ',
-
+ operators.and_: " AND ",
+ operators.or_: " OR ",
+ operators.add: " + ",
+ operators.mul: " * ",
+ operators.sub: " - ",
+ operators.div: " / ",
+ operators.mod: " % ",
+ operators.truediv: " / ",
+ operators.neg: "-",
+ operators.lt: " < ",
+ operators.le: " <= ",
+ operators.ne: " != ",
+ operators.gt: " > ",
+ operators.ge: " >= ",
+ operators.eq: " = ",
+ operators.is_distinct_from: " IS DISTINCT FROM ",
+ operators.isnot_distinct_from: " IS NOT DISTINCT FROM ",
+ operators.concat_op: " || ",
+ operators.match_op: " MATCH ",
+ operators.notmatch_op: " NOT MATCH ",
+ operators.in_op: " IN ",
+ operators.notin_op: " NOT IN ",
+ operators.comma_op: ", ",
+ operators.from_: " FROM ",
+ operators.as_: " AS ",
+ operators.is_: " IS ",
+ operators.isnot: " IS NOT ",
+ operators.collate: " COLLATE ",
# unary
- operators.exists: 'EXISTS ',
- operators.distinct_op: 'DISTINCT ',
- operators.inv: 'NOT ',
- operators.any_op: 'ANY ',
- operators.all_op: 'ALL ',
-
+ operators.exists: "EXISTS ",
+ operators.distinct_op: "DISTINCT ",
+ operators.inv: "NOT ",
+ operators.any_op: "ANY ",
+ operators.all_op: "ALL ",
# modifiers
- operators.desc_op: ' DESC',
- operators.asc_op: ' ASC',
- operators.nullsfirst_op: ' NULLS FIRST',
- operators.nullslast_op: ' NULLS LAST',
-
+ operators.desc_op: " DESC",
+ operators.asc_op: " ASC",
+ operators.nullsfirst_op: " NULLS FIRST",
+ operators.nullslast_op: " NULLS LAST",
}
FUNCTIONS = {
- functions.coalesce: 'coalesce',
- functions.current_date: 'CURRENT_DATE',
- functions.current_time: 'CURRENT_TIME',
- functions.current_timestamp: 'CURRENT_TIMESTAMP',
- functions.current_user: 'CURRENT_USER',
- functions.localtime: 'LOCALTIME',
- functions.localtimestamp: 'LOCALTIMESTAMP',
- functions.random: 'random',
- functions.sysdate: 'sysdate',
- functions.session_user: 'SESSION_USER',
- functions.user: 'USER',
- functions.cube: 'CUBE',
- functions.rollup: 'ROLLUP',
- functions.grouping_sets: 'GROUPING SETS',
+ functions.coalesce: "coalesce",
+ functions.current_date: "CURRENT_DATE",
+ functions.current_time: "CURRENT_TIME",
+ functions.current_timestamp: "CURRENT_TIMESTAMP",
+ functions.current_user: "CURRENT_USER",
+ functions.localtime: "LOCALTIME",
+ functions.localtimestamp: "LOCALTIMESTAMP",
+ functions.random: "random",
+ functions.sysdate: "sysdate",
+ functions.session_user: "SESSION_USER",
+ functions.user: "USER",
+ functions.cube: "CUBE",
+ functions.rollup: "ROLLUP",
+ functions.grouping_sets: "GROUPING SETS",
}
EXTRACT_MAP = {
- 'month': 'month',
- 'day': 'day',
- 'year': 'year',
- 'second': 'second',
- 'hour': 'hour',
- 'doy': 'doy',
- 'minute': 'minute',
- 'quarter': 'quarter',
- 'dow': 'dow',
- 'week': 'week',
- 'epoch': 'epoch',
- 'milliseconds': 'milliseconds',
- 'microseconds': 'microseconds',
- 'timezone_hour': 'timezone_hour',
- 'timezone_minute': 'timezone_minute'
+ "month": "month",
+ "day": "day",
+ "year": "year",
+ "second": "second",
+ "hour": "hour",
+ "doy": "doy",
+ "minute": "minute",
+ "quarter": "quarter",
+ "dow": "dow",
+ "week": "week",
+ "epoch": "epoch",
+ "milliseconds": "milliseconds",
+ "microseconds": "microseconds",
+ "timezone_hour": "timezone_hour",
+ "timezone_minute": "timezone_minute",
}
COMPOUND_KEYWORDS = {
- selectable.CompoundSelect.UNION: 'UNION',
- selectable.CompoundSelect.UNION_ALL: 'UNION ALL',
- selectable.CompoundSelect.EXCEPT: 'EXCEPT',
- selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL',
- selectable.CompoundSelect.INTERSECT: 'INTERSECT',
- selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL'
+ selectable.CompoundSelect.UNION: "UNION",
+ selectable.CompoundSelect.UNION_ALL: "UNION ALL",
+ selectable.CompoundSelect.EXCEPT: "EXCEPT",
+ selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL",
+ selectable.CompoundSelect.INTERSECT: "INTERSECT",
+ selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL",
}
@@ -177,9 +262,14 @@ class Compiled(object):
sub-elements of the statement can modify these.
"""
- def __init__(self, dialect, statement, bind=None,
- schema_translate_map=None,
- compile_kwargs=util.immutabledict()):
+ def __init__(
+ self,
+ dialect,
+ statement,
+ bind=None,
+ schema_translate_map=None,
+ compile_kwargs=util.immutabledict(),
+ ):
"""Construct a new :class:`.Compiled` object.
:param dialect: :class:`.Dialect` to compile against.
@@ -209,7 +299,8 @@ class Compiled(object):
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
self.preparer = self.preparer._with_schema_translate(
- schema_translate_map)
+ schema_translate_map
+ )
if statement is not None:
self.statement = statement
@@ -218,8 +309,10 @@ class Compiled(object):
self.execution_options = statement._execution_options
self.string = self.process(self.statement, **compile_kwargs)
- @util.deprecated("0.7", ":class:`.Compiled` objects now compile "
- "within the constructor.")
+ @util.deprecated(
+ "0.7",
+ ":class:`.Compiled` objects now compile " "within the constructor.",
+ )
def compile(self):
"""Produce the internal string representation of this element.
"""
@@ -247,7 +340,7 @@ class Compiled(object):
def __str__(self):
"""Return the string text of the generated SQL or DDL."""
- return self.string or ''
+ return self.string or ""
def construct_params(self, params=None):
"""Return the bind params for this compiled object.
@@ -271,7 +364,9 @@ class Compiled(object):
if e is None:
raise exc.UnboundExecutionError(
"This Compiled object is not bound to any Engine "
- "or Connection.", code="2afi")
+ "or Connection.",
+ code="2afi",
+ )
return e._execute_compiled(self, multiparams, params)
def scalar(self, *multiparams, **params):
@@ -284,7 +379,7 @@ class Compiled(object):
class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
"""Produces DDL specification for TypeEngine objects."""
- ensure_kwarg = r'visit_\w+'
+ ensure_kwarg = r"visit_\w+"
def __init__(self, dialect):
self.dialect = dialect
@@ -297,8 +392,8 @@ class _CompileLabel(visitors.Visitable):
"""lightweight label object which acts as an expression.Label."""
- __visit_name__ = 'label'
- __slots__ = 'element', 'name'
+ __visit_name__ = "label"
+ __slots__ = "element", "name"
def __init__(self, col, name, alt_names=()):
self.element = col
@@ -390,8 +485,9 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
- def __init__(self, dialect, statement, column_keys=None,
- inline=False, **kwargs):
+ def __init__(
+ self, dialect, statement, column_keys=None, inline=False, **kwargs
+ ):
"""Construct a new :class:`.SQLCompiler` object.
:param dialect: :class:`.Dialect` to be used
@@ -412,7 +508,7 @@ class SQLCompiler(Compiled):
# compile INSERT/UPDATE defaults/sequences inlined (no pre-
# execute)
- self.inline = inline or getattr(statement, 'inline', False)
+ self.inline = inline or getattr(statement, "inline", False)
# a dictionary of bind parameter keys to BindParameter
# instances.
@@ -440,8 +536,9 @@ class SQLCompiler(Compiled):
self.ctes = None
- self.label_length = dialect.label_length \
- or dialect.max_identifier_length
+ self.label_length = (
+ dialect.label_length or dialect.max_identifier_length
+ )
# a map which tracks "anonymous" identifiers that are created on
# the fly here
@@ -453,7 +550,7 @@ class SQLCompiler(Compiled):
Compiled.__init__(self, dialect, statement, **kwargs)
if (
- self.isinsert or self.isupdate or self.isdelete
+ self.isinsert or self.isupdate or self.isdelete
) and statement._returning:
self.returning = statement._returning
@@ -482,37 +579,43 @@ class SQLCompiler(Compiled):
def _nested_result(self):
"""special API to support the use case of 'nested result sets'"""
result_columns, ordered_columns = (
- self._result_columns, self._ordered_columns)
+ self._result_columns,
+ self._ordered_columns,
+ )
self._result_columns, self._ordered_columns = [], False
try:
if self.stack:
entry = self.stack[-1]
- entry['need_result_map_for_nested'] = True
+ entry["need_result_map_for_nested"] = True
else:
entry = None
yield self._result_columns, self._ordered_columns
finally:
if entry:
- entry.pop('need_result_map_for_nested')
+ entry.pop("need_result_map_for_nested")
self._result_columns, self._ordered_columns = (
- result_columns, ordered_columns)
+ result_columns,
+ ordered_columns,
+ )
def _apply_numbered_params(self):
poscount = itertools.count(1)
self.string = re.sub(
- r'\[_POSITION\]',
- lambda m: str(util.next(poscount)),
- self.string)
+ r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string
+ )
@util.memoized_property
def _bind_processors(self):
return dict(
- (key, value) for key, value in
- ((self.bind_names[bindparam],
- bindparam.type._cached_bind_processor(self.dialect)
- )
- for bindparam in self.bind_names)
+ (key, value)
+ for key, value in (
+ (
+ self.bind_names[bindparam],
+ bindparam.type._cached_bind_processor(self.dialect),
+ )
+ for bindparam in self.bind_names
+ )
if value is not None
)
@@ -539,12 +642,16 @@ class SQLCompiler(Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
- "in parameter group %d" %
- (bindparam.key, _group_number), code="cd3x")
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
- % bindparam.key, code="cd3x")
+ % bindparam.key,
+ code="cd3x",
+ )
elif bindparam.callable:
pd[name] = bindparam.effective_value
@@ -558,12 +665,16 @@ class SQLCompiler(Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
- "in parameter group %d" %
- (bindparam.key, _group_number), code="cd3x")
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
- % bindparam.key, code="cd3x")
+ % bindparam.key,
+ code="cd3x",
+ )
if bindparam.callable:
pd[self.bind_names[bindparam]] = bindparam.effective_value
@@ -595,9 +706,10 @@ class SQLCompiler(Compiled):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
def visit_label_reference(
- self, element, within_columns_clause=False, **kwargs):
+ self, element, within_columns_clause=False, **kwargs
+ ):
if self.stack and self.dialect.supports_simple_order_by_label:
- selectable = self.stack[-1]['selectable']
+ selectable = self.stack[-1]["selectable"]
with_cols, only_froms, only_cols = selectable._label_resolve_dict
if within_columns_clause:
@@ -611,25 +723,30 @@ class SQLCompiler(Compiled):
# to something else like a ColumnClause expression.
order_by_elem = element.element._order_by_label_element
- if order_by_elem is not None and order_by_elem.name in \
- resolve_dict and \
- order_by_elem.shares_lineage(
- resolve_dict[order_by_elem.name]):
- kwargs['render_label_as_label'] = \
- element.element._order_by_label_element
+ if (
+ order_by_elem is not None
+ and order_by_elem.name in resolve_dict
+ and order_by_elem.shares_lineage(
+ resolve_dict[order_by_elem.name]
+ )
+ ):
+ kwargs[
+ "render_label_as_label"
+ ] = element.element._order_by_label_element
return self.process(
- element.element, within_columns_clause=within_columns_clause,
- **kwargs)
+ element.element,
+ within_columns_clause=within_columns_clause,
+ **kwargs
+ )
def visit_textual_label_reference(
- self, element, within_columns_clause=False, **kwargs):
+ self, element, within_columns_clause=False, **kwargs
+ ):
if not self.stack:
# compiling the element outside of the context of a SELECT
- return self.process(
- element._text_clause
- )
+ return self.process(element._text_clause)
- selectable = self.stack[-1]['selectable']
+ selectable = self.stack[-1]["selectable"]
with_cols, only_froms, only_cols = selectable._label_resolve_dict
try:
if within_columns_clause:
@@ -640,26 +757,30 @@ class SQLCompiler(Compiled):
# treat it like text()
util.warn_limited(
"Can't resolve label reference %r; converting to text()",
- util.ellipses_string(element.element))
- return self.process(
- element._text_clause
+ util.ellipses_string(element.element),
)
+ return self.process(element._text_clause)
else:
- kwargs['render_label_as_label'] = col
+ kwargs["render_label_as_label"] = col
return self.process(
- col, within_columns_clause=within_columns_clause, **kwargs)
-
- def visit_label(self, label,
- add_to_result_map=None,
- within_label_clause=False,
- within_columns_clause=False,
- render_label_as_label=None,
- **kw):
+ col, within_columns_clause=within_columns_clause, **kwargs
+ )
+
+ def visit_label(
+ self,
+ label,
+ add_to_result_map=None,
+ within_label_clause=False,
+ within_columns_clause=False,
+ render_label_as_label=None,
+ **kw
+ ):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior.
- render_label_with_as = (within_columns_clause and not
- within_label_clause)
+ render_label_with_as = (
+ within_columns_clause and not within_label_clause
+ )
render_label_only = render_label_as_label is label
if render_label_only or render_label_with_as:
@@ -673,27 +794,35 @@ class SQLCompiler(Compiled):
add_to_result_map(
labelname,
label.name,
- (label, labelname, ) + label._alt_names,
- label.type
+ (label, labelname) + label._alt_names,
+ label.type,
)
- return label.element._compiler_dispatch(
- self, within_columns_clause=True,
- within_label_clause=True, **kw) + \
- OPERATORS[operators.as_] + \
- self.preparer.format_label(label, labelname)
+ return (
+ label.element._compiler_dispatch(
+ self,
+ within_columns_clause=True,
+ within_label_clause=True,
+ **kw
+ )
+ + OPERATORS[operators.as_]
+ + self.preparer.format_label(label, labelname)
+ )
elif render_label_only:
return self.preparer.format_label(label, labelname)
else:
return label.element._compiler_dispatch(
- self, within_columns_clause=False, **kw)
+ self, within_columns_clause=False, **kw
+ )
def _fallback_column_name(self, column):
- raise exc.CompileError("Cannot compile Column object until "
- "its 'name' is assigned.")
+ raise exc.CompileError(
+ "Cannot compile Column object until " "its 'name' is assigned."
+ )
- def visit_column(self, column, add_to_result_map=None,
- include_table=True, **kwargs):
+ def visit_column(
+ self, column, add_to_result_map=None, include_table=True, **kwargs
+ ):
name = orig_name = column.name
if name is None:
name = self._fallback_column_name(column)
@@ -704,10 +833,7 @@ class SQLCompiler(Compiled):
if add_to_result_map is not None:
add_to_result_map(
- name,
- orig_name,
- (column, name, column.key),
- column.type
+ name, orig_name, (column, name, column.key), column.type
)
if is_literal:
@@ -721,17 +847,16 @@ class SQLCompiler(Compiled):
effective_schema = self.preparer.schema_for_object(table)
if effective_schema:
- schema_prefix = self.preparer.quote_schema(
- effective_schema) + '.'
+ schema_prefix = (
+ self.preparer.quote_schema(effective_schema) + "."
+ )
else:
- schema_prefix = ''
+ schema_prefix = ""
tablename = table.name
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
- return schema_prefix + \
- self.preparer.quote(tablename) + \
- "." + name
+ return schema_prefix + self.preparer.quote(tablename) + "." + name
def visit_collation(self, element, **kw):
return self.preparer.format_collation(element.collation)
@@ -743,17 +868,17 @@ class SQLCompiler(Compiled):
return index.name
def visit_typeclause(self, typeclause, **kw):
- kw['type_expression'] = typeclause
+ kw["type_expression"] = typeclause
return self.dialect.type_compiler.process(typeclause.type, **kw)
def post_process_text(self, text):
if self.preparer._double_percents:
- text = text.replace('%', '%%')
+ text = text.replace("%", "%%")
return text
def escape_literal_column(self, text):
if self.preparer._double_percents:
- text = text.replace('%', '%%')
+ text = text.replace("%", "%%")
return text
def visit_textclause(self, textclause, **kw):
@@ -771,30 +896,36 @@ class SQLCompiler(Compiled):
return BIND_PARAMS_ESC.sub(
lambda m: m.group(1),
BIND_PARAMS.sub(
- do_bindparam,
- self.post_process_text(textclause.text))
+ do_bindparam, self.post_process_text(textclause.text)
+ ),
)
- def visit_text_as_from(self, taf,
- compound_index=None,
- asfrom=False,
- parens=True, **kw):
+ def visit_text_as_from(
+ self, taf, compound_index=None, asfrom=False, parens=True, **kw
+ ):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- populate_result_map = toplevel or \
- (
- compound_index == 0 and entry.get(
- 'need_result_map_for_compound', False)
- ) or entry.get('need_result_map_for_nested', False)
+ populate_result_map = (
+ toplevel
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
if populate_result_map:
- self._ordered_columns = \
- self._textual_ordered_columns = taf.positional
+ self._ordered_columns = (
+ self._textual_ordered_columns
+ ) = taf.positional
for c in taf.column_args:
- self.process(c, within_columns_clause=True,
- add_to_result_map=self._add_to_result_map)
+ self.process(
+ c,
+ within_columns_clause=True,
+ add_to_result_map=self._add_to_result_map,
+ )
text = self.process(taf.element, **kw)
if asfrom and parens:
@@ -802,17 +933,17 @@ class SQLCompiler(Compiled):
return text
def visit_null(self, expr, **kw):
- return 'NULL'
+ return "NULL"
def visit_true(self, expr, **kw):
if self.dialect.supports_native_boolean:
- return 'true'
+ return "true"
else:
return "1"
def visit_false(self, expr, **kw):
if self.dialect.supports_native_boolean:
- return 'false'
+ return "false"
else:
return "0"
@@ -823,25 +954,29 @@ class SQLCompiler(Compiled):
else:
sep = OPERATORS[clauselist.operator]
return sep.join(
- s for s in
- (
- c._compiler_dispatch(self, **kw)
- for c in clauselist.clauses)
- if s)
+ s
+ for s in (
+ c._compiler_dispatch(self, **kw) for c in clauselist.clauses
+ )
+ if s
+ )
def visit_case(self, clause, **kwargs):
x = "CASE "
if clause.value is not None:
x += clause.value._compiler_dispatch(self, **kwargs) + " "
for cond, result in clause.whens:
- x += "WHEN " + cond._compiler_dispatch(
- self, **kwargs
- ) + " THEN " + result._compiler_dispatch(
- self, **kwargs) + " "
+ x += (
+ "WHEN "
+ + cond._compiler_dispatch(self, **kwargs)
+ + " THEN "
+ + result._compiler_dispatch(self, **kwargs)
+ + " "
+ )
if clause.else_ is not None:
- x += "ELSE " + clause.else_._compiler_dispatch(
- self, **kwargs
- ) + " "
+ x += (
+ "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
+ )
x += "END"
return x
@@ -849,79 +984,84 @@ class SQLCompiler(Compiled):
return type_coerce.typed_expression._compiler_dispatch(self, **kw)
def visit_cast(self, cast, **kwargs):
- return "CAST(%s AS %s)" % \
- (cast.clause._compiler_dispatch(self, **kwargs),
- cast.typeclause._compiler_dispatch(self, **kwargs))
+ return "CAST(%s AS %s)" % (
+ cast.clause._compiler_dispatch(self, **kwargs),
+ cast.typeclause._compiler_dispatch(self, **kwargs),
+ )
def _format_frame_clause(self, range_, **kw):
- return '%s AND %s' % (
+ return "%s AND %s" % (
"UNBOUNDED PRECEDING"
if range_[0] is elements.RANGE_UNBOUNDED
- else "CURRENT ROW" if range_[0] is elements.RANGE_CURRENT
- else "%s PRECEDING" % (
- self.process(elements.literal(abs(range_[0])), **kw), )
+ else "CURRENT ROW"
+ if range_[0] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[0])), **kw),)
if range_[0] < 0
- else "%s FOLLOWING" % (
- self.process(elements.literal(range_[0]), **kw), ),
-
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[0]), **kw),),
"UNBOUNDED FOLLOWING"
if range_[1] is elements.RANGE_UNBOUNDED
- else "CURRENT ROW" if range_[1] is elements.RANGE_CURRENT
- else "%s PRECEDING" % (
- self.process(elements.literal(abs(range_[1])), **kw), )
+ else "CURRENT ROW"
+ if range_[1] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[1])), **kw),)
if range_[1] < 0
- else "%s FOLLOWING" % (
- self.process(elements.literal(range_[1]), **kw), ),
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[1]), **kw),),
)
def visit_over(self, over, **kwargs):
if over.range_:
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
- over.range_, **kwargs)
+ over.range_, **kwargs
+ )
elif over.rows:
range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
- over.rows, **kwargs)
+ over.rows, **kwargs
+ )
else:
range_ = None
return "%s OVER (%s)" % (
over.element._compiler_dispatch(self, **kwargs),
- ' '.join([
- '%s BY %s' % (
- word, clause._compiler_dispatch(self, **kwargs)
- )
- for word, clause in (
- ('PARTITION', over.partition_by),
- ('ORDER', over.order_by)
- )
- if clause is not None and len(clause)
- ] + ([range_] if range_ else [])
- )
+ " ".join(
+ [
+ "%s BY %s"
+ % (word, clause._compiler_dispatch(self, **kwargs))
+ for word, clause in (
+ ("PARTITION", over.partition_by),
+ ("ORDER", over.order_by),
+ )
+ if clause is not None and len(clause)
+ ]
+ + ([range_] if range_ else [])
+ ),
)
def visit_withingroup(self, withingroup, **kwargs):
return "%s WITHIN GROUP (ORDER BY %s)" % (
withingroup.element._compiler_dispatch(self, **kwargs),
- withingroup.order_by._compiler_dispatch(self, **kwargs)
+ withingroup.order_by._compiler_dispatch(self, **kwargs),
)
def visit_funcfilter(self, funcfilter, **kwargs):
return "%s FILTER (WHERE %s)" % (
funcfilter.func._compiler_dispatch(self, **kwargs),
- funcfilter.criterion._compiler_dispatch(self, **kwargs)
+ funcfilter.criterion._compiler_dispatch(self, **kwargs),
)
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (
- field, extract.expr._compiler_dispatch(self, **kwargs))
+ field,
+ extract.expr._compiler_dispatch(self, **kwargs),
+ )
def visit_function(self, func, add_to_result_map=None, **kwargs):
if add_to_result_map is not None:
- add_to_result_map(
- func.name, func.name, (), func.type
- )
+ add_to_result_map(func.name, func.name, (), func.type)
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
if disp:
@@ -933,51 +1073,63 @@ class SQLCompiler(Compiled):
name += "%(expr)s"
else:
name = func.name + "%(expr)s"
- return ".".join(list(func.packagenames) + [name]) % \
- {'expr': self.function_argspec(func, **kwargs)}
+ return ".".join(list(func.packagenames) + [name]) % {
+ "expr": self.function_argspec(func, **kwargs)
+ }
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
def visit_sequence(self, sequence, **kw):
raise NotImplementedError(
- "Dialect '%s' does not support sequence increments." %
- self.dialect.name
+ "Dialect '%s' does not support sequence increments."
+ % self.dialect.name
)
def function_argspec(self, func, **kwargs):
return func.clause_expr._compiler_dispatch(self, **kwargs)
- def visit_compound_select(self, cs, asfrom=False,
- parens=True, compound_index=0, **kwargs):
+ def visit_compound_select(
+ self, cs, asfrom=False, parens=True, compound_index=0, **kwargs
+ ):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- need_result_map = toplevel or \
- (compound_index == 0
- and entry.get('need_result_map_for_compound', False))
+ need_result_map = toplevel or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
self.stack.append(
{
- 'correlate_froms': entry['correlate_froms'],
- 'asfrom_froms': entry['asfrom_froms'],
- 'selectable': cs,
- 'need_result_map_for_compound': need_result_map
- })
+ "correlate_froms": entry["correlate_froms"],
+ "asfrom_froms": entry["asfrom_froms"],
+ "selectable": cs,
+ "need_result_map_for_compound": need_result_map,
+ }
+ )
keyword = self.compound_keywords.get(cs.keyword)
text = (" " + keyword + " ").join(
- (c._compiler_dispatch(self,
- asfrom=asfrom, parens=False,
- compound_index=i, **kwargs)
- for i, c in enumerate(cs.selects))
+ (
+ c._compiler_dispatch(
+ self,
+ asfrom=asfrom,
+ parens=False,
+ compound_index=i,
+ **kwargs
+ )
+ for i, c in enumerate(cs.selects)
+ )
)
text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
text += self.order_by_clause(cs, **kwargs)
- text += (cs._limit_clause is not None
- or cs._offset_clause is not None) and \
- self.limit_clause(cs, **kwargs) or ""
+ text += (
+ (cs._limit_clause is not None or cs._offset_clause is not None)
+ and self.limit_clause(cs, **kwargs)
+ or ""
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -990,8 +1142,10 @@ class SQLCompiler(Compiled):
def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
attrname = "visit_%s_%s%s" % (
- operator_.__name__, qualifier1,
- "_" + qualifier2 if qualifier2 else "")
+ operator_.__name__,
+ qualifier1,
+ "_" + qualifier2 if qualifier2 else "",
+ )
return getattr(self, attrname, None)
def visit_unary(self, unary, **kw):
@@ -999,51 +1153,63 @@ class SQLCompiler(Compiled):
if unary.modifier:
raise exc.CompileError(
"Unary expression does not support operator "
- "and modifier simultaneously")
+ "and modifier simultaneously"
+ )
disp = self._get_operator_dispatch(
- unary.operator, "unary", "operator")
+ unary.operator, "unary", "operator"
+ )
if disp:
return disp(unary, unary.operator, **kw)
else:
return self._generate_generic_unary_operator(
- unary, OPERATORS[unary.operator], **kw)
+ unary, OPERATORS[unary.operator], **kw
+ )
elif unary.modifier:
disp = self._get_operator_dispatch(
- unary.modifier, "unary", "modifier")
+ unary.modifier, "unary", "modifier"
+ )
if disp:
return disp(unary, unary.modifier, **kw)
else:
return self._generate_generic_unary_modifier(
- unary, OPERATORS[unary.modifier], **kw)
+ unary, OPERATORS[unary.modifier], **kw
+ )
else:
raise exc.CompileError(
- "Unary expression has no operator or modifier")
+ "Unary expression has no operator or modifier"
+ )
def visit_istrue_unary_operator(self, element, operator, **kw):
- if element._is_implicitly_boolean or \
- self.dialect.supports_native_boolean:
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
return self.process(element.element, **kw)
else:
return "%s = 1" % self.process(element.element, **kw)
def visit_isfalse_unary_operator(self, element, operator, **kw):
- if element._is_implicitly_boolean or \
- self.dialect.supports_native_boolean:
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
return "NOT %s" % self.process(element.element, **kw)
else:
return "%s = 0" % self.process(element.element, **kw)
def visit_notmatch_op_binary(self, binary, operator, **kw):
return "NOT %s" % self.visit_binary(
- binary, override_operator=operators.match_op)
+ binary, override_operator=operators.match_op
+ )
def _emit_empty_in_warning(self):
util.warn(
- 'The IN-predicate was invoked with an '
- 'empty sequence. This results in a '
- 'contradiction, which nonetheless can be '
- 'expensive to evaluate. Consider alternative '
- 'strategies for improved performance.')
+ "The IN-predicate was invoked with an "
+ "empty sequence. This results in a "
+ "contradiction, which nonetheless can be "
+ "expensive to evaluate. Consider alternative "
+ "strategies for improved performance."
+ )
def visit_empty_in_op_binary(self, binary, operator, **kw):
if self.dialect._use_static_in:
@@ -1063,18 +1229,21 @@ class SQLCompiler(Compiled):
def visit_empty_set_expr(self, element_types):
raise NotImplementedError(
- "Dialect '%s' does not support empty set expression." %
- self.dialect.name
+ "Dialect '%s' does not support empty set expression."
+ % self.dialect.name
)
- def visit_binary(self, binary, override_operator=None,
- eager_grouping=False, **kw):
+ def visit_binary(
+ self, binary, override_operator=None, eager_grouping=False, **kw
+ ):
# don't allow "? = ?" to render
- if self.ansi_bind_rules and \
- isinstance(binary.left, elements.BindParameter) and \
- isinstance(binary.right, elements.BindParameter):
- kw['literal_binds'] = True
+ if (
+ self.ansi_bind_rules
+ and isinstance(binary.left, elements.BindParameter)
+ and isinstance(binary.right, elements.BindParameter)
+ ):
+ kw["literal_binds"] = True
operator_ = override_operator or binary.operator
disp = self._get_operator_dispatch(operator_, "binary", None)
@@ -1093,36 +1262,50 @@ class SQLCompiler(Compiled):
def visit_mod_binary(self, binary, operator, **kw):
if self.preparer._double_percents:
- return self.process(binary.left, **kw) + " %% " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
else:
- return self.process(binary.left, **kw) + " % " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " % "
+ + self.process(binary.right, **kw)
+ )
def visit_custom_op_binary(self, element, operator, **kw):
- kw['eager_grouping'] = operator.eager_grouping
+ kw["eager_grouping"] = operator.eager_grouping
return self._generate_generic_binary(
- element, " " + operator.opstring + " ", **kw)
+ element, " " + operator.opstring + " ", **kw
+ )
def visit_custom_op_unary_operator(self, element, operator, **kw):
return self._generate_generic_unary_operator(
- element, operator.opstring + " ", **kw)
+ element, operator.opstring + " ", **kw
+ )
def visit_custom_op_unary_modifier(self, element, operator, **kw):
return self._generate_generic_unary_modifier(
- element, " " + operator.opstring, **kw)
+ element, " " + operator.opstring, **kw
+ )
def _generate_generic_binary(
- self, binary, opstring, eager_grouping=False, **kw):
+ self, binary, opstring, eager_grouping=False, **kw
+ ):
- _in_binary = kw.get('_in_binary', False)
+ _in_binary = kw.get("_in_binary", False)
- kw['_in_binary'] = True
- text = binary.left._compiler_dispatch(
- self, eager_grouping=eager_grouping, **kw) + \
- opstring + \
- binary.right._compiler_dispatch(
- self, eager_grouping=eager_grouping, **kw)
+ kw["_in_binary"] = True
+ text = (
+ binary.left._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ + opstring
+ + binary.right._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ )
if _in_binary and eager_grouping:
text = "(%s)" % text
@@ -1153,17 +1336,13 @@ class SQLCompiler(Compiled):
def visit_startswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
- binary.right = percent.__radd__(
- binary.right
- )
+ binary.right = percent.__radd__(binary.right)
return self.visit_like_op_binary(binary, operator, **kw)
def visit_notstartswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
- binary.right = percent.__radd__(
- binary.right
- )
+ binary.right = percent.__radd__(binary.right)
return self.visit_notlike_op_binary(binary, operator, **kw)
def visit_endswith_op_binary(self, binary, operator, **kw):
@@ -1182,98 +1361,105 @@ class SQLCompiler(Compiled):
escape = binary.modifiers.get("escape", None)
# TODO: use ternary here, not "and"/ "or"
- return '%s LIKE %s' % (
+ return "%s LIKE %s" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_notlike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return '%s NOT LIKE %s' % (
+ return "%s NOT LIKE %s" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) LIKE lower(%s)' % (
+ return "lower(%s) LIKE lower(%s)" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) NOT LIKE lower(%s)' % (
+ return "lower(%s) NOT LIKE lower(%s)" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_between_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
- binary, " BETWEEN SYMMETRIC "
- if symmetric else " BETWEEN ", **kw)
+ binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
+ )
def visit_notbetween_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
- binary, " NOT BETWEEN SYMMETRIC "
- if symmetric else " NOT BETWEEN ", **kw)
+ binary,
+ " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
+ **kw
+ )
- def visit_bindparam(self, bindparam, within_columns_clause=False,
- literal_binds=False,
- skip_bind_expression=False,
- **kwargs):
+ def visit_bindparam(
+ self,
+ bindparam,
+ within_columns_clause=False,
+ literal_binds=False,
+ skip_bind_expression=False,
+ **kwargs
+ ):
if not skip_bind_expression:
impl = bindparam.type.dialect_impl(self.dialect)
if impl._has_bind_expression:
bind_expression = impl.bind_expression(bindparam)
return self.process(
- bind_expression, skip_bind_expression=True,
+ bind_expression,
+ skip_bind_expression=True,
within_columns_clause=within_columns_clause,
literal_binds=literal_binds,
**kwargs
)
- if literal_binds or \
- (within_columns_clause and
- self.ansi_bind_rules):
+ if literal_binds or (within_columns_clause and self.ansi_bind_rules):
if bindparam.value is None and bindparam.callable is None:
- raise exc.CompileError("Bind parameter '%s' without a "
- "renderable value not allowed here."
- % bindparam.key)
+ raise exc.CompileError(
+ "Bind parameter '%s' without a "
+ "renderable value not allowed here." % bindparam.key
+ )
return self.render_literal_bindparam(
- bindparam, within_columns_clause=True, **kwargs)
+ bindparam, within_columns_clause=True, **kwargs
+ )
name = self._truncate_bindparam(bindparam)
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam:
- if (existing.unique or bindparam.unique) and \
- not existing.proxy_set.intersection(
- bindparam.proxy_set):
+ if (
+ existing.unique or bindparam.unique
+ ) and not existing.proxy_set.intersection(bindparam.proxy_set):
raise exc.CompileError(
"Bind parameter '%s' conflicts with "
- "unique bind parameter of the same name" %
- bindparam.key
+ "unique bind parameter of the same name"
+ % bindparam.key
)
elif existing._is_crud or bindparam._is_crud:
raise exc.CompileError(
@@ -1282,14 +1468,15 @@ class SQLCompiler(Compiled):
"clause of this "
"insert/update statement. Please use a "
"name other than column name when using bindparam() "
- "with insert() or update() (for example, 'b_%s')." %
- (bindparam.key, bindparam.key)
+ "with insert() or update() (for example, 'b_%s')."
+ % (bindparam.key, bindparam.key)
)
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(
- name, expanding=bindparam.expanding, **kwargs)
+ name, expanding=bindparam.expanding, **kwargs
+ )
def render_literal_bindparam(self, bindparam, **kw):
value = bindparam.effective_value
@@ -1311,7 +1498,8 @@ class SQLCompiler(Compiled):
return processor(value)
else:
raise NotImplementedError(
- "Don't know how to literal-quote value %r" % value)
+ "Don't know how to literal-quote value %r" % value
+ )
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
@@ -1334,8 +1522,11 @@ class SQLCompiler(Compiled):
if len(anonname) > self.label_length - 6:
counter = self.truncated_names.get(ident_class, 1)
- truncname = anonname[0:max(self.label_length - 6, 0)] + \
- "_" + hex(counter)[2:]
+ truncname = (
+ anonname[0 : max(self.label_length - 6, 0)]
+ + "_"
+ + hex(counter)[2:]
+ )
self.truncated_names[ident_class] = counter + 1
else:
truncname = anonname
@@ -1346,13 +1537,14 @@ class SQLCompiler(Compiled):
return name % self.anon_map
def _process_anon(self, key):
- (ident, derived) = key.split(' ', 1)
+ (ident, derived) = key.split(" ", 1)
anonymous_counter = self.anon_map.get(derived, 1)
self.anon_map[derived] = anonymous_counter + 1
return derived + "_" + str(anonymous_counter)
def bindparam_string(
- self, name, positional_names=None, expanding=False, **kw):
+ self, name, positional_names=None, expanding=False, **kw
+ ):
if self.positional:
if positional_names is not None:
positional_names.append(name)
@@ -1362,14 +1554,20 @@ class SQLCompiler(Compiled):
self.contains_expanding_parameters = True
return "([EXPANDING_%s])" % name
else:
- return self.bindtemplate % {'name': name}
-
- def visit_cte(self, cte, asfrom=False, ashint=False,
- fromhints=None, visiting_cte=None,
- **kwargs):
+ return self.bindtemplate % {"name": name}
+
+ def visit_cte(
+ self,
+ cte,
+ asfrom=False,
+ ashint=False,
+ fromhints=None,
+ visiting_cte=None,
+ **kwargs
+ ):
self._init_cte_state()
- kwargs['visiting_cte'] = cte
+ kwargs["visiting_cte"] = cte
if isinstance(cte.name, elements._truncated_label):
cte_name = self._truncated_identifier("alias", cte.name)
else:
@@ -1394,8 +1592,8 @@ class SQLCompiler(Compiled):
else:
raise exc.CompileError(
"Multiple, unrelated CTEs found with "
- "the same name: %r" %
- cte_name)
+ "the same name: %r" % cte_name
+ )
if asfrom or is_new_cte:
if cte._cte_alias is not None:
@@ -1403,7 +1601,8 @@ class SQLCompiler(Compiled):
cte_pre_alias_name = cte._cte_alias.name
if isinstance(cte_pre_alias_name, elements._truncated_label):
cte_pre_alias_name = self._truncated_identifier(
- "alias", cte_pre_alias_name)
+ "alias", cte_pre_alias_name
+ )
else:
pre_alias_cte = cte
cte_pre_alias_name = None
@@ -1412,11 +1611,17 @@ class SQLCompiler(Compiled):
self.ctes_by_name[cte_name] = cte
# look for embedded DML ctes and propagate autocommit
- if 'autocommit' in cte.element._execution_options and \
- 'autocommit' not in self.execution_options:
+ if (
+ "autocommit" in cte.element._execution_options
+ and "autocommit" not in self.execution_options
+ ):
self.execution_options = self.execution_options.union(
- {"autocommit":
- cte.element._execution_options['autocommit']})
+ {
+ "autocommit": cte.element._execution_options[
+ "autocommit"
+ ]
+ }
+ )
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
@@ -1432,25 +1637,30 @@ class SQLCompiler(Compiled):
col_source = cte.original.selects[0]
else:
assert False
- recur_cols = [c for c in
- util.unique_list(col_source.inner_columns)
- if c is not None]
-
- text += "(%s)" % (", ".join(
- self.preparer.format_column(ident)
- for ident in recur_cols))
+ recur_cols = [
+ c
+ for c in util.unique_list(col_source.inner_columns)
+ if c is not None
+ ]
+
+ text += "(%s)" % (
+ ", ".join(
+ self.preparer.format_column(ident)
+ for ident in recur_cols
+ )
+ )
if self.positional:
- kwargs['positional_names'] = self.cte_positional[cte] = []
+ kwargs["positional_names"] = self.cte_positional[cte] = []
- text += " AS \n" + \
- cte.original._compiler_dispatch(
- self, asfrom=True, **kwargs
- )
+ text += " AS \n" + cte.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ )
if cte._suffixes:
text += " " + self._generate_prefixes(
- cte, cte._suffixes, **kwargs)
+ cte, cte._suffixes, **kwargs
+ )
self.ctes[cte] = text
@@ -1467,9 +1677,15 @@ class SQLCompiler(Compiled):
else:
return self.preparer.format_alias(cte, cte_name)
- def visit_alias(self, alias, asfrom=False, ashint=False,
- iscrud=False,
- fromhints=None, **kwargs):
+ def visit_alias(
+ self,
+ alias,
+ asfrom=False,
+ ashint=False,
+ iscrud=False,
+ fromhints=None,
+ **kwargs
+ ):
if asfrom or ashint:
if isinstance(alias.name, elements._truncated_label):
alias_name = self._truncated_identifier("alias", alias.name)
@@ -1479,31 +1695,35 @@ class SQLCompiler(Compiled):
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
- ret = alias.original._compiler_dispatch(self,
- asfrom=True, **kwargs) + \
- self.get_render_as_alias_suffix(
- self.preparer.format_alias(alias, alias_name))
+ ret = alias.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ ) + self.get_render_as_alias_suffix(
+ self.preparer.format_alias(alias, alias_name)
+ )
if fromhints and alias in fromhints:
- ret = self.format_from_hint_text(ret, alias,
- fromhints[alias], iscrud)
+ ret = self.format_from_hint_text(
+ ret, alias, fromhints[alias], iscrud
+ )
return ret
else:
return alias.original._compiler_dispatch(self, **kwargs)
def visit_lateral(self, lateral, **kw):
- kw['lateral'] = True
+ kw["lateral"] = True
return "LATERAL %s" % self.visit_alias(lateral, **kw)
def visit_tablesample(self, tablesample, asfrom=False, **kw):
text = "%s TABLESAMPLE %s" % (
self.visit_alias(tablesample, asfrom=True, **kw),
- tablesample._get_method()._compiler_dispatch(self, **kw))
+ tablesample._get_method()._compiler_dispatch(self, **kw),
+ )
if tablesample.seed is not None:
text += " REPEATABLE (%s)" % (
- tablesample.seed._compiler_dispatch(self, **kw))
+ tablesample.seed._compiler_dispatch(self, **kw)
+ )
return text
@@ -1513,22 +1733,27 @@ class SQLCompiler(Compiled):
def _add_to_result_map(self, keyname, name, objects, type_):
self._result_columns.append((keyname, name, objects, type_))
- def _label_select_column(self, select, column,
- populate_result_map,
- asfrom, column_clause_args,
- name=None,
- within_columns_clause=True):
+ def _label_select_column(
+ self,
+ select,
+ column,
+ populate_result_map,
+ asfrom,
+ column_clause_args,
+ name=None,
+ within_columns_clause=True,
+ ):
"""produce labeled columns present in a select()."""
impl = column.type.dialect_impl(self.dialect)
- if impl._has_column_expression and \
- populate_result_map:
+ if impl._has_column_expression and populate_result_map:
col_expr = impl.column_expression(column)
def add_to_result_map(keyname, name, objects, type_):
self._add_to_result_map(
- keyname, name,
- (column,) + objects, type_)
+ keyname, name, (column,) + objects, type_
+ )
+
else:
col_expr = column
if populate_result_map:
@@ -1541,58 +1766,56 @@ class SQLCompiler(Compiled):
elif isinstance(column, elements.Label):
if col_expr is not column:
result_expr = _CompileLabel(
- col_expr,
- column.name,
- alt_names=(column.element,)
+ col_expr, column.name, alt_names=(column.element,)
)
else:
result_expr = col_expr
elif select is not None and name:
result_expr = _CompileLabel(
+ col_expr, name, alt_names=(column._key_label,)
+ )
+
+ elif (
+ asfrom
+ and isinstance(column, elements.ColumnClause)
+ and not column.is_literal
+ and column.table is not None
+ and not isinstance(column.table, selectable.Select)
+ ):
+ result_expr = _CompileLabel(
col_expr,
- name,
- alt_names=(column._key_label,)
- )
-
- elif \
- asfrom and \
- isinstance(column, elements.ColumnClause) and \
- not column.is_literal and \
- column.table is not None and \
- not isinstance(column.table, selectable.Select):
- result_expr = _CompileLabel(col_expr,
- elements._as_truncated(column.name),
- alt_names=(column.key,))
+ elements._as_truncated(column.name),
+ alt_names=(column.key,),
+ )
elif (
- not isinstance(column, elements.TextClause) and
- (
- not isinstance(column, elements.UnaryExpression) or
- column.wraps_column_expression
- ) and
- (
- not hasattr(column, 'name') or
- isinstance(column, functions.Function)
+ not isinstance(column, elements.TextClause)
+ and (
+ not isinstance(column, elements.UnaryExpression)
+ or column.wraps_column_expression
+ )
+ and (
+ not hasattr(column, "name")
+ or isinstance(column, functions.Function)
)
):
result_expr = _CompileLabel(col_expr, column.anon_label)
elif col_expr is not column:
# TODO: are we sure "column" has a .name and .key here ?
# assert isinstance(column, elements.ColumnClause)
- result_expr = _CompileLabel(col_expr,
- elements._as_truncated(column.name),
- alt_names=(column.key,))
+ result_expr = _CompileLabel(
+ col_expr,
+ elements._as_truncated(column.name),
+ alt_names=(column.key,),
+ )
else:
result_expr = col_expr
column_clause_args.update(
within_columns_clause=within_columns_clause,
- add_to_result_map=add_to_result_map
- )
- return result_expr._compiler_dispatch(
- self,
- **column_clause_args
+ add_to_result_map=add_to_result_map,
)
+ return result_expr._compiler_dispatch(self, **column_clause_args)
def format_from_hint_text(self, sqltext, table, hint, iscrud):
hinttext = self.get_from_hint_text(table, hint)
@@ -1631,8 +1854,11 @@ class SQLCompiler(Compiled):
newelem = cloned[element] = element._clone()
- if newelem.is_selectable and newelem._is_join and \
- isinstance(newelem.right, selectable.FromGrouping):
+ if (
+ newelem.is_selectable
+ and newelem._is_join
+ and isinstance(newelem.right, selectable.FromGrouping)
+ ):
newelem._reset_exported()
newelem.left = visit(newelem.left, **kw)
@@ -1640,8 +1866,8 @@ class SQLCompiler(Compiled):
right = visit(newelem.right, **kw)
selectable_ = selectable.Select(
- [right.element],
- use_labels=True).alias()
+ [right.element], use_labels=True
+ ).alias()
for c in selectable_.c:
c._key_label = c.key
@@ -1680,17 +1906,18 @@ class SQLCompiler(Compiled):
elif newelem._is_from_container:
# if we hit an Alias, CompoundSelect or ScalarSelect, put a
# marker in the stack.
- kw['transform_clue'] = 'select_container'
+ kw["transform_clue"] = "select_container"
newelem._copy_internals(clone=visit, **kw)
elif newelem.is_selectable and newelem._is_select:
- barrier_select = kw.get('transform_clue', None) == \
- 'select_container'
+ barrier_select = (
+ kw.get("transform_clue", None) == "select_container"
+ )
# if we're still descended from an
# Alias/CompoundSelect/ScalarSelect, we're
# in a FROM clause, so start with a new translate collection
if barrier_select:
column_translate.append({})
- kw['transform_clue'] = 'inside_select'
+ kw["transform_clue"] = "inside_select"
newelem._copy_internals(clone=visit, **kw)
if barrier_select:
del column_translate[-1]
@@ -1702,24 +1929,22 @@ class SQLCompiler(Compiled):
return visit(select)
def _transform_result_map_for_nested_joins(
- self, select, transformed_select):
- inner_col = dict((c._key_label, c) for
- c in transformed_select.inner_columns)
-
- d = dict(
- (inner_col[c._key_label], c)
- for c in select.inner_columns
+ self, select, transformed_select
+ ):
+ inner_col = dict(
+ (c._key_label, c) for c in transformed_select.inner_columns
)
+ d = dict((inner_col[c._key_label], c) for c in select.inner_columns)
+
self._result_columns = [
(key, name, tuple([d.get(col, col) for col in objs]), typ)
for key, name, objs, typ in self._result_columns
]
- _default_stack_entry = util.immutabledict([
- ('correlate_froms', frozenset()),
- ('asfrom_froms', frozenset())
- ])
+ _default_stack_entry = util.immutabledict(
+ [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+ )
def _display_froms_for_select(self, select, asfrom, lateral=False):
# utility method to help external dialects
@@ -1729,72 +1954,88 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- correlate_froms = entry['correlate_froms']
- asfrom_froms = entry['asfrom_froms']
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
if asfrom and not lateral:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
- asfrom_froms),
- implicit_correlate_froms=())
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
else:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
+ implicit_correlate_froms=asfrom_froms,
+ )
return froms
- def visit_select(self, select, asfrom=False, parens=True,
- fromhints=None,
- compound_index=0,
- nested_join_translation=False,
- select_wraps_for=None,
- lateral=False,
- **kwargs):
-
- needs_nested_translation = \
- select.use_labels and \
- not nested_join_translation and \
- not self.stack and \
- not self.dialect.supports_right_nested_joins
+ def visit_select(
+ self,
+ select,
+ asfrom=False,
+ parens=True,
+ fromhints=None,
+ compound_index=0,
+ nested_join_translation=False,
+ select_wraps_for=None,
+ lateral=False,
+ **kwargs
+ ):
+
+ needs_nested_translation = (
+ select.use_labels
+ and not nested_join_translation
+ and not self.stack
+ and not self.dialect.supports_right_nested_joins
+ )
if needs_nested_translation:
transformed_select = self._transform_select_for_nested_joins(
- select)
+ select
+ )
text = self.visit_select(
- transformed_select, asfrom=asfrom, parens=parens,
+ transformed_select,
+ asfrom=asfrom,
+ parens=parens,
fromhints=fromhints,
compound_index=compound_index,
- nested_join_translation=True, **kwargs
+ nested_join_translation=True,
+ **kwargs
)
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- populate_result_map = toplevel or \
- (
- compound_index == 0 and entry.get(
- 'need_result_map_for_compound', False)
- ) or entry.get('need_result_map_for_nested', False)
+ populate_result_map = (
+ toplevel
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
# this was first proposed as part of #3372; however, it is not
# reached in current tests and could possibly be an assertion
# instead.
- if not populate_result_map and 'add_to_result_map' in kwargs:
- del kwargs['add_to_result_map']
+ if not populate_result_map and "add_to_result_map" in kwargs:
+ del kwargs["add_to_result_map"]
if needs_nested_translation:
if populate_result_map:
self._transform_result_map_for_nested_joins(
- select, transformed_select)
+ select, transformed_select
+ )
return text
froms = self._setup_select_stack(select, entry, asfrom, lateral)
column_clause_args = kwargs.copy()
- column_clause_args.update({
- 'within_label_clause': False,
- 'within_columns_clause': False
- })
+ column_clause_args.update(
+ {"within_label_clause": False, "within_columns_clause": False}
+ )
text = "SELECT " # we're off to a good start !
@@ -1806,19 +2047,21 @@ class SQLCompiler(Compiled):
byfrom = None
if select._prefixes:
- text += self._generate_prefixes(
- select, select._prefixes, **kwargs)
+ text += self._generate_prefixes(select, select._prefixes, **kwargs)
text += self.get_select_precolumns(select, **kwargs)
# the actual list of columns to print in the SELECT column list.
inner_columns = [
- c for c in [
+ c
+ for c in [
self._label_select_column(
select,
column,
- populate_result_map, asfrom,
+ populate_result_map,
+ asfrom,
column_clause_args,
- name=name)
+ name=name,
+ )
for name, column in select._columns_plus_names
]
if c is not None
@@ -1831,8 +2074,11 @@ class SQLCompiler(Compiled):
translate = dict(
zip(
[name for (key, name) in select._columns_plus_names],
- [name for (key, name) in
- select_wraps_for._columns_plus_names])
+ [
+ name
+ for (key, name) in select_wraps_for._columns_plus_names
+ ],
+ )
)
self._result_columns = [
@@ -1841,13 +2087,14 @@ class SQLCompiler(Compiled):
]
text = self._compose_select_body(
- text, select, inner_columns, froms, byfrom, kwargs)
+ text, select, inner_columns, froms, byfrom, kwargs
+ )
if select._statement_hints:
per_dialect = [
- ht for (dialect_name, ht)
- in select._statement_hints
- if dialect_name in ('*', self.dialect.name)
+ ht
+ for (dialect_name, ht) in select._statement_hints
+ if dialect_name in ("*", self.dialect.name)
]
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
@@ -1857,7 +2104,8 @@ class SQLCompiler(Compiled):
if select._suffixes:
text += " " + self._generate_prefixes(
- select, select._suffixes, **kwargs)
+ select, select._suffixes, **kwargs
+ )
self.stack.pop(-1)
@@ -1867,60 +2115,73 @@ class SQLCompiler(Compiled):
return text
def _setup_select_hints(self, select):
- byfrom = dict([
- (from_, hinttext % {
- 'name': from_._compiler_dispatch(
- self, ashint=True)
- })
- for (from_, dialect), hinttext in
- select._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
+ byfrom = dict(
+ [
+ (
+ from_,
+ hinttext
+ % {"name": from_._compiler_dispatch(self, ashint=True)},
+ )
+ for (from_, dialect), hinttext in select._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
hint_text = self.get_select_hint_text(byfrom)
return hint_text, byfrom
def _setup_select_stack(self, select, entry, asfrom, lateral):
- correlate_froms = entry['correlate_froms']
- asfrom_froms = entry['asfrom_froms']
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
if asfrom and not lateral:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
- asfrom_froms),
- implicit_correlate_froms=())
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
else:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
+ implicit_correlate_froms=asfrom_froms,
+ )
new_correlate_froms = set(selectable._from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
new_entry = {
- 'asfrom_froms': new_correlate_froms,
- 'correlate_froms': all_correlate_froms,
- 'selectable': select,
+ "asfrom_froms": new_correlate_froms,
+ "correlate_froms": all_correlate_froms,
+ "selectable": select,
}
self.stack.append(new_entry)
return froms
def _compose_select_body(
- self, text, select, inner_columns, froms, byfrom, kwargs):
- text += ', '.join(inner_columns)
+ self, text, select, inner_columns, froms, byfrom, kwargs
+ ):
+ text += ", ".join(inner_columns)
if froms:
text += " \nFROM "
if select._hints:
- text += ', '.join(
- [f._compiler_dispatch(self, asfrom=True,
- fromhints=byfrom, **kwargs)
- for f in froms])
+ text += ", ".join(
+ [
+ f._compiler_dispatch(
+ self, asfrom=True, fromhints=byfrom, **kwargs
+ )
+ for f in froms
+ ]
+ )
else:
- text += ', '.join(
- [f._compiler_dispatch(self, asfrom=True, **kwargs)
- for f in froms])
+ text += ", ".join(
+ [
+ f._compiler_dispatch(self, asfrom=True, **kwargs)
+ for f in froms
+ ]
+ )
else:
text += self.default_from()
@@ -1940,8 +2201,10 @@ class SQLCompiler(Compiled):
if select._order_by_clause.clauses:
text += self.order_by_clause(select, **kwargs)
- if (select._limit_clause is not None or
- select._offset_clause is not None):
+ if (
+ select._limit_clause is not None
+ or select._offset_clause is not None
+ ):
text += self.limit_clause(select, **kwargs)
if select._for_update_arg is not None:
@@ -1953,8 +2216,7 @@ class SQLCompiler(Compiled):
clause = " ".join(
prefix._compiler_dispatch(self, **kw)
for prefix, dialect_name in prefixes
- if dialect_name is None or
- dialect_name == self.dialect.name
+ if dialect_name is None or dialect_name == self.dialect.name
)
if clause:
clause += " "
@@ -1962,14 +2224,12 @@ class SQLCompiler(Compiled):
def _render_cte_clause(self):
if self.positional:
- self.positiontup = sum([
- self.cte_positional[cte]
- for cte in self.ctes], []) + \
- self.positiontup
+ self.positiontup = (
+ sum([self.cte_positional[cte] for cte in self.ctes], [])
+ + self.positiontup
+ )
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
- cte_text += ", \n".join(
- [txt for txt in self.ctes.values()]
- )
+ cte_text += ", \n".join([txt for txt in self.ctes.values()])
cte_text += "\n "
return cte_text
@@ -2010,7 +2270,8 @@ class SQLCompiler(Compiled):
def returning_clause(self, stmt, returning_cols):
raise exc.CompileError(
"RETURNING is not supported by this "
- "dialect's statement compiler.")
+ "dialect's statement compiler."
+ )
def limit_clause(self, select, **kw):
text = ""
@@ -2022,19 +2283,31 @@ class SQLCompiler(Compiled):
text += " OFFSET " + self.process(select._offset_clause, **kw)
return text
- def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
- fromhints=None, use_schema=True, **kwargs):
+ def visit_table(
+ self,
+ table,
+ asfrom=False,
+ iscrud=False,
+ ashint=False,
+ fromhints=None,
+ use_schema=True,
+ **kwargs
+ ):
if asfrom or ashint:
effective_schema = self.preparer.schema_for_object(table)
if use_schema and effective_schema:
- ret = self.preparer.quote_schema(effective_schema) + \
- "." + self.preparer.quote(table.name)
+ ret = (
+ self.preparer.quote_schema(effective_schema)
+ + "."
+ + self.preparer.quote(table.name)
+ )
else:
ret = self.preparer.quote(table.name)
if fromhints and table in fromhints:
- ret = self.format_from_hint_text(ret, table,
- fromhints[table], iscrud)
+ ret = self.format_from_hint_text(
+ ret, table, fromhints[table], iscrud
+ )
return ret
else:
return ""
@@ -2047,26 +2320,24 @@ class SQLCompiler(Compiled):
else:
join_type = " JOIN "
return (
- join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
- join_type +
- join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
- " ON " +
- join.onclause._compiler_dispatch(self, **kwargs)
+ join.left._compiler_dispatch(self, asfrom=True, **kwargs)
+ + join_type
+ + join.right._compiler_dispatch(self, asfrom=True, **kwargs)
+ + " ON "
+ + join.onclause._compiler_dispatch(self, **kwargs)
)
def _setup_crud_hints(self, stmt, table_text):
- dialect_hints = dict([
- (table, hint_text)
- for (table, dialect), hint_text in
- stmt._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
+ dialect_hints = dict(
+ [
+ (table, hint_text)
+ for (table, dialect), hint_text in stmt._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
if stmt.table in dialect_hints:
table_text = self.format_from_hint_text(
- table_text,
- stmt.table,
- dialect_hints[stmt.table],
- True
+ table_text, stmt.table, dialect_hints[stmt.table], True
)
return dialect_hints, table_text
@@ -2074,28 +2345,35 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
self.stack.append(
- {'correlate_froms': set(),
- "asfrom_froms": set(),
- "selectable": insert_stmt})
+ {
+ "correlate_froms": set(),
+ "asfrom_froms": set(),
+ "selectable": insert_stmt,
+ }
+ )
crud_params = crud._setup_crud_params(
- self, insert_stmt, crud.ISINSERT, **kw)
+ self, insert_stmt, crud.ISINSERT, **kw
+ )
- if not crud_params and \
- not self.dialect.supports_default_values and \
- not self.dialect.supports_empty_insert:
- raise exc.CompileError("The '%s' dialect with current database "
- "version settings does not support empty "
- "inserts." %
- self.dialect.name)
+ if (
+ not crud_params
+ and not self.dialect.supports_default_values
+ and not self.dialect.supports_empty_insert
+ ):
+ raise exc.CompileError(
+ "The '%s' dialect with current database "
+ "version settings does not support empty "
+ "inserts." % self.dialect.name
+ )
if insert_stmt._has_multi_parameters:
if not self.dialect.supports_multivalues_insert:
raise exc.CompileError(
"The '%s' dialect with current database "
"version settings does not support "
- "in-place multirow inserts." %
- self.dialect.name)
+ "in-place multirow inserts." % self.dialect.name
+ )
crud_params_single = crud_params[0]
else:
crud_params_single = crud_params
@@ -2106,27 +2384,31 @@ class SQLCompiler(Compiled):
text = "INSERT "
if insert_stmt._prefixes:
- text += self._generate_prefixes(insert_stmt,
- insert_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ insert_stmt, insert_stmt._prefixes, **kw
+ )
text += "INTO "
table_text = preparer.format_table(insert_stmt.table)
if insert_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- insert_stmt, table_text)
+ insert_stmt, table_text
+ )
else:
dialect_hints = None
text += table_text
if crud_params_single or not supports_default_values:
- text += " (%s)" % ', '.join([preparer.format_column(c[0])
- for c in crud_params_single])
+ text += " (%s)" % ", ".join(
+ [preparer.format_column(c[0]) for c in crud_params_single]
+ )
if self.returning or insert_stmt._returning:
returning_clause = self.returning_clause(
- insert_stmt, self.returning or insert_stmt._returning)
+ insert_stmt, self.returning or insert_stmt._returning
+ )
if self.returning_precedes_values:
text += " " + returning_clause
@@ -2145,19 +2427,17 @@ class SQLCompiler(Compiled):
elif insert_stmt._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
- "(%s)" % (
- ', '.join(c[1] for c in crud_param_set)
- )
+ "(%s)" % (", ".join(c[1] for c in crud_param_set))
for crud_param_set in crud_params
)
)
else:
- text += " VALUES (%s)" % \
- ', '.join([c[1] for c in crud_params])
+ text += " VALUES (%s)" % ", ".join([c[1] for c in crud_params])
if insert_stmt._post_values_clause is not None:
post_values_clause = self.process(
- insert_stmt._post_values_clause, **kw)
+ insert_stmt._post_values_clause, **kw
+ )
if post_values_clause:
text += " " + post_values_clause
@@ -2178,21 +2458,19 @@ class SQLCompiler(Compiled):
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
return None
- def update_tables_clause(self, update_stmt, from_table,
- extra_froms, **kw):
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
"""Provide a hook to override the initial table clause
in an UPDATE statement.
MySQL overrides this.
"""
- kw['asfrom'] = True
+ kw["asfrom"] = True
return from_table._compiler_dispatch(self, iscrud=True, **kw)
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Provide a hook to override the generation of an
UPDATE..FROM clause.
@@ -2201,7 +2479,8 @@ class SQLCompiler(Compiled):
"""
raise NotImplementedError(
"This backend does not support multiple-table "
- "criteria within UPDATE")
+ "criteria within UPDATE"
+ )
def visit_update(self, update_stmt, asfrom=False, **kw):
toplevel = not self.stack
@@ -2221,49 +2500,61 @@ class SQLCompiler(Compiled):
correlate_froms = {update_stmt.table}
self.stack.append(
- {'correlate_froms': correlate_froms,
- "asfrom_froms": correlate_froms,
- "selectable": update_stmt})
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": update_stmt,
+ }
+ )
text = "UPDATE "
if update_stmt._prefixes:
- text += self._generate_prefixes(update_stmt,
- update_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ update_stmt, update_stmt._prefixes, **kw
+ )
- table_text = self.update_tables_clause(update_stmt, update_stmt.table,
- render_extra_froms, **kw)
+ table_text = self.update_tables_clause(
+ update_stmt, update_stmt.table, render_extra_froms, **kw
+ )
crud_params = crud._setup_crud_params(
- self, update_stmt, crud.ISUPDATE, **kw)
+ self, update_stmt, crud.ISUPDATE, **kw
+ )
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- update_stmt, table_text)
+ update_stmt, table_text
+ )
else:
dialect_hints = None
text += table_text
- text += ' SET '
- include_table = is_multitable and \
- self.render_table_with_column_in_update_from
- text += ', '.join(
- c[0]._compiler_dispatch(self,
- include_table=include_table) +
- '=' + c[1] for c in crud_params
+ text += " SET "
+ include_table = (
+ is_multitable and self.render_table_with_column_in_update_from
+ )
+ text += ", ".join(
+ c[0]._compiler_dispatch(self, include_table=include_table)
+ + "="
+ + c[1]
+ for c in crud_params
)
if self.returning or update_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning or update_stmt._returning)
+ update_stmt, self.returning or update_stmt._returning
+ )
if extra_froms:
extra_from_text = self.update_from_clause(
update_stmt,
update_stmt.table,
render_extra_froms,
- dialect_hints, **kw)
+ dialect_hints,
+ **kw
+ )
if extra_from_text:
text += " " + extra_from_text
@@ -2276,10 +2567,12 @@ class SQLCompiler(Compiled):
if limit_clause:
text += " " + limit_clause
- if (self.returning or update_stmt._returning) and \
- not self.returning_precedes_values:
+ if (
+ self.returning or update_stmt._returning
+ ) and not self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning or update_stmt._returning)
+ update_stmt, self.returning or update_stmt._returning
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -2295,9 +2588,9 @@ class SQLCompiler(Compiled):
def _key_getters_for_crud_column(self):
return crud._key_getters_for_crud_column(self, self.statement)
- def delete_extra_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints, **kw):
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Provide a hook to override the generation of an
DELETE..FROM clause.
@@ -2308,10 +2601,10 @@ class SQLCompiler(Compiled):
"""
raise NotImplementedError(
"This backend does not support multiple-table "
- "criteria within DELETE")
+ "criteria within DELETE"
+ )
- def delete_table_clause(self, delete_stmt, from_table,
- extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
def visit_delete(self, delete_stmt, asfrom=False, **kw):
@@ -2322,23 +2615,30 @@ class SQLCompiler(Compiled):
extra_froms = delete_stmt._extra_froms
correlate_froms = {delete_stmt.table}.union(extra_froms)
- self.stack.append({'correlate_froms': correlate_froms,
- "asfrom_froms": correlate_froms,
- "selectable": delete_stmt})
+ self.stack.append(
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": delete_stmt,
+ }
+ )
text = "DELETE "
if delete_stmt._prefixes:
- text += self._generate_prefixes(delete_stmt,
- delete_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ delete_stmt, delete_stmt._prefixes, **kw
+ )
text += "FROM "
- table_text = self.delete_table_clause(delete_stmt, delete_stmt.table,
- extra_froms)
+ table_text = self.delete_table_clause(
+ delete_stmt, delete_stmt.table, extra_froms
+ )
if delete_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- delete_stmt, table_text)
+ delete_stmt, table_text
+ )
else:
dialect_hints = None
@@ -2347,14 +2647,17 @@ class SQLCompiler(Compiled):
if delete_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
- delete_stmt, delete_stmt._returning)
+ delete_stmt, delete_stmt._returning
+ )
if extra_froms:
extra_from_text = self.delete_extra_from_clause(
delete_stmt,
delete_stmt.table,
extra_froms,
- dialect_hints, **kw)
+ dialect_hints,
+ **kw
+ )
if extra_from_text:
text += " " + extra_from_text
@@ -2365,7 +2668,8 @@ class SQLCompiler(Compiled):
if delete_stmt._returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
- delete_stmt, delete_stmt._returning)
+ delete_stmt, delete_stmt._returning
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -2381,12 +2685,14 @@ class SQLCompiler(Compiled):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def visit_rollback_to_savepoint(self, savepoint_stmt):
- return "ROLLBACK TO SAVEPOINT %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
def visit_release_savepoint(self, savepoint_stmt):
- return "RELEASE SAVEPOINT %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
class StrSQLCompiler(SQLCompiler):
@@ -2403,7 +2709,7 @@ class StrSQLCompiler(SQLCompiler):
def visit_getitem_binary(self, binary, operator, **kw):
return "%s[%s]" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw)
+ self.process(binary.right, **kw),
)
def visit_json_getitem_op_binary(self, binary, operator, **kw):
@@ -2421,29 +2727,26 @@ class StrSQLCompiler(SQLCompiler):
for c in elements._select_iterables(returning_cols)
]
- return 'RETURNING ' + ', '.join(columns)
+ return "RETURNING " + ", ".join(columns)
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
- return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in extra_froms)
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
- def delete_extra_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
- return ', ' + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in extra_froms)
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ return ", " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
class DDLCompiler(Compiled):
-
@util.memoized_property
def sql_compiler(self):
return self.dialect.statement_compiler(self.dialect, None)
@@ -2464,13 +2767,13 @@ class DDLCompiler(Compiled):
preparer = self.preparer
path = preparer.format_table_seq(ddl.target)
if len(path) == 1:
- table, sch = path[0], ''
+ table, sch = path[0], ""
else:
table, sch = path[-1], path[0]
- context.setdefault('table', table)
- context.setdefault('schema', sch)
- context.setdefault('fullname', preparer.format_table(ddl.target))
+ context.setdefault("table", table)
+ context.setdefault("schema", sch)
+ context.setdefault("fullname", preparer.format_table(ddl.target))
return self.sql_compiler.post_process_text(ddl.statement % context)
@@ -2507,9 +2810,9 @@ class DDLCompiler(Compiled):
for create_column in create.columns:
column = create_column.element
try:
- processed = self.process(create_column,
- first_pk=column.primary_key
- and not first_pk)
+ processed = self.process(
+ create_column, first_pk=column.primary_key and not first_pk
+ )
if processed is not None:
text += separator
separator = ", \n"
@@ -2519,13 +2822,15 @@ class DDLCompiler(Compiled):
except exc.CompileError as ce:
util.raise_from_cause(
exc.CompileError(
- util.u("(in table '%s', column '%s'): %s") %
- (table.description, column.name, ce.args[0])
- ))
+ util.u("(in table '%s', column '%s'): %s")
+ % (table.description, column.name, ce.args[0])
+ )
+ )
const = self.create_table_constraints(
- table, _include_foreign_key_constraints= # noqa
- create.include_foreign_key_constraints)
+ table,
+ _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
+ )
if const:
text += separator + "\t" + const
@@ -2538,20 +2843,18 @@ class DDLCompiler(Compiled):
if column.system:
return None
- text = self.get_column_specification(
- column,
- first_pk=first_pk
+ text = self.get_column_specification(column, first_pk=first_pk)
+ const = " ".join(
+ self.process(constraint) for constraint in column.constraints
)
- const = " ".join(self.process(constraint)
- for constraint in column.constraints)
if const:
text += " " + const
return text
def create_table_constraints(
- self, table,
- _include_foreign_key_constraints=None):
+ self, table, _include_foreign_key_constraints=None
+ ):
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
@@ -2565,21 +2868,29 @@ class DDLCompiler(Compiled):
else:
omit_fkcs = set()
- constraints.extend([c for c in table._sorted_constraints
- if c is not table.primary_key and
- c not in omit_fkcs])
+ constraints.extend(
+ [
+ c
+ for c in table._sorted_constraints
+ if c is not table.primary_key and c not in omit_fkcs
+ ]
+ )
return ", \n\t".join(
- p for p in
- (self.process(constraint)
+ p
+ for p in (
+ self.process(constraint)
for constraint in constraints
if (
- constraint._create_rule is None or
- constraint._create_rule(self))
+ constraint._create_rule is None
+ or constraint._create_rule(self)
+ )
and (
- not self.dialect.supports_alter or
- not getattr(constraint, 'use_alter', False)
- )) if p is not None
+ not self.dialect.supports_alter
+ or not getattr(constraint, "use_alter", False)
+ )
+ )
+ if p is not None
)
def visit_drop_table(self, drop):
@@ -2590,34 +2901,38 @@ class DDLCompiler(Compiled):
def _verify_index_table(self, index):
if index.table is None:
- raise exc.CompileError("Index '%s' is not associated "
- "with any table." % index.name)
+ raise exc.CompileError(
+ "Index '%s' is not associated " "with any table." % index.name
+ )
- def visit_create_index(self, create, include_schema=False,
- include_table_schema=True):
+ def visit_create_index(
+ self, create, include_schema=False, include_table_schema=True
+ ):
index = create.element
self._verify_index_table(index)
preparer = self.preparer
text = "CREATE "
if index.unique:
text += "UNIQUE "
- text += "INDEX %s ON %s (%s)" \
- % (
- self._prepared_index_name(index,
- include_schema=include_schema),
- preparer.format_table(index.table,
- use_schema=include_table_schema),
- ', '.join(
- self.sql_compiler.process(
- expr, include_table=False, literal_binds=True) for
- expr in index.expressions)
- )
+ text += "INDEX %s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=include_schema),
+ preparer.format_table(
+ index.table, use_schema=include_table_schema
+ ),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
return text
def visit_drop_index(self, drop):
index = drop.element
return "\nDROP INDEX " + self._prepared_index_name(
- index, include_schema=True)
+ index, include_schema=True
+ )
def _prepared_index_name(self, index, include_schema=False):
if index.table is not None:
@@ -2638,35 +2953,41 @@ class DDLCompiler(Compiled):
def visit_add_constraint(self, create):
return "ALTER TABLE %s ADD %s" % (
self.preparer.format_table(create.element.table),
- self.process(create.element)
+ self.process(create.element),
)
def visit_set_table_comment(self, create):
return "COMMENT ON TABLE %s IS %s" % (
self.preparer.format_table(create.element),
self.sql_compiler.render_literal_value(
- create.element.comment, sqltypes.String())
+ create.element.comment, sqltypes.String()
+ ),
)
def visit_drop_table_comment(self, drop):
- return "COMMENT ON TABLE %s IS NULL" % \
- self.preparer.format_table(drop.element)
+ return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
+ drop.element
+ )
def visit_set_column_comment(self, create):
return "COMMENT ON COLUMN %s IS %s" % (
self.preparer.format_column(
- create.element, use_table=True, use_schema=True),
+ create.element, use_table=True, use_schema=True
+ ),
self.sql_compiler.render_literal_value(
- create.element.comment, sqltypes.String())
+ create.element.comment, sqltypes.String()
+ ),
)
def visit_drop_column_comment(self, drop):
- return "COMMENT ON COLUMN %s IS NULL" % \
- self.preparer.format_column(drop.element, use_table=True)
+ return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
+ drop.element, use_table=True
+ )
def visit_create_sequence(self, create):
- text = "CREATE SEQUENCE %s" % \
- self.preparer.format_sequence(create.element)
+ text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(
+ create.element
+ )
if create.element.increment is not None:
text += " INCREMENT BY %d" % create.element.increment
if create.element.start is not None:
@@ -2688,8 +3009,7 @@ class DDLCompiler(Compiled):
return text
def visit_drop_sequence(self, drop):
- return "DROP SEQUENCE %s" % \
- self.preparer.format_sequence(drop.element)
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
def visit_drop_constraint(self, drop):
constraint = drop.element
@@ -2701,17 +3021,22 @@ class DDLCompiler(Compiled):
if formatted_name is None:
raise exc.CompileError(
"Can't emit DROP CONSTRAINT for constraint %r; "
- "it has no name" % drop.element)
+ "it has no name" % drop.element
+ )
return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
self.preparer.format_table(drop.element.table),
formatted_name,
- drop.cascade and " CASCADE" or ""
+ drop.cascade and " CASCADE" or "",
)
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + \
- self.dialect.type_compiler.process(
- column.type, type_expression=column)
+ colspec = (
+ self.preparer.format_column(column)
+ + " "
+ + self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+ )
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -2721,19 +3046,21 @@ class DDLCompiler(Compiled):
return colspec
def create_table_suffix(self, table):
- return ''
+ return ""
def post_create_table(self, table):
- return ''
+ return ""
def get_column_default_string(self, column):
if isinstance(column.server_default, schema.DefaultClause):
if isinstance(column.server_default.arg, util.string_types):
return self.sql_compiler.render_literal_value(
- column.server_default.arg, sqltypes.STRINGTYPE)
+ column.server_default.arg, sqltypes.STRINGTYPE
+ )
else:
return self.sql_compiler.process(
- column.server_default.arg, literal_binds=True)
+ column.server_default.arg, literal_binds=True
+ )
else:
return None
@@ -2743,9 +3070,9 @@ class DDLCompiler(Compiled):
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
- text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
- include_table=False,
- literal_binds=True)
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2755,25 +3082,29 @@ class DDLCompiler(Compiled):
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
- text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
- include_table=False,
- literal_binds=True)
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
text += self.define_constraint_deferrability(constraint)
return text
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
text += "PRIMARY KEY "
- text += "(%s)" % ', '.join(self.preparer.quote(c.name)
- for c in (constraint.columns_autoinc_first
- if constraint._implicit_generated
- else constraint.columns))
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name)
+ for c in (
+ constraint.columns_autoinc_first
+ if constraint._implicit_generated
+ else constraint.columns
+ )
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2786,12 +3117,15 @@ class DDLCompiler(Compiled):
text += "CONSTRAINT %s " % formatted_name
remote_table = list(constraint.elements)[0].column.table
text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- ', '.join(preparer.quote(f.parent.name)
- for f in constraint.elements),
+ ", ".join(
+ preparer.quote(f.parent.name) for f in constraint.elements
+ ),
self.define_constraint_remote_table(
- constraint, remote_table, preparer),
- ', '.join(preparer.quote(f.column.name)
- for f in constraint.elements)
+ constraint, remote_table, preparer
+ ),
+ ", ".join(
+ preparer.quote(f.column.name) for f in constraint.elements
+ ),
)
text += self.define_constraint_match(constraint)
text += self.define_constraint_cascades(constraint)
@@ -2805,14 +3139,14 @@ class DDLCompiler(Compiled):
def visit_unique_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
text += "CONSTRAINT %s " % formatted_name
text += "UNIQUE (%s)" % (
- ', '.join(self.preparer.quote(c.name)
- for c in constraint))
+ ", ".join(self.preparer.quote(c.name) for c in constraint)
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2843,7 +3177,6 @@ class DDLCompiler(Compiled):
class GenericTypeCompiler(TypeCompiler):
-
def visit_FLOAT(self, type_, **kw):
return "FLOAT"
@@ -2854,23 +3187,23 @@ class GenericTypeCompiler(TypeCompiler):
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
- return "NUMERIC(%(precision)s)" % \
- {'precision': type_.precision}
+ return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
else:
- return "NUMERIC(%(precision)s, %(scale)s)" % \
- {'precision': type_.precision,
- 'scale': type_.scale}
+ return "NUMERIC(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return "DECIMAL"
elif type_.scale is None:
- return "DECIMAL(%(precision)s)" % \
- {'precision': type_.precision}
+ return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
else:
- return "DECIMAL(%(precision)s, %(scale)s)" % \
- {'precision': type_.precision,
- 'scale': type_.scale}
+ return "DECIMAL(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
def visit_INTEGER(self, type_, **kw):
return "INTEGER"
@@ -2882,7 +3215,7 @@ class GenericTypeCompiler(TypeCompiler):
return "BIGINT"
def visit_TIMESTAMP(self, type_, **kw):
- return 'TIMESTAMP'
+ return "TIMESTAMP"
def visit_DATETIME(self, type_, **kw):
return "DATETIME"
@@ -2984,9 +3317,11 @@ class GenericTypeCompiler(TypeCompiler):
return self.visit_VARCHAR(type_, **kw)
def visit_null(self, type_, **kw):
- raise exc.CompileError("Can't generate DDL for %r; "
- "did you forget to specify a "
- "type on this Column?" % type_)
+ raise exc.CompileError(
+ "Can't generate DDL for %r; "
+ "did you forget to specify a "
+ "type on this Column?" % type_
+ )
def visit_type_decorator(self, type_, **kw):
return self.process(type_.type_engine(self.dialect), **kw)
@@ -3018,9 +3353,15 @@ class IdentifierPreparer(object):
schema_for_object = schema._schema_getter(None)
- def __init__(self, dialect, initial_quote='"',
- final_quote=None, escape_quote='"',
- quote_case_sensitive_collations=True, omit_schema=False):
+ def __init__(
+ self,
+ dialect,
+ initial_quote='"',
+ final_quote=None,
+ escape_quote='"',
+ quote_case_sensitive_collations=True,
+ omit_schema=False,
+ ):
"""Construct a new ``IdentifierPreparer`` object.
initial_quote
@@ -3043,7 +3384,10 @@ class IdentifierPreparer(object):
self.omit_schema = omit_schema
self.quote_case_sensitive_collations = quote_case_sensitive_collations
self._strings = {}
- self._double_percents = self.dialect.paramstyle in ('format', 'pyformat')
+ self._double_percents = self.dialect.paramstyle in (
+ "format",
+ "pyformat",
+ )
def _with_schema_translate(self, schema_translate_map):
prep = self.__class__.__new__(self.__class__)
@@ -3060,7 +3404,7 @@ class IdentifierPreparer(object):
value = value.replace(self.escape_quote, self.escape_to_quote)
if self._double_percents:
- value = value.replace('%', '%%')
+ value = value.replace("%", "%%")
return value
def _unescape_identifier(self, value):
@@ -3079,17 +3423,21 @@ class IdentifierPreparer(object):
quoting behavior.
"""
- return self.initial_quote + \
- self._escape_identifier(value) + \
- self.final_quote
+ return (
+ self.initial_quote
+ + self._escape_identifier(value)
+ + self.final_quote
+ )
def _requires_quotes(self, value):
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
- return (lc_value in self.reserved_words
- or value[0] in self.illegal_initial_characters
- or not self.legal_characters.match(util.text_type(value))
- or (lc_value != value))
+ return (
+ lc_value in self.reserved_words
+ or value[0] in self.illegal_initial_characters
+ or not self.legal_characters.match(util.text_type(value))
+ or (lc_value != value)
+ )
def quote_schema(self, schema, force=None):
"""Conditionally quote a schema.
@@ -3135,8 +3483,11 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(sequence)
- if (not self.omit_schema and use_schema and
- effective_schema is not None):
+ if (
+ not self.omit_schema
+ and use_schema
+ and effective_schema is not None
+ ):
name = self.quote_schema(effective_schema) + "." + name
return name
@@ -3159,7 +3510,8 @@ class IdentifierPreparer(object):
def format_constraint(self, naming, constraint):
if isinstance(constraint.name, elements._defer_name):
name = naming._constraint_name_for_table(
- constraint, constraint.table)
+ constraint, constraint.table
+ )
if name is None:
if isinstance(constraint.name, elements._defer_none_name):
@@ -3170,14 +3522,15 @@ class IdentifierPreparer(object):
name = constraint.name
if isinstance(name, elements._truncated_label):
- if constraint.__visit_name__ == 'index':
- max_ = self.dialect.max_index_name_length or \
- self.dialect.max_identifier_length
+ if constraint.__visit_name__ == "index":
+ max_ = (
+ self.dialect.max_index_name_length
+ or self.dialect.max_identifier_length
+ )
else:
max_ = self.dialect.max_identifier_length
if len(name) > max_:
- name = name[0:max_ - 8] + \
- "_" + util.md5_hex(name)[-4:]
+ name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
else:
self.dialect.validate_identifier(name)
@@ -3195,8 +3548,7 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(table)
- if not self.omit_schema and use_schema \
- and effective_schema:
+ if not self.omit_schema and use_schema and effective_schema:
result = self.quote_schema(effective_schema) + "." + result
return result
@@ -3205,17 +3557,27 @@ class IdentifierPreparer(object):
return self.quote(name, quote)
- def format_column(self, column, use_table=False,
- name=None, table_name=None, use_schema=False):
+ def format_column(
+ self,
+ column,
+ use_table=False,
+ name=None,
+ table_name=None,
+ use_schema=False,
+ ):
"""Prepare a quoted column name."""
if name is None:
name = column.name
- if not getattr(column, 'is_literal', False):
+ if not getattr(column, "is_literal", False):
if use_table:
- return self.format_table(
- column.table, use_schema=use_schema,
- name=table_name) + "." + self.quote(name)
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + self.quote(name)
+ )
else:
return self.quote(name)
else:
@@ -3223,9 +3585,13 @@ class IdentifierPreparer(object):
# which shouldn't get quoted
if use_table:
- return self.format_table(
- column.table, use_schema=use_schema,
- name=table_name) + '.' + name
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + name
+ )
else:
return name
@@ -3238,31 +3604,37 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(table)
- if not self.omit_schema and use_schema and \
- effective_schema:
- return (self.quote_schema(effective_schema),
- self.format_table(table, use_schema=False))
+ if not self.omit_schema and use_schema and effective_schema:
+ return (
+ self.quote_schema(effective_schema),
+ self.format_table(table, use_schema=False),
+ )
else:
- return (self.format_table(table, use_schema=False), )
+ return (self.format_table(table, use_schema=False),)
@util.memoized_property
def _r_identifiers(self):
- initial, final, escaped_final = \
- [re.escape(s) for s in
- (self.initial_quote, self.final_quote,
- self._escape_identifier(self.final_quote))]
+ initial, final, escaped_final = [
+ re.escape(s)
+ for s in (
+ self.initial_quote,
+ self.final_quote,
+ self._escape_identifier(self.final_quote),
+ )
+ ]
r = re.compile(
- r'(?:'
- r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
- r'|([^\.]+))(?=\.|$))+' %
- {'initial': initial,
- 'final': final,
- 'escaped': escaped_final})
+ r"(?:"
+ r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
+ r"|([^\.]+))(?=\.|$))+"
+ % {"initial": initial, "final": final, "escaped": escaped_final}
+ )
return r
def unformat_identifiers(self, identifiers):
"""Unpack 'schema.table.column'-like strings into components."""
r = self._r_identifiers
- return [self._unescape_identifier(i)
- for i in [a or b for a, b in r.findall(identifiers)]]
+ return [
+ self._unescape_identifier(i)
+ for i in [a or b for a, b in r.findall(identifiers)]
+ ]
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 999d48a55..602b91a25 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -15,7 +15,9 @@ from . import dml
from . import elements
import operator
-REQUIRED = util.symbol('REQUIRED', """
+REQUIRED = util.symbol(
+ "REQUIRED",
+ """
Placeholder for the value within a :class:`.BindParameter`
which is required to be present when the statement is passed
to :meth:`.Connection.execute`.
@@ -24,11 +26,12 @@ This symbol is typically used when a :func:`.expression.insert`
or :func:`.expression.update` statement is compiled without parameter
values present.
-""")
+""",
+)
-ISINSERT = util.symbol('ISINSERT')
-ISUPDATE = util.symbol('ISUPDATE')
-ISDELETE = util.symbol('ISDELETE')
+ISINSERT = util.symbol("ISINSERT")
+ISUPDATE = util.symbol("ISUPDATE")
+ISDELETE = util.symbol("ISDELETE")
def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
@@ -82,8 +85,7 @@ def _get_crud_params(compiler, stmt, **kw):
# compiled params - return binds for all columns
if compiler.column_keys is None and stmt.parameters is None:
return [
- (c, _create_bind_param(
- compiler, c, None, required=True))
+ (c, _create_bind_param(compiler, c, None, required=True))
for c in stmt.table.columns
]
@@ -95,26 +97,28 @@ def _get_crud_params(compiler, stmt, **kw):
# getters - these are normally just column.key,
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
- _column_as_key, _getattr_col_key, _col_bind_name = \
- _key_getters_for_crud_column(compiler, stmt)
+ _column_as_key, _getattr_col_key, _col_bind_name = _key_getters_for_crud_column(
+ compiler, stmt
+ )
# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
parameters = {}
else:
- parameters = dict((_column_as_key(key), REQUIRED)
- for key in compiler.column_keys
- if not stmt_parameters or
- key not in stmt_parameters)
+ parameters = dict(
+ (_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if not stmt_parameters or key not in stmt_parameters
+ )
# create a list of column assignment clauses as tuples
values = []
if stmt_parameters is not None:
_get_stmt_parameters_params(
- compiler,
- parameters, stmt_parameters, _column_as_key, values, kw)
+ compiler, parameters, stmt_parameters, _column_as_key, values, kw
+ )
check_columns = {}
@@ -122,28 +126,51 @@ def _get_crud_params(compiler, stmt, **kw):
# statements
if compiler.isupdate and stmt._extra_froms and stmt_parameters:
_get_multitable_params(
- compiler, stmt, stmt_parameters, check_columns,
- _col_bind_name, _getattr_col_key, values, kw)
+ compiler,
+ stmt,
+ stmt_parameters,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+ )
if compiler.isinsert and stmt.select_names:
_scan_insert_from_select_cols(
- compiler, stmt, parameters,
- _getattr_col_key, _column_as_key,
- _col_bind_name, check_columns, values, kw)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
else:
_scan_cols(
- compiler, stmt, parameters,
- _getattr_col_key, _column_as_key,
- _col_bind_name, check_columns, values, kw)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
if parameters and stmt_parameters:
- check = set(parameters).intersection(
- _column_as_key(k) for k in stmt_parameters
- ).difference(check_columns)
+ check = (
+ set(parameters)
+ .intersection(_column_as_key(k) for k in stmt_parameters)
+ .difference(check_columns)
+ )
if check:
raise exc.CompileError(
- "Unconsumed column names: %s" %
- (", ".join("%s" % c for c in check))
+ "Unconsumed column names: %s"
+ % (", ".join("%s" % c for c in check))
)
if stmt._has_multi_parameters:
@@ -153,12 +180,13 @@ def _get_crud_params(compiler, stmt, **kw):
def _create_bind_param(
- compiler, col, value, process=True,
- required=False, name=None, **kw):
+ compiler, col, value, process=True, required=False, name=None, **kw
+):
if name is None:
name = col.key
bindparam = elements.BindParameter(
- name, value, type_=col.type, required=required)
+ name, value, type_=col.type, required=required
+ )
bindparam._is_crud = True
if process:
bindparam = bindparam._compiler_dispatch(compiler, **kw)
@@ -177,7 +205,7 @@ def _key_getters_for_crud_column(compiler, stmt):
def _column_as_key(key):
str_key = elements._column_as_key(key)
- if hasattr(key, 'table') and key.table in _et:
+ if hasattr(key, "table") and key.table in _et:
return (key.table.name, str_key)
else:
return str_key
@@ -202,15 +230,22 @@ def _key_getters_for_crud_column(compiler, stmt):
def _scan_insert_from_select_cols(
- compiler, stmt, parameters, _getattr_col_key,
- _column_as_key, _col_bind_name, check_columns, values, kw):
-
- need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid = \
- _get_returning_modifiers(compiler, stmt)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+
+ need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers(
+ compiler, stmt
+ )
- cols = [stmt.table.c[_column_as_key(name)]
- for name in stmt.select_names]
+ cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names]
compiler._insert_from_select = stmt.select
@@ -228,32 +263,39 @@ def _scan_insert_from_select_cols(
values.append((c, None))
else:
_append_param_insert_select_hasdefault(
- compiler, stmt, c, add_select_cols, kw)
+ compiler, stmt, c, add_select_cols, kw
+ )
if add_select_cols:
values.extend(add_select_cols)
compiler._insert_from_select = compiler._insert_from_select._generate()
- compiler._insert_from_select._raw_columns = \
- tuple(compiler._insert_from_select._raw_columns) + tuple(
- expr for col, expr in add_select_cols)
+ compiler._insert_from_select._raw_columns = tuple(
+ compiler._insert_from_select._raw_columns
+ ) + tuple(expr for col, expr in add_select_cols)
def _scan_cols(
- compiler, stmt, parameters, _getattr_col_key,
- _column_as_key, _col_bind_name, check_columns, values, kw):
-
- need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid = \
- _get_returning_modifiers(compiler, stmt)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+
+ need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers(
+ compiler, stmt
+ )
if stmt._parameter_ordering:
parameter_ordering = [
_column_as_key(key) for key in stmt._parameter_ordering
]
ordered_keys = set(parameter_ordering)
- cols = [
- stmt.table.c[key] for key in parameter_ordering
- ] + [
+ cols = [stmt.table.c[key] for key in parameter_ordering] + [
c for c in stmt.table.c if c.key not in ordered_keys
]
else:
@@ -265,72 +307,95 @@ def _scan_cols(
if col_key in parameters and col_key not in check_columns:
_append_param_parameter(
- compiler, stmt, c, col_key, parameters, _col_bind_name,
- implicit_returning, implicit_return_defaults, values, kw)
+ compiler,
+ stmt,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+ )
elif compiler.isinsert:
- if c.primary_key and \
- need_pks and \
- (
- implicit_returning or
- not postfetch_lastrowid or
- c is not stmt.table._autoincrement_column
- ):
+ if (
+ c.primary_key
+ and need_pks
+ and (
+ implicit_returning
+ or not postfetch_lastrowid
+ or c is not stmt.table._autoincrement_column
+ )
+ ):
if implicit_returning:
_append_param_insert_pk_returning(
- compiler, stmt, c, values, kw)
+ compiler, stmt, c, values, kw
+ )
else:
_append_param_insert_pk(compiler, stmt, c, values, kw)
elif c.default is not None:
_append_param_insert_hasdefault(
- compiler, stmt, c, implicit_return_defaults,
- values, kw)
+ compiler, stmt, c, implicit_return_defaults, values, kw
+ )
elif c.server_default is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
+ elif implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
- elif c.primary_key and \
- c is not stmt.table._autoincrement_column and \
- not c.nullable:
+ elif (
+ c.primary_key
+ and c is not stmt.table._autoincrement_column
+ and not c.nullable
+ ):
_warn_pk_with_no_anticipated_value(c)
elif compiler.isupdate:
_append_param_update(
- compiler, stmt, c, implicit_return_defaults, values, kw)
+ compiler, stmt, c, implicit_return_defaults, values, kw
+ )
def _append_param_parameter(
- compiler, stmt, c, col_key, parameters, _col_bind_name,
- implicit_returning, implicit_return_defaults, values, kw):
+ compiler,
+ stmt,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+):
value = parameters.pop(col_key)
if elements._is_literal(value):
value = _create_bind_param(
- compiler, c, value, required=value is REQUIRED,
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
name=_col_bind_name(c)
if not stmt._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
**kw
)
else:
- if isinstance(value, elements.BindParameter) and \
- value.type._isnull:
+ if isinstance(value, elements.BindParameter) and value.type._isnull:
value = value._clone()
value.type = c.type
if c.primary_key and implicit_returning:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
+ elif implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
else:
@@ -358,22 +423,20 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
"""
if c.default is not None:
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional
+ or not compiler.dialect.sequences_optional
+ ):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
compiler.returning.append(c)
elif c.default.is_clause_element:
values.append(
- (c, compiler.process(
- c.default.arg.self_group(), **kw))
+ (c, compiler.process(c.default.arg.self_group(), **kw))
)
compiler.returning.append(c)
else:
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
elif c is stmt.table._autoincrement_column or c.server_default is not None:
compiler.returning.append(c)
elif not c.nullable:
@@ -405,9 +468,11 @@ class _multiparam_column(elements.ColumnElement):
self.type = original.type
def __eq__(self, other):
- return isinstance(other, _multiparam_column) and \
- other.key == self.key and \
- other.original == self.original
+ return (
+ isinstance(other, _multiparam_column)
+ and other.key == self.key
+ and other.original == self.original
+ )
def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
@@ -416,7 +481,8 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
raise exc.CompileError(
"INSERT value for column %s is explicitly rendered as a bound"
"parameter in the VALUES clause; "
- "a Python-side value or SQL expression is required" % c)
+ "a Python-side value or SQL expression is required" % c
+ )
elif c.default.is_clause_element:
return compiler.process(c.default.arg.self_group(), **kw)
else:
@@ -440,30 +506,24 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
"""
if (
- (
- # column has a Python-side default
- c.default is not None and
- (
- # and it won't be a Sequence
- not c.default.is_sequence or
- compiler.dialect.supports_sequences
- )
- )
- or
- (
- # column is the "autoincrement column"
- c is stmt.table._autoincrement_column and
- (
- # and it's either a "sequence" or a
- # pre-executable "autoincrement" sequence
- compiler.dialect.supports_sequences or
- compiler.dialect.preexecute_autoincrement_sequences
- )
- )
- ):
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
+ # column has a Python-side default
+ c.default is not None
+ and (
+ # and it won't be a Sequence
+ not c.default.is_sequence
+ or compiler.dialect.supports_sequences
)
+ ) or (
+ # column is the "autoincrement column"
+ c is stmt.table._autoincrement_column
+ and (
+ # and it's either a "sequence" or a
+ # pre-executable "autoincrement" sequence
+ compiler.dialect.supports_sequences
+ or compiler.dialect.preexecute_autoincrement_sequences
+ )
+ ):
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
elif c.default is None and c.server_default is None and not c.nullable:
# no .default, no .server_default, not autoincrement, we have
# no indication this primary key column will have any value
@@ -471,16 +531,16 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
def _append_param_insert_hasdefault(
- compiler, stmt, c, implicit_return_defaults, values, kw):
+ compiler, stmt, c, implicit_return_defaults, values, kw
+):
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
@@ -488,25 +548,21 @@ def _append_param_insert_hasdefault(
proc = compiler.process(c.default.arg.self_group(), **kw)
values.append((c, proc))
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
# don't add primary key column to postfetch
compiler.postfetch.append(c)
else:
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
-def _append_param_insert_select_hasdefault(
- compiler, stmt, c, values, kw):
+def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
proc = c.default
values.append((c, proc.next_value()))
elif c.default.is_clause_element:
@@ -519,38 +575,43 @@ def _append_param_insert_select_hasdefault(
def _append_param_update(
- compiler, stmt, c, implicit_return_defaults, values, kw):
+ compiler, stmt, c, implicit_return_defaults, values, kw
+):
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, compiler.process(
- c.onupdate.arg.self_group(), **kw))
+ (c, compiler.process(c.onupdate.arg.self_group(), **kw))
)
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
else:
- values.append(
- (c, _create_update_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_update_prefetch_bind_param(compiler, c)))
elif c.server_onupdate is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
- elif implicit_return_defaults and \
- stmt._return_defaults is not True and \
- c in implicit_return_defaults:
+ elif (
+ implicit_return_defaults
+ and stmt._return_defaults is not True
+ and c in implicit_return_defaults
+ ):
compiler.returning.append(c)
def _get_multitable_params(
- compiler, stmt, stmt_parameters, check_columns,
- _col_bind_name, _getattr_col_key, values, kw):
+ compiler,
+ stmt,
+ stmt_parameters,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+):
normalized_params = dict(
(elements._clause_element_as_expr(c), param)
@@ -565,8 +626,12 @@ def _get_multitable_params(
value = normalized_params[c]
if elements._is_literal(value):
value = _create_bind_param(
- compiler, c, value, required=value is REQUIRED,
- name=_col_bind_name(c))
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
+ name=_col_bind_name(c),
+ )
else:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
@@ -577,20 +642,25 @@ def _get_multitable_params(
for c in t.c:
if c in normalized_params:
continue
- elif (c.onupdate is not None and not
- c.onupdate.is_sequence):
+ elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, compiler.process(
- c.onupdate.arg.self_group(),
- **kw)
- )
+ (
+ c,
+ compiler.process(
+ c.onupdate.arg.self_group(), **kw
+ ),
+ )
)
compiler.postfetch.append(c)
else:
values.append(
- (c, _create_update_prefetch_bind_param(
- compiler, c, name=_col_bind_name(c)))
+ (
+ c,
+ _create_update_prefetch_bind_param(
+ compiler, c, name=_col_bind_name(c)
+ ),
+ )
)
elif c.server_onupdate is not None:
compiler.postfetch.append(c)
@@ -608,8 +678,11 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
if elements._is_literal(row[key]):
new_param = _create_bind_param(
- compiler, col, row[key],
- name="%s_m%d" % (col.key, i + 1), **kw
+ compiler,
+ col,
+ row[key],
+ name="%s_m%d" % (col.key, i + 1),
+ **kw
)
else:
new_param = compiler.process(row[key].self_group(), **kw)
@@ -626,7 +699,8 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
def _get_stmt_parameters_params(
- compiler, parameters, stmt_parameters, _column_as_key, values, kw):
+ compiler, parameters, stmt_parameters, _column_as_key, values, kw
+):
for k, v in stmt_parameters.items():
colkey = _column_as_key(k)
if colkey is not None:
@@ -637,8 +711,8 @@ def _get_stmt_parameters_params(
# coercing right side to bound param
if elements._is_literal(v):
v = compiler.process(
- elements.BindParameter(None, v, type_=k.type),
- **kw)
+ elements.BindParameter(None, v, type_=k.type), **kw
+ )
else:
v = compiler.process(v.self_group(), **kw)
@@ -646,22 +720,27 @@ def _get_stmt_parameters_params(
def _get_returning_modifiers(compiler, stmt):
- need_pks = compiler.isinsert and \
- not compiler.inline and \
- not stmt._returning and \
- not stmt._has_multi_parameters
+ need_pks = (
+ compiler.isinsert
+ and not compiler.inline
+ and not stmt._returning
+ and not stmt._has_multi_parameters
+ )
- implicit_returning = need_pks and \
- compiler.dialect.implicit_returning and \
- stmt.table.implicit_returning
+ implicit_returning = (
+ need_pks
+ and compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ )
if compiler.isinsert:
- implicit_return_defaults = (implicit_returning and
- stmt._return_defaults)
+ implicit_return_defaults = implicit_returning and stmt._return_defaults
elif compiler.isupdate:
- implicit_return_defaults = (compiler.dialect.implicit_returning and
- stmt.table.implicit_returning and
- stmt._return_defaults)
+ implicit_return_defaults = (
+ compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ and stmt._return_defaults
+ )
else:
# this line is unused, currently we are always
# isinsert or isupdate
@@ -675,8 +754,12 @@ def _get_returning_modifiers(compiler, stmt):
postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
- return need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid
+ return (
+ need_pks,
+ implicit_returning,
+ implicit_return_defaults,
+ postfetch_lastrowid,
+ )
def _warn_pk_with_no_anticipated_value(c):
@@ -687,8 +770,8 @@ def _warn_pk_with_no_anticipated_value(c):
"nor does it indicate 'autoincrement=True' or 'nullable=True', "
"and no explicit value is passed. "
"Primary key columns typically may not store NULL."
- %
- (c.table.fullname, c.name, c.table.fullname))
+ % (c.table.fullname, c.name, c.table.fullname)
+ )
if len(c.table.primary_key) > 1:
msg += (
" Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be "
@@ -696,5 +779,6 @@ def _warn_pk_with_no_anticipated_value(c):
"keys if AUTO_INCREMENT/SERIAL/IDENTITY "
"behavior is expected for one of the columns in the primary key. "
"CREATE TABLE statements are impacted by this change as well on "
- "most backends.")
+ "most backends."
+ )
util.warn(msg)
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index 91e93efe7..f21b3d7f0 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -56,8 +56,9 @@ class DDLElement(Executable, _DDLCompiles):
"""
- _execution_options = Executable.\
- _execution_options.union({'autocommit': True})
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": True}
+ )
target = None
on = None
@@ -95,11 +96,13 @@ class DDLElement(Executable, _DDLCompiles):
if self._should_execute(target, bind):
return bind.execute(self.against(target))
else:
- bind.engine.logger.info(
- "DDL execution skipped, criteria not met.")
+ bind.engine.logger.info("DDL execution skipped, criteria not met.")
- @util.deprecated("0.7", "See :class:`.DDLEvents`, as well as "
- ":meth:`.DDLElement.execute_if`.")
+ @util.deprecated(
+ "0.7",
+ "See :class:`.DDLEvents`, as well as "
+ ":meth:`.DDLElement.execute_if`.",
+ )
def execute_at(self, event_name, target):
"""Link execution of this DDL to the DDL lifecycle of a SchemaItem.
@@ -129,11 +132,12 @@ class DDLElement(Executable, _DDLCompiles):
"""
def call_event(target, connection, **kw):
- if self._should_execute_deprecated(event_name,
- target, connection, **kw):
+ if self._should_execute_deprecated(
+ event_name, target, connection, **kw
+ ):
return connection.execute(self.against(target))
- event.listen(target, "" + event_name.replace('-', '_'), call_event)
+ event.listen(target, "" + event_name.replace("-", "_"), call_event)
@_generative
def against(self, target):
@@ -211,8 +215,9 @@ class DDLElement(Executable, _DDLCompiles):
self.state = state
def _should_execute(self, target, bind, **kw):
- if self.on is not None and \
- not self._should_execute_deprecated(None, target, bind, **kw):
+ if self.on is not None and not self._should_execute_deprecated(
+ None, target, bind, **kw
+ ):
return False
if isinstance(self.dialect, util.string_types):
@@ -221,9 +226,9 @@ class DDLElement(Executable, _DDLCompiles):
elif isinstance(self.dialect, (tuple, list, set)):
if bind.engine.name not in self.dialect:
return False
- if (self.callable_ is not None and
- not self.callable_(self, target, bind,
- state=self.state, **kw)):
+ if self.callable_ is not None and not self.callable_(
+ self, target, bind, state=self.state, **kw
+ ):
return False
return True
@@ -245,13 +250,15 @@ class DDLElement(Executable, _DDLCompiles):
return bind.execute(self.against(target))
def _check_ddl_on(self, on):
- if (on is not None and
- (not isinstance(on, util.string_types + (tuple, list, set)) and
- not util.callable(on))):
+ if on is not None and (
+ not isinstance(on, util.string_types + (tuple, list, set))
+ and not util.callable(on)
+ ):
raise exc.ArgumentError(
"Expected the name of a database dialect, a tuple "
"of names, or a callable for "
- "'on' criteria, got type '%s'." % type(on).__name__)
+ "'on' criteria, got type '%s'." % type(on).__name__
+ )
def bind(self):
if self._bind:
@@ -259,6 +266,7 @@ class DDLElement(Executable, _DDLCompiles):
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
def _generate(self):
@@ -375,8 +383,9 @@ class DDL(DDLElement):
if not isinstance(statement, util.string_types):
raise exc.ArgumentError(
- "Expected a string or unicode SQL statement, got '%r'" %
- statement)
+ "Expected a string or unicode SQL statement, got '%r'"
+ % statement
+ )
self.statement = statement
self.context = context or {}
@@ -386,12 +395,18 @@ class DDL(DDLElement):
self._bind = bind
def __repr__(self):
- return '<%s@%s; %s>' % (
- type(self).__name__, id(self),
- ', '.join([repr(self.statement)] +
- ['%s=%r' % (key, getattr(self, key))
- for key in ('on', 'context')
- if getattr(self, key)]))
+ return "<%s@%s; %s>" % (
+ type(self).__name__,
+ id(self),
+ ", ".join(
+ [repr(self.statement)]
+ + [
+ "%s=%r" % (key, getattr(self, key))
+ for key in ("on", "context")
+ if getattr(self, key)
+ ]
+ ),
+ )
class _CreateDropBase(DDLElement):
@@ -464,8 +479,8 @@ class CreateTable(_CreateDropBase):
__visit_name__ = "create_table"
def __init__(
- self, element, on=None, bind=None,
- include_foreign_key_constraints=None):
+ self, element, on=None, bind=None, include_foreign_key_constraints=None
+ ):
"""Create a :class:`.CreateTable` construct.
:param element: a :class:`.Table` that's the subject
@@ -481,9 +496,7 @@ class CreateTable(_CreateDropBase):
"""
super(CreateTable, self).__init__(element, on=on, bind=bind)
- self.columns = [CreateColumn(column)
- for column in element.columns
- ]
+ self.columns = [CreateColumn(column) for column in element.columns]
self.include_foreign_key_constraints = include_foreign_key_constraints
@@ -494,6 +507,7 @@ class _DropView(_CreateDropBase):
This object will eventually be part of a public "view" API.
"""
+
__visit_name__ = "drop_view"
@@ -602,7 +616,8 @@ class CreateColumn(_DDLCompiles):
to support custom column creation styles.
"""
- __visit_name__ = 'create_column'
+
+ __visit_name__ = "create_column"
def __init__(self, element):
self.element = element
@@ -646,7 +661,8 @@ class AddConstraint(_CreateDropBase):
def __init__(self, element, *args, **kw):
super(AddConstraint, self).__init__(element, *args, **kw)
element._create_rule = util.portable_instancemethod(
- self._create_rule_disable)
+ self._create_rule_disable
+ )
class DropConstraint(_CreateDropBase):
@@ -658,7 +674,8 @@ class DropConstraint(_CreateDropBase):
self.cascade = cascade
super(DropConstraint, self).__init__(element, **kw)
element._create_rule = util.portable_instancemethod(
- self._create_rule_disable)
+ self._create_rule_disable
+ )
class SetTableComment(_CreateDropBase):
@@ -691,9 +708,9 @@ class DDLBase(SchemaVisitor):
class SchemaGenerator(DDLBase):
-
- def __init__(self, dialect, connection, checkfirst=False,
- tables=None, **kwargs):
+ def __init__(
+ self, dialect, connection, checkfirst=False, tables=None, **kwargs
+ ):
super(SchemaGenerator, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
@@ -706,25 +723,22 @@ class SchemaGenerator(DDLBase):
effective_schema = self.connection.schema_for_object(table)
if effective_schema:
self.dialect.validate_identifier(effective_schema)
- return not self.checkfirst or \
- not self.dialect.has_table(self.connection,
- table.name, schema=effective_schema)
+ return not self.checkfirst or not self.dialect.has_table(
+ self.connection, table.name, schema=effective_schema
+ )
def _can_create_sequence(self, sequence):
effective_schema = self.connection.schema_for_object(sequence)
- return self.dialect.supports_sequences and \
- (
- (not self.dialect.sequences_optional or
- not sequence.optional) and
- (
- not self.checkfirst or
- not self.dialect.has_sequence(
- self.connection,
- sequence.name,
- schema=effective_schema)
+ return self.dialect.supports_sequences and (
+ (not self.dialect.sequences_optional or not sequence.optional)
+ and (
+ not self.checkfirst
+ or not self.dialect.has_sequence(
+ self.connection, sequence.name, schema=effective_schema
)
)
+ )
def visit_metadata(self, metadata):
if self.tables is not None:
@@ -733,18 +747,23 @@ class SchemaGenerator(DDLBase):
tables = list(metadata.tables.values())
collection = sort_tables_and_constraints(
- [t for t in tables if self._can_create_table(t)])
-
- seq_coll = [s for s in metadata._sequences.values()
- if s.column is None and self._can_create_sequence(s)]
+ [t for t in tables if self._can_create_table(t)]
+ )
- event_collection = [
- t for (t, fks) in collection if t is not None
+ seq_coll = [
+ s
+ for s in metadata._sequences.values()
+ if s.column is None and self._can_create_sequence(s)
]
- metadata.dispatch.before_create(metadata, self.connection,
- tables=event_collection,
- checkfirst=self.checkfirst,
- _ddl_runner=self)
+
+ event_collection = [t for (t, fks) in collection if t is not None]
+ metadata.dispatch.before_create(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
for seq in seq_coll:
self.traverse_single(seq, create_ok=True)
@@ -752,30 +771,40 @@ class SchemaGenerator(DDLBase):
for table, fkcs in collection:
if table is not None:
self.traverse_single(
- table, create_ok=True,
+ table,
+ create_ok=True,
include_foreign_key_constraints=fkcs,
- _is_metadata_operation=True)
+ _is_metadata_operation=True,
+ )
else:
for fkc in fkcs:
self.traverse_single(fkc)
- metadata.dispatch.after_create(metadata, self.connection,
- tables=event_collection,
- checkfirst=self.checkfirst,
- _ddl_runner=self)
+ metadata.dispatch.after_create(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
def visit_table(
- self, table, create_ok=False,
- include_foreign_key_constraints=None,
- _is_metadata_operation=False):
+ self,
+ table,
+ create_ok=False,
+ include_foreign_key_constraints=None,
+ _is_metadata_operation=False,
+ ):
if not create_ok and not self._can_create_table(table):
return
table.dispatch.before_create(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
for column in table.columns:
if column.default is not None:
@@ -788,10 +817,11 @@ class SchemaGenerator(DDLBase):
self.connection.execute(
CreateTable(
table,
- include_foreign_key_constraints=include_foreign_key_constraints
- ))
+ include_foreign_key_constraints=include_foreign_key_constraints,
+ )
+ )
- if hasattr(table, 'indexes'):
+ if hasattr(table, "indexes"):
for index in table.indexes:
self.traverse_single(index)
@@ -804,10 +834,12 @@ class SchemaGenerator(DDLBase):
self.connection.execute(SetColumnComment(column))
table.dispatch.after_create(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
@@ -824,9 +856,9 @@ class SchemaGenerator(DDLBase):
class SchemaDropper(DDLBase):
-
- def __init__(self, dialect, connection, checkfirst=False,
- tables=None, **kwargs):
+ def __init__(
+ self, dialect, connection, checkfirst=False, tables=None, **kwargs
+ ):
super(SchemaDropper, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
@@ -842,15 +874,17 @@ class SchemaDropper(DDLBase):
try:
unsorted_tables = [t for t in tables if self._can_drop_table(t)]
- collection = list(reversed(
- sort_tables_and_constraints(
- unsorted_tables,
- filter_fn=lambda constraint: False
- if not self.dialect.supports_alter
- or constraint.name is None
- else None
+ collection = list(
+ reversed(
+ sort_tables_and_constraints(
+ unsorted_tables,
+ filter_fn=lambda constraint: False
+ if not self.dialect.supports_alter
+ or constraint.name is None
+ else None,
+ )
)
- ))
+ )
except exc.CircularDependencyError as err2:
if not self.dialect.supports_alter:
util.warn(
@@ -862,16 +896,15 @@ class SchemaDropper(DDLBase):
"ForeignKeyConstraint "
"objects involved in the cycle to mark these as known "
"cycles that will be ignored."
- % (
- ", ".join(sorted([t.fullname for t in err2.cycles]))
- )
+ % (", ".join(sorted([t.fullname for t in err2.cycles])))
)
collection = [(t, ()) for t in unsorted_tables]
else:
util.raise_from_cause(
exc.CircularDependencyError(
err2.args[0],
- err2.cycles, err2.edges,
+ err2.cycles,
+ err2.edges,
msg="Can't sort tables for DROP; an "
"unresolvable foreign key "
"dependency exists between tables: %s. Please ensure "
@@ -880,9 +913,10 @@ class SchemaDropper(DDLBase):
"names so that they can be dropped using "
"DROP CONSTRAINT."
% (
- ", ".join(sorted([t.fullname for t in err2.cycles]))
- )
-
+ ", ".join(
+ sorted([t.fullname for t in err2.cycles])
+ )
+ ),
)
)
@@ -892,18 +926,21 @@ class SchemaDropper(DDLBase):
if s.column is None and self._can_drop_sequence(s)
]
- event_collection = [
- t for (t, fks) in collection if t is not None
- ]
+ event_collection = [t for (t, fks) in collection if t is not None]
metadata.dispatch.before_drop(
- metadata, self.connection, tables=event_collection,
- checkfirst=self.checkfirst, _ddl_runner=self)
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
for table, fkcs in collection:
if table is not None:
self.traverse_single(
- table, drop_ok=True, _is_metadata_operation=True)
+ table, drop_ok=True, _is_metadata_operation=True
+ )
else:
for fkc in fkcs:
self.traverse_single(fkc)
@@ -912,8 +949,12 @@ class SchemaDropper(DDLBase):
self.traverse_single(seq, drop_ok=True)
metadata.dispatch.after_drop(
- metadata, self.connection, tables=event_collection,
- checkfirst=self.checkfirst, _ddl_runner=self)
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
def _can_drop_table(self, table):
self.dialect.validate_identifier(table.name)
@@ -921,19 +962,20 @@ class SchemaDropper(DDLBase):
if effective_schema:
self.dialect.validate_identifier(effective_schema)
return not self.checkfirst or self.dialect.has_table(
- self.connection, table.name, schema=effective_schema)
+ self.connection, table.name, schema=effective_schema
+ )
def _can_drop_sequence(self, sequence):
effective_schema = self.connection.schema_for_object(sequence)
- return self.dialect.supports_sequences and \
- ((not self.dialect.sequences_optional or
- not sequence.optional) and
- (not self.checkfirst or
- self.dialect.has_sequence(
- self.connection,
- sequence.name,
- schema=effective_schema))
- )
+ return self.dialect.supports_sequences and (
+ (not self.dialect.sequences_optional or not sequence.optional)
+ and (
+ not self.checkfirst
+ or self.dialect.has_sequence(
+ self.connection, sequence.name, schema=effective_schema
+ )
+ )
+ )
def visit_index(self, index):
self.connection.execute(DropIndex(index))
@@ -943,10 +985,12 @@ class SchemaDropper(DDLBase):
return
table.dispatch.before_drop(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
self.connection.execute(DropTable(table))
@@ -960,10 +1004,12 @@ class SchemaDropper(DDLBase):
self.traverse_single(column.default)
table.dispatch.after_drop(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
@@ -1019,25 +1065,29 @@ def sort_tables(tables, skip_fn=None, extra_dependencies=None):
"""
if skip_fn is not None:
+
def _skip_fn(fkc):
for fk in fkc.elements:
if skip_fn(fk):
return True
else:
return None
+
else:
_skip_fn = None
return [
- t for (t, fkcs) in
- sort_tables_and_constraints(
- tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies)
+ t
+ for (t, fkcs) in sort_tables_and_constraints(
+ tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies
+ )
if t is not None
]
def sort_tables_and_constraints(
- tables, filter_fn=None, extra_dependencies=None):
+ tables, filter_fn=None, extra_dependencies=None
+):
"""sort a collection of :class:`.Table` / :class:`.ForeignKeyConstraint`
objects.
@@ -1109,8 +1159,9 @@ def sort_tables_and_constraints(
try:
candidate_sort = list(
topological.sort(
- fixed_dependencies.union(mutable_dependencies), tables,
- deterministic_order=True
+ fixed_dependencies.union(mutable_dependencies),
+ tables,
+ deterministic_order=True,
)
)
except exc.CircularDependencyError as err:
@@ -1118,8 +1169,10 @@ def sort_tables_and_constraints(
if edge in mutable_dependencies:
table = edge[1]
can_remove = [
- fkc for fkc in table.foreign_key_constraints
- if filter_fn is None or filter_fn(fkc) is not False]
+ fkc
+ for fkc in table.foreign_key_constraints
+ if filter_fn is None or filter_fn(fkc) is not False
+ ]
remaining_fkcs.update(can_remove)
for fkc in can_remove:
dependent_on = fkc.referred_table
@@ -1127,8 +1180,9 @@ def sort_tables_and_constraints(
mutable_dependencies.discard((dependent_on, table))
candidate_sort = list(
topological.sort(
- fixed_dependencies.union(mutable_dependencies), tables,
- deterministic_order=True
+ fixed_dependencies.union(mutable_dependencies),
+ tables,
+ deterministic_order=True,
)
)
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 8149f9731..fa0052198 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -11,19 +11,43 @@
from .. import exc, util
from . import type_api
from . import operators
-from .elements import BindParameter, True_, False_, BinaryExpression, \
- Null, _const_expr, _clause_element_as_expr, \
- ClauseList, ColumnElement, TextClause, UnaryExpression, \
- collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \
- Slice, Visitable, _literal_as_binds, CollectionAggregate, \
- Tuple
+from .elements import (
+ BindParameter,
+ True_,
+ False_,
+ BinaryExpression,
+ Null,
+ _const_expr,
+ _clause_element_as_expr,
+ ClauseList,
+ ColumnElement,
+ TextClause,
+ UnaryExpression,
+ collate,
+ _is_literal,
+ _literal_as_text,
+ ClauseElement,
+ and_,
+ or_,
+ Slice,
+ Visitable,
+ _literal_as_binds,
+ CollectionAggregate,
+ Tuple,
+)
from .selectable import SelectBase, Alias, Selectable, ScalarSelect
-def _boolean_compare(expr, op, obj, negate=None, reverse=False,
- _python_is_types=(util.NoneType, bool),
- result_type = None,
- **kwargs):
+def _boolean_compare(
+ expr,
+ op,
+ obj,
+ negate=None,
+ reverse=False,
+ _python_is_types=(util.NoneType, bool),
+ result_type=None,
+ **kwargs
+):
if result_type is None:
result_type = type_api.BOOLEANTYPE
@@ -33,57 +57,64 @@ def _boolean_compare(expr, op, obj, negate=None, reverse=False,
# allow x ==/!= True/False to be treated as a literal.
# this comes out to "== / != true/false" or "1/0" if those
# constants aren't supported and works on all platforms
- if op in (operators.eq, operators.ne) and \
- isinstance(obj, (bool, True_, False_)):
- return BinaryExpression(expr,
- _literal_as_text(obj),
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ if op in (operators.eq, operators.ne) and isinstance(
+ obj, (bool, True_, False_)
+ ):
+ return BinaryExpression(
+ expr,
+ _literal_as_text(obj),
+ op,
+ type_=result_type,
+ negate=negate,
+ modifiers=kwargs,
+ )
elif op in (operators.is_distinct_from, operators.isnot_distinct_from):
- return BinaryExpression(expr,
- _literal_as_text(obj),
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ return BinaryExpression(
+ expr,
+ _literal_as_text(obj),
+ op,
+ type_=result_type,
+ negate=negate,
+ modifiers=kwargs,
+ )
else:
# all other None/True/False uses IS, IS NOT
if op in (operators.eq, operators.is_):
- return BinaryExpression(expr, _const_expr(obj),
- operators.is_,
- negate=operators.isnot,
- type_=result_type
- )
+ return BinaryExpression(
+ expr,
+ _const_expr(obj),
+ operators.is_,
+ negate=operators.isnot,
+ type_=result_type,
+ )
elif op in (operators.ne, operators.isnot):
- return BinaryExpression(expr, _const_expr(obj),
- operators.isnot,
- negate=operators.is_,
- type_=result_type
- )
+ return BinaryExpression(
+ expr,
+ _const_expr(obj),
+ operators.isnot,
+ negate=operators.is_,
+ type_=result_type,
+ )
else:
raise exc.ArgumentError(
"Only '=', '!=', 'is_()', 'isnot()', "
"'is_distinct_from()', 'isnot_distinct_from()' "
- "operators can be used with None/True/False")
+ "operators can be used with None/True/False"
+ )
else:
obj = _check_literal(expr, op, obj)
if reverse:
- return BinaryExpression(obj,
- expr,
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ return BinaryExpression(
+ obj, expr, op, type_=result_type, negate=negate, modifiers=kwargs
+ )
else:
- return BinaryExpression(expr,
- obj,
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ return BinaryExpression(
+ expr, obj, op, type_=result_type, negate=negate, modifiers=kwargs
+ )
-def _custom_op_operate(expr, op, obj, reverse=False, result_type=None,
- **kw):
+def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw):
if result_type is None:
if op.return_type:
result_type = op.return_type
@@ -91,11 +122,11 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None,
result_type = type_api.BOOLEANTYPE
return _binary_operate(
- expr, op, obj, reverse=reverse, result_type=result_type, **kw)
+ expr, op, obj, reverse=reverse, result_type=result_type, **kw
+ )
-def _binary_operate(expr, op, obj, reverse=False, result_type=None,
- **kw):
+def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw):
obj = _check_literal(expr, op, obj)
if reverse:
@@ -105,10 +136,10 @@ def _binary_operate(expr, op, obj, reverse=False, result_type=None,
if result_type is None:
op, result_type = left.comparator._adapt_expression(
- op, right.comparator)
+ op, right.comparator
+ )
- return BinaryExpression(
- left, right, op, type_=result_type, modifiers=kw)
+ return BinaryExpression(left, right, op, type_=result_type, modifiers=kw)
def _conjunction_operate(expr, op, other, **kw):
@@ -128,8 +159,7 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
seq_or_selectable = _clause_element_as_expr(seq_or_selectable)
if isinstance(seq_or_selectable, ScalarSelect):
- return _boolean_compare(expr, op, seq_or_selectable,
- negate=negate_op)
+ return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op)
elif isinstance(seq_or_selectable, SelectBase):
# TODO: if we ever want to support (x, y, z) IN (select x,
@@ -138,32 +168,33 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
# does not export itself as a FROM clause
return _boolean_compare(
- expr, op, seq_or_selectable.as_scalar(),
- negate=negate_op, **kw)
+ expr, op, seq_or_selectable.as_scalar(), negate=negate_op, **kw
+ )
elif isinstance(seq_or_selectable, (Selectable, TextClause)):
- return _boolean_compare(expr, op, seq_or_selectable,
- negate=negate_op, **kw)
+ return _boolean_compare(
+ expr, op, seq_or_selectable, negate=negate_op, **kw
+ )
elif isinstance(seq_or_selectable, ClauseElement):
- if isinstance(seq_or_selectable, BindParameter) and \
- seq_or_selectable.expanding:
+ if (
+ isinstance(seq_or_selectable, BindParameter)
+ and seq_or_selectable.expanding
+ ):
if isinstance(expr, Tuple):
- seq_or_selectable = (
- seq_or_selectable._with_expanding_in_types(
- [elem.type for elem in expr]
- )
+ seq_or_selectable = seq_or_selectable._with_expanding_in_types(
+ [elem.type for elem in expr]
)
return _boolean_compare(
- expr, op,
- seq_or_selectable,
- negate=negate_op)
+ expr, op, seq_or_selectable, negate=negate_op
+ )
else:
raise exc.InvalidRequestError(
- 'in_() accepts'
- ' either a list of expressions, '
+ "in_() accepts"
+ " either a list of expressions, "
'a selectable, or an "expanding" bound parameter: %r'
- % seq_or_selectable)
+ % seq_or_selectable
+ )
# Handle non selectable arguments as sequences
args = []
@@ -171,9 +202,10 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
if not _is_literal(o):
if not isinstance(o, operators.ColumnOperators):
raise exc.InvalidRequestError(
- 'in_() accepts'
- ' either a list of expressions, '
- 'a selectable, or an "expanding" bound parameter: %r' % o)
+ "in_() accepts"
+ " either a list of expressions, "
+ 'a selectable, or an "expanding" bound parameter: %r' % o
+ )
elif o is None:
o = Null()
else:
@@ -182,15 +214,14 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
if len(args) == 0:
op, negate_op = (
- operators.empty_in_op,
- operators.empty_notin_op) if op is operators.in_op \
- else (
- operators.empty_notin_op,
- operators.empty_in_op)
+ (operators.empty_in_op, operators.empty_notin_op)
+ if op is operators.in_op
+ else (operators.empty_notin_op, operators.empty_in_op)
+ )
- return _boolean_compare(expr, op,
- ClauseList(*args).self_group(against=op),
- negate=negate_op)
+ return _boolean_compare(
+ expr, op, ClauseList(*args).self_group(against=op), negate=negate_op
+ )
def _getitem_impl(expr, op, other, **kw):
@@ -202,13 +233,14 @@ def _getitem_impl(expr, op, other, **kw):
def _unsupported_impl(expr, op, *arg, **kw):
- raise NotImplementedError("Operator '%s' is not supported on "
- "this expression" % op.__name__)
+ raise NotImplementedError(
+ "Operator '%s' is not supported on " "this expression" % op.__name__
+ )
def _inv_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.__inv__`."""
- if hasattr(expr, 'negation_clause'):
+ if hasattr(expr, "negation_clause"):
return expr.negation_clause
else:
return expr._negate()
@@ -223,20 +255,22 @@ def _match_impl(expr, op, other, **kw):
"""See :meth:`.ColumnOperators.match`."""
return _boolean_compare(
- expr, operators.match_op,
- _check_literal(
- expr, operators.match_op, other),
+ expr,
+ operators.match_op,
+ _check_literal(expr, operators.match_op, other),
result_type=type_api.MATCHTYPE,
negate=operators.notmatch_op
- if op is operators.match_op else operators.match_op,
+ if op is operators.match_op
+ else operators.match_op,
**kw
)
def _distinct_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.distinct`."""
- return UnaryExpression(expr, operator=operators.distinct_op,
- type_=expr.type)
+ return UnaryExpression(
+ expr, operator=operators.distinct_op, type_=expr.type
+ )
def _between_impl(expr, op, cleft, cright, **kw):
@@ -247,17 +281,21 @@ def _between_impl(expr, op, cleft, cright, **kw):
_check_literal(expr, operators.and_, cleft),
_check_literal(expr, operators.and_, cright),
operator=operators.and_,
- group=False, group_contents=False),
+ group=False,
+ group_contents=False,
+ ),
op,
negate=operators.notbetween_op
if op is operators.between_op
else operators.between_op,
- modifiers=kw)
+ modifiers=kw,
+ )
def _collate_impl(expr, op, other, **kw):
return collate(expr, other)
+
# a mapping of operators with the method they use, along with
# their negated operator for comparison operators
operator_lookup = {
@@ -271,8 +309,8 @@ operator_lookup = {
"mod": (_binary_operate,),
"truediv": (_binary_operate,),
"custom_op": (_custom_op_operate,),
- "json_path_getitem_op": (_binary_operate, ),
- "json_getitem_op": (_binary_operate, ),
+ "json_path_getitem_op": (_binary_operate,),
+ "json_getitem_op": (_binary_operate,),
"concat_op": (_binary_operate,),
"any_op": (_scalar, CollectionAggregate._create_any),
"all_op": (_scalar, CollectionAggregate._create_all),
@@ -303,8 +341,8 @@ operator_lookup = {
"match_op": (_match_impl,),
"notmatch_op": (_match_impl,),
"distinct_op": (_distinct_impl,),
- "between_op": (_between_impl, ),
- "notbetween_op": (_between_impl, ),
+ "between_op": (_between_impl,),
+ "notbetween_op": (_between_impl,),
"neg": (_neg_impl,),
"getitem": (_getitem_impl,),
"lshift": (_unsupported_impl,),
@@ -315,12 +353,11 @@ operator_lookup = {
def _check_literal(expr, operator, other, bindparam_type=None):
if isinstance(other, (ColumnElement, TextClause)):
- if isinstance(other, BindParameter) and \
- other.type._isnull:
+ if isinstance(other, BindParameter) and other.type._isnull:
other = other._clone()
other.type = expr.type
return other
- elif hasattr(other, '__clause_element__'):
+ elif hasattr(other, "__clause_element__"):
other = other.__clause_element__()
elif isinstance(other, type_api.TypeEngine.Comparator):
other = other.expr
@@ -331,4 +368,3 @@ def _check_literal(expr, operator, other, bindparam_type=None):
return expr._bind_param(operator, other, type_=bindparam_type)
else:
return other
-
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index d6890de15..0cea5ccc4 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -9,26 +9,43 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`.
"""
-from .base import Executable, _generative, _from_objects, DialectKWArgs, \
- ColumnCollection
-from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \
- _column_as_key
-from .selectable import _interpret_as_from, _interpret_as_select, \
- HasPrefixes, HasCTE
+from .base import (
+ Executable,
+ _generative,
+ _from_objects,
+ DialectKWArgs,
+ ColumnCollection,
+)
+from .elements import (
+ ClauseElement,
+ _literal_as_text,
+ Null,
+ and_,
+ _clone,
+ _column_as_key,
+)
+from .selectable import (
+ _interpret_as_from,
+ _interpret_as_select,
+ HasPrefixes,
+ HasCTE,
+)
from .. import util
from .. import exc
class UpdateBase(
- HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement):
+ HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement
+):
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.
"""
- __visit_name__ = 'update_base'
+ __visit_name__ = "update_base"
- _execution_options = \
- Executable._execution_options.union({'autocommit': True})
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": True}
+ )
_hints = util.immutabledict()
_parameter_ordering = None
_prefixes = ()
@@ -37,30 +54,33 @@ class UpdateBase(
def _process_colparams(self, parameters):
def process_single(p):
if isinstance(p, (list, tuple)):
- return dict(
- (c.key, pval)
- for c, pval in zip(self.table.c, p)
- )
+ return dict((c.key, pval) for c, pval in zip(self.table.c, p))
else:
return p
if self._preserve_parameter_order and parameters is not None:
- if not isinstance(parameters, list) or \
- (parameters and not isinstance(parameters[0], tuple)):
+ if not isinstance(parameters, list) or (
+ parameters and not isinstance(parameters[0], tuple)
+ ):
raise ValueError(
"When preserve_parameter_order is True, "
- "values() only accepts a list of 2-tuples")
+ "values() only accepts a list of 2-tuples"
+ )
self._parameter_ordering = [key for key, value in parameters]
return dict(parameters), False
- if (isinstance(parameters, (list, tuple)) and parameters and
- isinstance(parameters[0], (list, tuple, dict))):
+ if (
+ isinstance(parameters, (list, tuple))
+ and parameters
+ and isinstance(parameters[0], (list, tuple, dict))
+ ):
if not self._supports_multi_parameters:
raise exc.InvalidRequestError(
"This construct does not support "
- "multiple parameter sets.")
+ "multiple parameter sets."
+ )
return [process_single(p) for p in parameters], True
else:
@@ -77,7 +97,8 @@ class UpdateBase(
raise NotImplementedError(
"params() is not supported for INSERT/UPDATE/DELETE statements."
" To set the values for an INSERT or UPDATE statement, use"
- " stmt.values(**parameters).")
+ " stmt.values(**parameters)."
+ )
def bind(self):
"""Return a 'bind' linked to this :class:`.UpdateBase`
@@ -88,6 +109,7 @@ class UpdateBase(
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
@_generative
@@ -181,15 +203,14 @@ class UpdateBase(
if selectable is None:
selectable = self.table
- self._hints = self._hints.union(
- {(selectable, dialect_name): text})
+ self._hints = self._hints.union({(selectable, dialect_name): text})
class ValuesBase(UpdateBase):
"""Supplies support for :meth:`.ValuesBase.values` to
INSERT and UPDATE constructs."""
- __visit_name__ = 'values_base'
+ __visit_name__ = "values_base"
_supports_multi_parameters = False
_has_multi_parameters = False
@@ -199,8 +220,9 @@ class ValuesBase(UpdateBase):
def __init__(self, table, values, prefixes):
self.table = _interpret_as_from(table)
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(values)
+ self.parameters, self._has_multi_parameters = self._process_colparams(
+ values
+ )
if prefixes:
self._setup_prefixes(prefixes)
@@ -332,23 +354,27 @@ class ValuesBase(UpdateBase):
"""
if self.select is not None:
raise exc.InvalidRequestError(
- "This construct already inserts from a SELECT")
+ "This construct already inserts from a SELECT"
+ )
if self._has_multi_parameters and kwargs:
raise exc.InvalidRequestError(
- "This construct already has multiple parameter sets.")
+ "This construct already has multiple parameter sets."
+ )
if args:
if len(args) > 1:
raise exc.ArgumentError(
"Only a single dictionary/tuple or list of "
- "dictionaries/tuples is accepted positionally.")
+ "dictionaries/tuples is accepted positionally."
+ )
v = args[0]
else:
v = {}
if self.parameters is None:
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(v)
+ self.parameters, self._has_multi_parameters = self._process_colparams(
+ v
+ )
else:
if self._has_multi_parameters:
self.parameters = list(self.parameters)
@@ -356,7 +382,8 @@ class ValuesBase(UpdateBase):
if not self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
- "formats in one statement")
+ "formats in one statement"
+ )
self.parameters.extend(p)
else:
@@ -365,14 +392,16 @@ class ValuesBase(UpdateBase):
if self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
- "formats in one statement")
+ "formats in one statement"
+ )
self.parameters.update(p)
if kwargs:
if self._has_multi_parameters:
raise exc.ArgumentError(
"Can't pass kwargs and multiple parameter sets "
- "simultaneously")
+ "simultaneously"
+ )
else:
self.parameters.update(kwargs)
@@ -456,19 +485,22 @@ class Insert(ValuesBase):
:ref:`coretutorial_insert_expressions`
"""
- __visit_name__ = 'insert'
+
+ __visit_name__ = "insert"
_supports_multi_parameters = True
- def __init__(self,
- table,
- values=None,
- inline=False,
- bind=None,
- prefixes=None,
- returning=None,
- return_defaults=False,
- **dialect_kw):
+ def __init__(
+ self,
+ table,
+ values=None,
+ inline=False,
+ bind=None,
+ prefixes=None,
+ returning=None,
+ return_defaults=False,
+ **dialect_kw
+ ):
"""Construct an :class:`.Insert` object.
Similar functionality is available via the
@@ -526,7 +558,7 @@ class Insert(ValuesBase):
def get_children(self, **kwargs):
if self.select is not None:
- return self.select,
+ return (self.select,)
else:
return ()
@@ -578,11 +610,12 @@ class Insert(ValuesBase):
"""
if self.parameters:
raise exc.InvalidRequestError(
- "This construct already inserts value expressions")
+ "This construct already inserts value expressions"
+ )
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(
- {_column_as_key(n): Null() for n in names})
+ self.parameters, self._has_multi_parameters = self._process_colparams(
+ {_column_as_key(n): Null() for n in names}
+ )
self.select_names = names
self.inline = True
@@ -603,19 +636,22 @@ class Update(ValuesBase):
function.
"""
- __visit_name__ = 'update'
-
- def __init__(self,
- table,
- whereclause=None,
- values=None,
- inline=False,
- bind=None,
- prefixes=None,
- returning=None,
- return_defaults=False,
- preserve_parameter_order=False,
- **dialect_kw):
+
+ __visit_name__ = "update"
+
+ def __init__(
+ self,
+ table,
+ whereclause=None,
+ values=None,
+ inline=False,
+ bind=None,
+ prefixes=None,
+ returning=None,
+ return_defaults=False,
+ preserve_parameter_order=False,
+ **dialect_kw
+ ):
r"""Construct an :class:`.Update` object.
E.g.::
@@ -745,7 +781,7 @@ class Update(ValuesBase):
def get_children(self, **kwargs):
if self._whereclause is not None:
- return self._whereclause,
+ return (self._whereclause,)
else:
return ()
@@ -761,8 +797,9 @@ class Update(ValuesBase):
"""
if self._whereclause is not None:
- self._whereclause = and_(self._whereclause,
- _literal_as_text(whereclause))
+ self._whereclause = and_(
+ self._whereclause, _literal_as_text(whereclause)
+ )
else:
self._whereclause = _literal_as_text(whereclause)
@@ -788,15 +825,17 @@ class Delete(UpdateBase):
"""
- __visit_name__ = 'delete'
-
- def __init__(self,
- table,
- whereclause=None,
- bind=None,
- returning=None,
- prefixes=None,
- **dialect_kw):
+ __visit_name__ = "delete"
+
+ def __init__(
+ self,
+ table,
+ whereclause=None,
+ bind=None,
+ returning=None,
+ prefixes=None,
+ **dialect_kw
+ ):
"""Construct :class:`.Delete` object.
Similar functionality is available via the
@@ -847,7 +886,7 @@ class Delete(UpdateBase):
def get_children(self, **kwargs):
if self._whereclause is not None:
- return self._whereclause,
+ return (self._whereclause,)
else:
return ()
@@ -856,8 +895,9 @@ class Delete(UpdateBase):
"""Add the given WHERE clause to a newly returned delete construct."""
if self._whereclause is not None:
- self._whereclause = and_(self._whereclause,
- _literal_as_text(whereclause))
+ self._whereclause = and_(
+ self._whereclause, _literal_as_text(whereclause)
+ )
else:
self._whereclause = _literal_as_text(whereclause)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index de3b7992a..e857f2da8 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -51,9 +51,8 @@ def collate(expression, collation):
expr = _literal_as_binds(expression)
return BinaryExpression(
- expr,
- CollationClause(collation),
- operators.collate, type_=expr.type)
+ expr, CollationClause(collation), operators.collate, type_=expr.type
+ )
def between(expr, lower_bound, upper_bound, symmetric=False):
@@ -130,8 +129,6 @@ def literal(value, type_=None):
return BindParameter(None, value, type_=type_, unique=True)
-
-
def outparam(key, type_=None):
"""Create an 'OUT' parameter for usage in functions (stored procedures),
for databases which support them.
@@ -142,8 +139,7 @@ def outparam(key, type_=None):
attribute, which returns a dictionary containing the values.
"""
- return BindParameter(
- key, None, type_=type_, unique=False, isoutparam=True)
+ return BindParameter(key, None, type_=type_, unique=False, isoutparam=True)
def not_(clause):
@@ -163,7 +159,8 @@ class ClauseElement(Visitable):
expression.
"""
- __visit_name__ = 'clause'
+
+ __visit_name__ = "clause"
_annotations = {}
supports_execution = False
@@ -230,7 +227,7 @@ class ClauseElement(Visitable):
def __getstate__(self):
d = self.__dict__.copy()
- d.pop('_is_clone_of', None)
+ d.pop("_is_clone_of", None)
return d
def _annotate(self, values):
@@ -300,7 +297,8 @@ class ClauseElement(Visitable):
kwargs.update(optionaldict[0])
elif len(optionaldict) > 1:
raise exc.ArgumentError(
- "params() takes zero or one positional dictionary argument")
+ "params() takes zero or one positional dictionary argument"
+ )
def visit_bindparam(bind):
if bind.key in kwargs:
@@ -308,7 +306,8 @@ class ClauseElement(Visitable):
bind.required = False
if unique:
bind._convert_to_unique()
- return cloned_traverse(self, {}, {'bindparam': visit_bindparam})
+
+ return cloned_traverse(self, {}, {"bindparam": visit_bindparam})
def compare(self, other, **kw):
r"""Compare this ClauseElement to the given ClauseElement.
@@ -451,7 +450,7 @@ class ClauseElement(Visitable):
if util.py3k:
return str(self.compile())
else:
- return unicode(self.compile()).encode('ascii', 'backslashreplace')
+ return unicode(self.compile()).encode("ascii", "backslashreplace")
def __and__(self, other):
"""'and' at the ClauseElement level.
@@ -472,7 +471,7 @@ class ClauseElement(Visitable):
return or_(self, other)
def __invert__(self):
- if hasattr(self, 'negation_clause'):
+ if hasattr(self, "negation_clause"):
return self.negation_clause
else:
return self._negate()
@@ -481,7 +480,8 @@ class ClauseElement(Visitable):
return UnaryExpression(
self.self_group(against=operators.inv),
operator=operators.inv,
- negate=None)
+ negate=None,
+ )
def __bool__(self):
raise TypeError("Boolean value of this clause is not defined")
@@ -493,8 +493,12 @@ class ClauseElement(Visitable):
if friendly is None:
return object.__repr__(self)
else:
- return '<%s.%s at 0x%x; %s>' % (
- self.__module__, self.__class__.__name__, id(self), friendly)
+ return "<%s.%s at 0x%x; %s>" % (
+ self.__module__,
+ self.__class__.__name__,
+ id(self),
+ friendly,
+ )
class ColumnElement(operators.ColumnOperators, ClauseElement):
@@ -571,7 +575,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
"""
- __visit_name__ = 'column_element'
+ __visit_name__ = "column_element"
primary_key = False
foreign_keys = []
@@ -646,11 +650,12 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
_alt_names = ()
def self_group(self, against=None):
- if (against in (operators.and_, operators.or_, operators._asbool) and
- self.type._type_affinity
- is type_api.BOOLEANTYPE._type_affinity):
+ if (
+ against in (operators.and_, operators.or_, operators._asbool)
+ and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity
+ ):
return AsBoolean(self, operators.istrue, operators.isfalse)
- elif (against in (operators.any_op, operators.all_op)):
+ elif against in (operators.any_op, operators.all_op):
return Grouping(self)
else:
return self
@@ -675,7 +680,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
except AttributeError:
raise TypeError(
"Object %r associated with '.type' attribute "
- "is not a TypeEngine class or object" % self.type)
+ "is not a TypeEngine class or object" % self.type
+ )
else:
return comparator_factory(self)
@@ -684,10 +690,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
return getattr(self.comparator, key)
except AttributeError:
raise AttributeError(
- 'Neither %r object nor %r object has an attribute %r' % (
- type(self).__name__,
- type(self.comparator).__name__,
- key)
+ "Neither %r object nor %r object has an attribute %r"
+ % (type(self).__name__, type(self.comparator).__name__, key)
)
def operate(self, op, *other, **kwargs):
@@ -697,10 +701,14 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
return op(other, self.comparator, **kwargs)
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(None, obj,
- _compared_to_operator=operator,
- type_=type_,
- _compared_to_type=self.type, unique=True)
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ type_=type_,
+ _compared_to_type=self.type,
+ unique=True,
+ )
@property
def expression(self):
@@ -713,17 +721,18 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
@property
def _select_iterable(self):
- return (self, )
+ return (self,)
@util.memoized_property
def base_columns(self):
- return util.column_set(c for c in self.proxy_set
- if not hasattr(c, '_proxies'))
+ return util.column_set(
+ c for c in self.proxy_set if not hasattr(c, "_proxies")
+ )
@util.memoized_property
def proxy_set(self):
s = util.column_set([self])
- if hasattr(self, '_proxies'):
+ if hasattr(self, "_proxies"):
for c in self._proxies:
s.update(c.proxy_set)
return s
@@ -738,11 +747,15 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
"""Return True if the given column element compares to this one
when targeting within a result row."""
- return hasattr(other, 'name') and hasattr(self, 'name') and \
- other.name == self.name
+ return (
+ hasattr(other, "name")
+ and hasattr(self, "name")
+ and other.name == self.name
+ )
def _make_proxy(
- self, selectable, name=None, name_is_truncatable=False, **kw):
+ self, selectable, name=None, name_is_truncatable=False, **kw
+ ):
"""Create a new :class:`.ColumnElement` representing this
:class:`.ColumnElement` as it appears in the select list of a
descending selectable.
@@ -762,13 +775,12 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
key = name
co = ColumnClause(
_as_truncated(name) if name_is_truncatable else name,
- type_=getattr(self, 'type', None),
- _selectable=selectable
+ type_=getattr(self, "type", None),
+ _selectable=selectable,
)
co._proxies = [self]
if selectable._is_clone_of is not None:
- co._is_clone_of = \
- selectable._is_clone_of.columns.get(key)
+ co._is_clone_of = selectable._is_clone_of.columns.get(key)
selectable._columns[key] = co
return co
@@ -788,7 +800,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
this one via foreign key or other criterion.
"""
- to_compare = (other, )
+ to_compare = (other,)
if equivalents and other in equivalents:
to_compare = equivalents[other].union(to_compare)
@@ -838,7 +850,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
self = self._is_clone_of
return _anonymous_label(
- '%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon'))
+ "%%(%d %s)s" % (id(self), getattr(self, "name", "anon"))
)
@@ -862,18 +874,25 @@ class BindParameter(ColumnElement):
"""
- __visit_name__ = 'bindparam'
+ __visit_name__ = "bindparam"
_is_crud = False
_expanding_in_types = ()
- def __init__(self, key, value=NO_ARG, type_=None,
- unique=False, required=NO_ARG,
- quote=None, callable_=None,
- expanding=False,
- isoutparam=False,
- _compared_to_operator=None,
- _compared_to_type=None):
+ def __init__(
+ self,
+ key,
+ value=NO_ARG,
+ type_=None,
+ unique=False,
+ required=NO_ARG,
+ quote=None,
+ callable_=None,
+ expanding=False,
+ isoutparam=False,
+ _compared_to_operator=None,
+ _compared_to_type=None,
+ ):
r"""Produce a "bound expression".
The return value is an instance of :class:`.BindParameter`; this
@@ -1093,7 +1112,7 @@ class BindParameter(ColumnElement):
type_ = key.type
key = key.key
if required is NO_ARG:
- required = (value is NO_ARG and callable_ is None)
+ required = value is NO_ARG and callable_ is None
if value is NO_ARG:
value = None
@@ -1101,11 +1120,11 @@ class BindParameter(ColumnElement):
key = quoted_name(key, quote)
if unique:
- self.key = _anonymous_label('%%(%d %s)s' % (id(self), key
- or 'param'))
+ self.key = _anonymous_label(
+ "%%(%d %s)s" % (id(self), key or "param")
+ )
else:
- self.key = key or _anonymous_label('%%(%d param)s'
- % id(self))
+ self.key = key or _anonymous_label("%%(%d param)s" % id(self))
# identifying key that won't change across
# clones, used to identify the bind's logical
@@ -1114,7 +1133,7 @@ class BindParameter(ColumnElement):
# key that was passed in the first place, used to
# generate new keys
- self._orig_key = key or 'param'
+ self._orig_key = key or "param"
self.unique = unique
self.value = value
@@ -1125,9 +1144,9 @@ class BindParameter(ColumnElement):
if type_ is None:
if _compared_to_type is not None:
- self.type = \
- _compared_to_type.coerce_compared_value(
- _compared_to_operator, value)
+ self.type = _compared_to_type.coerce_compared_value(
+ _compared_to_operator, value
+ )
else:
self.type = type_api._resolve_value_to_type(value)
elif isinstance(type_, type):
@@ -1174,24 +1193,28 @@ class BindParameter(ColumnElement):
def _clone(self):
c = ClauseElement._clone(self)
if self.unique:
- c.key = _anonymous_label('%%(%d %s)s' % (id(c), c._orig_key
- or 'param'))
+ c.key = _anonymous_label(
+ "%%(%d %s)s" % (id(c), c._orig_key or "param")
+ )
return c
def _convert_to_unique(self):
if not self.unique:
self.unique = True
self.key = _anonymous_label(
- '%%(%d %s)s' % (id(self), self._orig_key or 'param'))
+ "%%(%d %s)s" % (id(self), self._orig_key or "param")
+ )
def compare(self, other, **kw):
"""Compare this :class:`BindParameter` to the given
clause."""
- return isinstance(other, BindParameter) \
- and self.type._compare_type_affinity(other.type) \
- and self.value == other.value \
+ return (
+ isinstance(other, BindParameter)
+ and self.type._compare_type_affinity(other.type)
+ and self.value == other.value
and self.callable == other.callable
+ )
def __getstate__(self):
"""execute a deferred value for serialization purposes."""
@@ -1200,13 +1223,16 @@ class BindParameter(ColumnElement):
v = self.value
if self.callable:
v = self.callable()
- d['callable'] = None
- d['value'] = v
+ d["callable"] = None
+ d["value"] = v
return d
def __repr__(self):
- return 'BindParameter(%r, %r, type_=%r)' % (self.key,
- self.value, self.type)
+ return "BindParameter(%r, %r, type_=%r)" % (
+ self.key,
+ self.value,
+ self.type,
+ )
class TypeClause(ClauseElement):
@@ -1216,7 +1242,7 @@ class TypeClause(ClauseElement):
"""
- __visit_name__ = 'typeclause'
+ __visit_name__ = "typeclause"
def __init__(self, type):
self.type = type
@@ -1242,12 +1268,12 @@ class TextClause(Executable, ClauseElement):
"""
- __visit_name__ = 'textclause'
+ __visit_name__ = "textclause"
- _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
- _execution_options = \
- Executable._execution_options.union(
- {'autocommit': PARSE_AUTOCOMMIT})
+ _bind_params_regex = re.compile(r"(?<![:\w\x5c]):(\w+)(?!:)", re.UNICODE)
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": PARSE_AUTOCOMMIT}
+ )
_is_implicitly_boolean = False
@property
@@ -1268,24 +1294,22 @@ class TextClause(Executable, ClauseElement):
_allow_label_resolve = False
- def __init__(
- self,
- text,
- bind=None):
+ def __init__(self, text, bind=None):
self._bind = bind
self._bindparams = {}
def repl(m):
self._bindparams[m.group(1)] = BindParameter(m.group(1))
- return ':%s' % m.group(1)
+ return ":%s" % m.group(1)
# scan the string and search for bind parameter names, add them
# to the list of bindparams
self.text = self._bind_params_regex.sub(repl, text)
@classmethod
- def _create_text(self, text, bind=None, bindparams=None,
- typemap=None, autocommit=None):
+ def _create_text(
+ self, text, bind=None, bindparams=None, typemap=None, autocommit=None
+ ):
r"""Construct a new :class:`.TextClause` clause, representing
a textual SQL string directly.
@@ -1428,8 +1452,10 @@ class TextClause(Executable, ClauseElement):
if typemap:
stmt = stmt.columns(**typemap)
if autocommit is not None:
- util.warn_deprecated('autocommit on text() is deprecated. '
- 'Use .execution_options(autocommit=True)')
+ util.warn_deprecated(
+ "autocommit on text() is deprecated. "
+ "Use .execution_options(autocommit=True)"
+ )
stmt = stmt.execution_options(autocommit=autocommit)
return stmt
@@ -1513,7 +1539,8 @@ class TextClause(Executable, ClauseElement):
except KeyError:
raise exc.ArgumentError(
"This text() construct doesn't define a "
- "bound parameter named %r" % bind.key)
+ "bound parameter named %r" % bind.key
+ )
else:
new_params[existing.key] = bind
@@ -1523,11 +1550,12 @@ class TextClause(Executable, ClauseElement):
except KeyError:
raise exc.ArgumentError(
"This text() construct doesn't define a "
- "bound parameter named %r" % key)
+ "bound parameter named %r" % key
+ )
else:
new_params[key] = existing._with_value(value)
- @util.dependencies('sqlalchemy.sql.selectable')
+ @util.dependencies("sqlalchemy.sql.selectable")
def columns(self, selectable, *cols, **types):
"""Turn this :class:`.TextClause` object into a :class:`.TextAsFrom`
object that can be embedded into another statement.
@@ -1629,12 +1657,14 @@ class TextClause(Executable, ClauseElement):
for col in cols
]
keyed_input_cols = [
- ColumnClause(key, type_) for key, type_ in types.items()]
+ ColumnClause(key, type_) for key, type_ in types.items()
+ ]
return selectable.TextAsFrom(
self,
positional_input_cols + keyed_input_cols,
- positional=bool(positional_input_cols) and not keyed_input_cols)
+ positional=bool(positional_input_cols) and not keyed_input_cols,
+ )
@property
def type(self):
@@ -1651,8 +1681,9 @@ class TextClause(Executable, ClauseElement):
return self
def _copy_internals(self, clone=_clone, **kw):
- self._bindparams = dict((b.key, clone(b, **kw))
- for b in self._bindparams.values())
+ self._bindparams = dict(
+ (b.key, clone(b, **kw)) for b in self._bindparams.values()
+ )
def get_children(self, **kwargs):
return list(self._bindparams.values())
@@ -1669,7 +1700,7 @@ class Null(ColumnElement):
"""
- __visit_name__ = 'null'
+ __visit_name__ = "null"
@util.memoized_property
def type(self):
@@ -1693,7 +1724,7 @@ class False_(ColumnElement):
"""
- __visit_name__ = 'false'
+ __visit_name__ = "false"
@util.memoized_property
def type(self):
@@ -1752,7 +1783,7 @@ class True_(ColumnElement):
"""
- __visit_name__ = 'true'
+ __visit_name__ = "true"
@util.memoized_property
def type(self):
@@ -1816,23 +1847,23 @@ class ClauseList(ClauseElement):
By default, is comma-separated, such as a column listing.
"""
- __visit_name__ = 'clauselist'
+
+ __visit_name__ = "clauselist"
def __init__(self, *clauses, **kwargs):
- self.operator = kwargs.pop('operator', operators.comma_op)
- self.group = kwargs.pop('group', True)
- self.group_contents = kwargs.pop('group_contents', True)
+ self.operator = kwargs.pop("operator", operators.comma_op)
+ self.group = kwargs.pop("group", True)
+ self.group_contents = kwargs.pop("group_contents", True)
text_converter = kwargs.pop(
- '_literal_as_text',
- _expression_literal_as_text)
+ "_literal_as_text", _expression_literal_as_text
+ )
if self.group_contents:
self.clauses = [
text_converter(clause).self_group(against=self.operator)
- for clause in clauses]
+ for clause in clauses
+ ]
else:
- self.clauses = [
- text_converter(clause)
- for clause in clauses]
+ self.clauses = [text_converter(clause) for clause in clauses]
self._is_implicitly_boolean = operators.is_boolean(self.operator)
def __iter__(self):
@@ -1847,8 +1878,9 @@ class ClauseList(ClauseElement):
def append(self, clause):
if self.group_contents:
- self.clauses.append(_literal_as_text(clause).
- self_group(against=self.operator))
+ self.clauses.append(
+ _literal_as_text(clause).self_group(against=self.operator)
+ )
else:
self.clauses.append(_literal_as_text(clause))
@@ -1875,14 +1907,18 @@ class ClauseList(ClauseElement):
"""
if not isinstance(other, ClauseList) and len(self.clauses) == 1:
return self.clauses[0].compare(other, **kw)
- elif isinstance(other, ClauseList) and \
- len(self.clauses) == len(other.clauses) and \
- self.operator is other.operator:
+ elif (
+ isinstance(other, ClauseList)
+ and len(self.clauses) == len(other.clauses)
+ and self.operator is other.operator
+ ):
if self.operator in (operators.and_, operators.or_):
completed = set()
for clause in self.clauses:
- for other_clause in set(other.clauses).difference(completed):
+ for other_clause in set(other.clauses).difference(
+ completed
+ ):
if clause.compare(other_clause, **kw):
completed.add(other_clause)
break
@@ -1898,11 +1934,12 @@ class ClauseList(ClauseElement):
class BooleanClauseList(ClauseList, ColumnElement):
- __visit_name__ = 'clauselist'
+ __visit_name__ = "clauselist"
def __init__(self, *arg, **kw):
raise NotImplementedError(
- "BooleanClauseList has a private constructor")
+ "BooleanClauseList has a private constructor"
+ )
@classmethod
def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
@@ -1910,8 +1947,7 @@ class BooleanClauseList(ClauseList, ColumnElement):
clauses = [
_expression_literal_as_text(clause)
- for clause in
- util.coerce_generator_arg(clauses)
+ for clause in util.coerce_generator_arg(clauses)
]
for clause in clauses:
@@ -1927,8 +1963,9 @@ class BooleanClauseList(ClauseList, ColumnElement):
elif not convert_clauses and clauses:
return clauses[0].self_group(against=operators._asbool)
- convert_clauses = [c.self_group(against=operator)
- for c in convert_clauses]
+ convert_clauses = [
+ c.self_group(against=operator) for c in convert_clauses
+ ]
self = cls.__new__(cls)
self.clauses = convert_clauses
@@ -2014,7 +2051,7 @@ class BooleanClauseList(ClauseList, ColumnElement):
@property
def _select_iterable(self):
- return (self, )
+ return (self,)
def self_group(self, against=None):
if not self.clauses:
@@ -2056,22 +2093,31 @@ class Tuple(ClauseList, ColumnElement):
clauses = [_literal_as_binds(c) for c in clauses]
self._type_tuple = [arg.type for arg in clauses]
- self.type = kw.pop('type_', self._type_tuple[0]
- if self._type_tuple else type_api.NULLTYPE)
+ self.type = kw.pop(
+ "type_",
+ self._type_tuple[0] if self._type_tuple else type_api.NULLTYPE,
+ )
super(Tuple, self).__init__(*clauses, **kw)
@property
def _select_iterable(self):
- return (self, )
+ return (self,)
def _bind_param(self, operator, obj, type_=None):
- return Tuple(*[
- BindParameter(None, o, _compared_to_operator=operator,
- _compared_to_type=compared_to_type, unique=True,
- type_=type_)
- for o, compared_to_type in zip(obj, self._type_tuple)
- ]).self_group()
+ return Tuple(
+ *[
+ BindParameter(
+ None,
+ o,
+ _compared_to_operator=operator,
+ _compared_to_type=compared_to_type,
+ unique=True,
+ type_=type_,
+ )
+ for o, compared_to_type in zip(obj, self._type_tuple)
+ ]
+ ).self_group()
class Case(ColumnElement):
@@ -2101,7 +2147,7 @@ class Case(ColumnElement):
"""
- __visit_name__ = 'case'
+ __visit_name__ = "case"
def __init__(self, whens, value=None, else_=None):
r"""Produce a ``CASE`` expression.
@@ -2231,13 +2277,13 @@ class Case(ColumnElement):
if value is not None:
whenlist = [
- (_literal_as_binds(c).self_group(),
- _literal_as_binds(r)) for (c, r) in whens
+ (_literal_as_binds(c).self_group(), _literal_as_binds(r))
+ for (c, r) in whens
]
else:
whenlist = [
- (_no_literals(c).self_group(),
- _literal_as_binds(r)) for (c, r) in whens
+ (_no_literals(c).self_group(), _literal_as_binds(r))
+ for (c, r) in whens
]
if whenlist:
@@ -2260,8 +2306,7 @@ class Case(ColumnElement):
def _copy_internals(self, clone=_clone, **kw):
if self.value is not None:
self.value = clone(self.value, **kw)
- self.whens = [(clone(x, **kw), clone(y, **kw))
- for x, y in self.whens]
+ self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens]
if self.else_ is not None:
self.else_ = clone(self.else_, **kw)
@@ -2276,8 +2321,9 @@ class Case(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(*[x._from_objects for x in
- self.get_children()]))
+ return list(
+ itertools.chain(*[x._from_objects for x in self.get_children()])
+ )
def literal_column(text, type_=None):
@@ -2333,7 +2379,7 @@ class Cast(ColumnElement):
"""
- __visit_name__ = 'cast'
+ __visit_name__ = "cast"
def __init__(self, expression, type_):
"""Produce a ``CAST`` expression.
@@ -2416,7 +2462,7 @@ class TypeCoerce(ColumnElement):
"""
- __visit_name__ = 'type_coerce'
+ __visit_name__ = "type_coerce"
def __init__(self, expression, type_):
"""Associate a SQL expression with a particular type, without rendering
@@ -2484,10 +2530,10 @@ class TypeCoerce(ColumnElement):
def _copy_internals(self, clone=_clone, **kw):
self.clause = clone(self.clause, **kw)
- self.__dict__.pop('typed_expression', None)
+ self.__dict__.pop("typed_expression", None)
def get_children(self, **kwargs):
- return self.clause,
+ return (self.clause,)
@property
def _from_objects(self):
@@ -2506,7 +2552,7 @@ class TypeCoerce(ColumnElement):
class Extract(ColumnElement):
"""Represent a SQL EXTRACT clause, ``extract(field FROM expr)``."""
- __visit_name__ = 'extract'
+ __visit_name__ = "extract"
def __init__(self, field, expr, **kwargs):
"""Return a :class:`.Extract` construct.
@@ -2524,7 +2570,7 @@ class Extract(ColumnElement):
self.expr = clone(self.expr, **kw)
def get_children(self, **kwargs):
- return self.expr,
+ return (self.expr,)
@property
def _from_objects(self):
@@ -2543,7 +2589,8 @@ class _label_reference(ColumnElement):
within an OVER clause.
"""
- __visit_name__ = 'label_reference'
+
+ __visit_name__ = "label_reference"
def __init__(self, element):
self.element = element
@@ -2557,7 +2604,7 @@ class _label_reference(ColumnElement):
class _textual_label_reference(ColumnElement):
- __visit_name__ = 'textual_label_reference'
+ __visit_name__ = "textual_label_reference"
def __init__(self, element):
self.element = element
@@ -2580,14 +2627,23 @@ class UnaryExpression(ColumnElement):
:func:`.nullsfirst` and :func:`.nullslast`.
"""
- __visit_name__ = 'unary'
- def __init__(self, element, operator=None, modifier=None,
- type_=None, negate=None, wraps_column_expression=False):
+ __visit_name__ = "unary"
+
+ def __init__(
+ self,
+ element,
+ operator=None,
+ modifier=None,
+ type_=None,
+ negate=None,
+ wraps_column_expression=False,
+ ):
self.operator = operator
self.modifier = modifier
self.element = element.self_group(
- against=self.operator or self.modifier)
+ against=self.operator or self.modifier
+ )
self.type = type_api.to_instance(type_)
self.negate = negate
self.wraps_column_expression = wraps_column_expression
@@ -2633,7 +2689,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.nullsfirst_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_nullslast(cls, column):
@@ -2675,7 +2732,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.nullslast_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_desc(cls, column):
@@ -2715,7 +2773,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.desc_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_asc(cls, column):
@@ -2754,7 +2813,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.asc_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_distinct(cls, expr):
@@ -2794,8 +2854,11 @@ class UnaryExpression(ColumnElement):
"""
expr = _literal_as_binds(expr)
return UnaryExpression(
- expr, operator=operators.distinct_op,
- type_=expr.type, wraps_column_expression=False)
+ expr,
+ operator=operators.distinct_op,
+ type_=expr.type,
+ wraps_column_expression=False,
+ )
@property
def _order_by_label_element(self):
@@ -2812,17 +2875,17 @@ class UnaryExpression(ColumnElement):
self.element = clone(self.element, **kw)
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
def compare(self, other, **kw):
"""Compare this :class:`UnaryExpression` against the given
:class:`.ClauseElement`."""
return (
- isinstance(other, UnaryExpression) and
- self.operator == other.operator and
- self.modifier == other.modifier and
- self.element.compare(other.element, **kw)
+ isinstance(other, UnaryExpression)
+ and self.operator == other.operator
+ and self.modifier == other.modifier
+ and self.element.compare(other.element, **kw)
)
def _negate(self):
@@ -2833,14 +2896,16 @@ class UnaryExpression(ColumnElement):
negate=self.operator,
modifier=self.modifier,
type_=self.type,
- wraps_column_expression=self.wraps_column_expression)
+ wraps_column_expression=self.wraps_column_expression,
+ )
elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
return UnaryExpression(
self.self_group(against=operators.inv),
operator=operators.inv,
type_=type_api.BOOLEANTYPE,
wraps_column_expression=self.wraps_column_expression,
- negate=None)
+ negate=None,
+ )
else:
return ClauseElement._negate(self)
@@ -2860,6 +2925,7 @@ class CollectionAggregate(UnaryExpression):
MySQL, they only work for subqueries.
"""
+
@classmethod
def _create_any(cls, expr):
"""Produce an ANY expression.
@@ -2883,12 +2949,15 @@ class CollectionAggregate(UnaryExpression):
expr = _literal_as_binds(expr)
- if expr.is_selectable and hasattr(expr, 'as_scalar'):
+ if expr.is_selectable and hasattr(expr, "as_scalar"):
expr = expr.as_scalar()
expr = expr.self_group()
return CollectionAggregate(
- expr, operator=operators.any_op,
- type_=type_api.NULLTYPE, wraps_column_expression=False)
+ expr,
+ operator=operators.any_op,
+ type_=type_api.NULLTYPE,
+ wraps_column_expression=False,
+ )
@classmethod
def _create_all(cls, expr):
@@ -2912,12 +2981,15 @@ class CollectionAggregate(UnaryExpression):
"""
expr = _literal_as_binds(expr)
- if expr.is_selectable and hasattr(expr, 'as_scalar'):
+ if expr.is_selectable and hasattr(expr, "as_scalar"):
expr = expr.as_scalar()
expr = expr.self_group()
return CollectionAggregate(
- expr, operator=operators.all_op,
- type_=type_api.NULLTYPE, wraps_column_expression=False)
+ expr,
+ operator=operators.all_op,
+ type_=type_api.NULLTYPE,
+ wraps_column_expression=False,
+ )
# operate and reverse_operate are hardwired to
# dispatch onto the type comparator directly, so that we can
@@ -2925,19 +2997,20 @@ class CollectionAggregate(UnaryExpression):
def operate(self, op, *other, **kwargs):
if not operators.is_comparison(op):
raise exc.ArgumentError(
- "Only comparison operators may be used with ANY/ALL")
- kwargs['reverse'] = True
+ "Only comparison operators may be used with ANY/ALL"
+ )
+ kwargs["reverse"] = True
return self.comparator.operate(operators.mirror(op), *other, **kwargs)
def reverse_operate(self, op, other, **kwargs):
# comparison operators should never call reverse_operate
assert not operators.is_comparison(op)
raise exc.ArgumentError(
- "Only comparison operators may be used with ANY/ALL")
+ "Only comparison operators may be used with ANY/ALL"
+ )
class AsBoolean(UnaryExpression):
-
def __init__(self, element, operator, negate):
self.element = element
self.type = type_api.BOOLEANTYPE
@@ -2971,7 +3044,7 @@ class BinaryExpression(ColumnElement):
"""
- __visit_name__ = 'binary'
+ __visit_name__ = "binary"
_is_implicitly_boolean = True
"""Indicates that any database will know this is a boolean expression
@@ -2979,8 +3052,9 @@ class BinaryExpression(ColumnElement):
"""
- def __init__(self, left, right, operator, type_=None,
- negate=None, modifiers=None):
+ def __init__(
+ self, left, right, operator, type_=None, negate=None, modifiers=None
+ ):
# allow compatibility with libraries that
# refer to BinaryExpression directly and pass strings
if isinstance(operator, util.string_types):
@@ -3026,15 +3100,15 @@ class BinaryExpression(ColumnElement):
given :class:`BinaryExpression`."""
return (
- isinstance(other, BinaryExpression) and
- self.operator == other.operator and
- (
- self.left.compare(other.left, **kw) and
- self.right.compare(other.right, **kw) or
- (
- operators.is_commutative(self.operator) and
- self.left.compare(other.right, **kw) and
- self.right.compare(other.left, **kw)
+ isinstance(other, BinaryExpression)
+ and self.operator == other.operator
+ and (
+ self.left.compare(other.left, **kw)
+ and self.right.compare(other.right, **kw)
+ or (
+ operators.is_commutative(self.operator)
+ and self.left.compare(other.right, **kw)
+ and self.right.compare(other.left, **kw)
)
)
)
@@ -3053,7 +3127,8 @@ class BinaryExpression(ColumnElement):
self.negate,
negate=self.operator,
type_=self.type,
- modifiers=self.modifiers)
+ modifiers=self.modifiers,
+ )
else:
return super(BinaryExpression, self)._negate()
@@ -3065,7 +3140,8 @@ class Slice(ColumnElement):
may be interpreted by specific dialects, e.g. PostgreSQL.
"""
- __visit_name__ = 'slice'
+
+ __visit_name__ = "slice"
def __init__(self, start, stop, step):
self.start = start
@@ -3081,17 +3157,18 @@ class Slice(ColumnElement):
class IndexExpression(BinaryExpression):
"""Represent the class of expressions that are like an "index" operation.
"""
+
pass
class Grouping(ColumnElement):
"""Represent a grouping within a column expression"""
- __visit_name__ = 'grouping'
+ __visit_name__ = "grouping"
def __init__(self, element):
self.element = element
- self.type = getattr(element, 'type', type_api.NULLTYPE)
+ self.type = getattr(element, "type", type_api.NULLTYPE)
def self_group(self, against=None):
return self
@@ -3106,13 +3183,13 @@ class Grouping(ColumnElement):
@property
def _label(self):
- return getattr(self.element, '_label', None) or self.anon_label
+ return getattr(self.element, "_label", None) or self.anon_label
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
@property
def _from_objects(self):
@@ -3122,15 +3199,16 @@ class Grouping(ColumnElement):
return getattr(self.element, attr)
def __getstate__(self):
- return {'element': self.element, 'type': self.type}
+ return {"element": self.element, "type": self.type}
def __setstate__(self, state):
- self.element = state['element']
- self.type = state['type']
+ self.element = state["element"]
+ self.type = state["type"]
def compare(self, other, **kw):
- return isinstance(other, Grouping) and \
- self.element.compare(other.element)
+ return isinstance(other, Grouping) and self.element.compare(
+ other.element
+ )
RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED")
@@ -3147,14 +3225,15 @@ class Over(ColumnElement):
backends.
"""
- __visit_name__ = 'over'
+
+ __visit_name__ = "over"
order_by = None
partition_by = None
def __init__(
- self, element, partition_by=None,
- order_by=None, range_=None, rows=None):
+ self, element, partition_by=None, order_by=None, range_=None, rows=None
+ ):
"""Produce an :class:`.Over` object against a function.
Used against aggregate or so-called "window" functions,
@@ -3237,17 +3316,20 @@ class Over(ColumnElement):
if order_by is not None:
self.order_by = ClauseList(
*util.to_list(order_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
if partition_by is not None:
self.partition_by = ClauseList(
*util.to_list(partition_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
if range_:
self.range_ = self._interpret_range(range_)
if rows:
raise exc.ArgumentError(
- "'range_' and 'rows' are mutually exclusive")
+ "'range_' and 'rows' are mutually exclusive"
+ )
else:
self.rows = None
elif rows:
@@ -3267,7 +3349,8 @@ class Over(ColumnElement):
lower = int(range_[0])
except ValueError:
raise exc.ArgumentError(
- "Integer or None expected for range value")
+ "Integer or None expected for range value"
+ )
else:
if lower == 0:
lower = RANGE_CURRENT
@@ -3279,7 +3362,8 @@ class Over(ColumnElement):
upper = int(range_[1])
except ValueError:
raise exc.ArgumentError(
- "Integer or None expected for range value")
+ "Integer or None expected for range value"
+ )
else:
if upper == 0:
upper = RANGE_CURRENT
@@ -3303,9 +3387,11 @@ class Over(ColumnElement):
return self.element.type
def get_children(self, **kwargs):
- return [c for c in
- (self.element, self.partition_by, self.order_by)
- if c is not None]
+ return [
+ c
+ for c in (self.element, self.partition_by, self.order_by)
+ if c is not None
+ ]
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
@@ -3316,11 +3402,15 @@ class Over(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(
- *[c._from_objects for c in
- (self.element, self.partition_by, self.order_by)
- if c is not None]
- ))
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.element, self.partition_by, self.order_by)
+ if c is not None
+ ]
+ )
+ )
class WithinGroup(ColumnElement):
@@ -3339,7 +3429,8 @@ class WithinGroup(ColumnElement):
``None``, the function's ``.type`` is used.
"""
- __visit_name__ = 'withingroup'
+
+ __visit_name__ = "withingroup"
order_by = None
@@ -3383,7 +3474,8 @@ class WithinGroup(ColumnElement):
if order_by is not None:
self.order_by = ClauseList(
*util.to_list(order_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
def over(self, partition_by=None, order_by=None, range_=None, rows=None):
"""Produce an OVER clause against this :class:`.WithinGroup`
@@ -3394,8 +3486,12 @@ class WithinGroup(ColumnElement):
"""
return Over(
- self, partition_by=partition_by, order_by=order_by,
- range_=range_, rows=rows)
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ range_=range_,
+ rows=rows,
+ )
@util.memoized_property
def type(self):
@@ -3406,9 +3502,7 @@ class WithinGroup(ColumnElement):
return self.element.type
def get_children(self, **kwargs):
- return [c for c in
- (self.element, self.order_by)
- if c is not None]
+ return [c for c in (self.element, self.order_by) if c is not None]
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
@@ -3417,11 +3511,15 @@ class WithinGroup(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(
- *[c._from_objects for c in
- (self.element, self.order_by)
- if c is not None]
- ))
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.element, self.order_by)
+ if c is not None
+ ]
+ )
+ )
class FunctionFilter(ColumnElement):
@@ -3443,7 +3541,8 @@ class FunctionFilter(ColumnElement):
:meth:`.FunctionElement.filter`
"""
- __visit_name__ = 'funcfilter'
+
+ __visit_name__ = "funcfilter"
criterion = None
@@ -3515,17 +3614,19 @@ class FunctionFilter(ColumnElement):
"""
return Over(
- self, partition_by=partition_by, order_by=order_by,
- range_=range_, rows=rows)
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ range_=range_,
+ rows=rows,
+ )
@util.memoized_property
def type(self):
return self.func.type
def get_children(self, **kwargs):
- return [c for c in
- (self.func, self.criterion)
- if c is not None]
+ return [c for c in (self.func, self.criterion) if c is not None]
def _copy_internals(self, clone=_clone, **kw):
self.func = clone(self.func, **kw)
@@ -3534,10 +3635,15 @@ class FunctionFilter(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(
- *[c._from_objects for c in (self.func, self.criterion)
- if c is not None]
- ))
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.func, self.criterion)
+ if c is not None
+ ]
+ )
+ )
class Label(ColumnElement):
@@ -3548,7 +3654,7 @@ class Label(ColumnElement):
"""
- __visit_name__ = 'label'
+ __visit_name__ = "label"
def __init__(self, name, element, type_=None):
"""Return a :class:`Label` object for the
@@ -3577,7 +3683,7 @@ class Label(ColumnElement):
self._resolve_label = self.name
else:
self.name = _anonymous_label(
- '%%(%d %s)s' % (id(self), getattr(element, 'name', 'anon'))
+ "%%(%d %s)s" % (id(self), getattr(element, "name", "anon"))
)
self.key = self._label = self._key_label = self.name
@@ -3603,7 +3709,7 @@ class Label(ColumnElement):
@util.memoized_property
def type(self):
return type_api.to_instance(
- self._type or getattr(self._element, 'type', None)
+ self._type or getattr(self._element, "type", None)
)
@util.memoized_property
@@ -3619,9 +3725,7 @@ class Label(ColumnElement):
def _apply_to_inner(self, fn, *arg, **kw):
sub_element = fn(*arg, **kw)
if sub_element is not self._element:
- return Label(self.name,
- sub_element,
- type_=self._type)
+ return Label(self.name, sub_element, type_=self._type)
else:
return self
@@ -3634,16 +3738,16 @@ class Label(ColumnElement):
return self.element.foreign_keys
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
self._element = clone(self._element, **kw)
- self.__dict__.pop('element', None)
- self.__dict__.pop('_allow_label_resolve', None)
+ self.__dict__.pop("element", None)
+ self.__dict__.pop("_allow_label_resolve", None)
if anonymize_labels:
self.name = self._resolve_label = _anonymous_label(
- '%%(%d %s)s' % (
- id(self), getattr(self.element, 'name', 'anon'))
+ "%%(%d %s)s"
+ % (id(self), getattr(self.element, "name", "anon"))
)
self.key = self._label = self._key_label = self.name
@@ -3652,8 +3756,9 @@ class Label(ColumnElement):
return self.element._from_objects
def _make_proxy(self, selectable, name=None, **kw):
- e = self.element._make_proxy(selectable,
- name=name if name else self.name)
+ e = self.element._make_proxy(
+ selectable, name=name if name else self.name
+ )
e._proxies.append(self)
if self._type is not None:
e.type = self._type
@@ -3694,7 +3799,8 @@ class ColumnClause(Immutable, ColumnElement):
:class:`.Column`
"""
- __visit_name__ = 'column'
+
+ __visit_name__ = "column"
onupdate = default = server_default = server_onupdate = None
@@ -3792,25 +3898,33 @@ class ColumnClause(Immutable, ColumnElement):
self.is_literal = is_literal
def _compare_name_for_result(self, other):
- if self.is_literal or \
- self.table is None or self.table._textual or \
- not hasattr(other, 'proxy_set') or (
- isinstance(other, ColumnClause) and
- (other.is_literal or
- other.table is None or
- other.table._textual)
- ):
- return (hasattr(other, 'name') and self.name == other.name) or \
- (hasattr(other, '_label') and self._label == other._label)
+ if (
+ self.is_literal
+ or self.table is None
+ or self.table._textual
+ or not hasattr(other, "proxy_set")
+ or (
+ isinstance(other, ColumnClause)
+ and (
+ other.is_literal
+ or other.table is None
+ or other.table._textual
+ )
+ )
+ ):
+ return (hasattr(other, "name") and self.name == other.name) or (
+ hasattr(other, "_label") and self._label == other._label
+ )
else:
return other.proxy_set.intersection(self.proxy_set)
def _get_table(self):
- return self.__dict__['table']
+ return self.__dict__["table"]
def _set_table(self, table):
self._memoized_property.expire_instance(self)
- self.__dict__['table'] = table
+ self.__dict__["table"] = table
+
table = property(_get_table, _set_table)
@_memoized_property
@@ -3826,7 +3940,7 @@ class ColumnClause(Immutable, ColumnElement):
if util.py3k:
return self.name
else:
- return self.name.encode('ascii', 'backslashreplace')
+ return self.name.encode("ascii", "backslashreplace")
@_memoized_property
def _key_label(self):
@@ -3850,9 +3964,8 @@ class ColumnClause(Immutable, ColumnElement):
return None
elif t is not None and t.named_with_column:
- if getattr(t, 'schema', None):
- label = t.schema.replace('.', '_') + "_" + \
- t.name + "_" + name
+ if getattr(t, "schema", None):
+ label = t.schema.replace(".", "_") + "_" + t.name + "_" + name
else:
label = t.name + "_" + name
@@ -3884,31 +3997,39 @@ class ColumnClause(Immutable, ColumnElement):
return name
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(self.key, obj,
- _compared_to_operator=operator,
- _compared_to_type=self.type,
- type_=type_,
- unique=True)
-
- def _make_proxy(self, selectable, name=None, attach=True,
- name_is_truncatable=False, **kw):
+ return BindParameter(
+ self.key,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ )
+
+ def _make_proxy(
+ self,
+ selectable,
+ name=None,
+ attach=True,
+ name_is_truncatable=False,
+ **kw
+ ):
# propagate the "is_literal" flag only if we are keeping our name,
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
c = self._constructor(
- _as_truncated(name or self.name) if
- name_is_truncatable else
- (name or self.name),
+ _as_truncated(name or self.name)
+ if name_is_truncatable
+ else (name or self.name),
type_=self.type,
_selectable=selectable,
- is_literal=is_literal
+ is_literal=is_literal,
)
if name is None:
c.key = self.key
c._proxies = [self]
if selectable._is_clone_of is not None:
- c._is_clone_of = \
- selectable._is_clone_of.columns.get(c.key)
+ c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
if attach:
selectable._columns[c.key] = c
@@ -3924,24 +4045,25 @@ class CollationClause(ColumnElement):
class _IdentifiedClause(Executable, ClauseElement):
- __visit_name__ = 'identified'
- _execution_options = \
- Executable._execution_options.union({'autocommit': False})
+ __visit_name__ = "identified"
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": False}
+ )
def __init__(self, ident):
self.ident = ident
class SavepointClause(_IdentifiedClause):
- __visit_name__ = 'savepoint'
+ __visit_name__ = "savepoint"
class RollbackToSavepointClause(_IdentifiedClause):
- __visit_name__ = 'rollback_to_savepoint'
+ __visit_name__ = "rollback_to_savepoint"
class ReleaseSavepointClause(_IdentifiedClause):
- __visit_name__ = 'release_savepoint'
+ __visit_name__ = "release_savepoint"
class quoted_name(util.MemoizedSlots, util.text_type):
@@ -3992,7 +4114,7 @@ class quoted_name(util.MemoizedSlots, util.text_type):
"""
- __slots__ = 'quote', 'lower', 'upper'
+ __slots__ = "quote", "lower", "upper"
def __new__(cls, value, quote):
if value is None:
@@ -4026,9 +4148,9 @@ class quoted_name(util.MemoizedSlots, util.text_type):
return util.text_type(self).upper()
def __repr__(self):
- backslashed = self.encode('ascii', 'backslashreplace')
+ backslashed = self.encode("ascii", "backslashreplace")
if not util.py2k:
- backslashed = backslashed.decode('ascii')
+ backslashed = backslashed.decode("ascii")
return "'%s'" % backslashed
@@ -4094,6 +4216,7 @@ class conv(_truncated_label):
:ref:`constraint_naming_conventions`
"""
+
__slots__ = ()
@@ -4102,6 +4225,7 @@ class _defer_name(_truncated_label):
generation.
"""
+
__slots__ = ()
def __new__(cls, value):
@@ -4113,13 +4237,15 @@ class _defer_name(_truncated_label):
return super(_defer_name, cls).__new__(cls, value)
def __reduce__(self):
- return self.__class__, (util.text_type(self), )
+ return self.__class__, (util.text_type(self),)
class _defer_none_name(_defer_name):
"""indicate a 'deferred' name that was ultimately the value None."""
+
__slots__ = ()
+
_NONE_NAME = _defer_none_name("_unnamed_")
# for backwards compatibility in case
@@ -4138,15 +4264,15 @@ class _anonymous_label(_truncated_label):
def __add__(self, other):
return _anonymous_label(
quoted_name(
- util.text_type.__add__(self, util.text_type(other)),
- self.quote)
+ util.text_type.__add__(self, util.text_type(other)), self.quote
+ )
)
def __radd__(self, other):
return _anonymous_label(
quoted_name(
- util.text_type.__add__(util.text_type(other), self),
- self.quote)
+ util.text_type.__add__(util.text_type(other), self), self.quote
+ )
)
def apply_map(self, map_):
@@ -4206,20 +4332,23 @@ def _cloned_intersection(a, b):
"""
all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
- return set(elem for elem in a
- if all_overlap.intersection(elem._cloned_set))
+ return set(
+ elem for elem in a if all_overlap.intersection(elem._cloned_set)
+ )
def _cloned_difference(a, b):
all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
- return set(elem for elem in a
- if not all_overlap.intersection(elem._cloned_set))
+ return set(
+ elem for elem in a if not all_overlap.intersection(elem._cloned_set)
+ )
@util.dependencies("sqlalchemy.sql.functions")
def _labeled(functions, element):
- if not hasattr(element, 'name') or \
- isinstance(element, functions.FunctionElement):
+ if not hasattr(element, "name") or isinstance(
+ element, functions.FunctionElement
+ ):
return element.label(None)
else:
return element
@@ -4235,7 +4364,7 @@ def _find_columns(clause):
"""locate Column objects within the given expression."""
cols = util.column_set()
- traverse(clause, {}, {'column': cols.add})
+ traverse(clause, {}, {"column": cols.add})
return cols
@@ -4253,7 +4382,7 @@ def _find_columns(clause):
def _column_as_key(element):
if isinstance(element, util.string_types):
return element
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
try:
return element.key
@@ -4262,7 +4391,7 @@ def _column_as_key(element):
def _clause_element_as_expr(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
return element.__clause_element__()
else:
return element
@@ -4272,7 +4401,7 @@ def _literal_as_label_reference(element):
if isinstance(element, util.string_types):
return _textual_label_reference(element)
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
element = element.__clause_element__()
return _literal_as_text(element)
@@ -4282,11 +4411,13 @@ def _literal_and_labels_as_label_reference(element):
if isinstance(element, util.string_types):
return _textual_label_reference(element)
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
element = element.__clause_element__()
- if isinstance(element, ColumnElement) and \
- element._order_by_label_element is not None:
+ if (
+ isinstance(element, ColumnElement)
+ and element._order_by_label_element is not None
+ ):
return _label_reference(element)
else:
return _literal_as_text(element)
@@ -4299,14 +4430,15 @@ def _expression_literal_as_text(element):
def _literal_as_text(element, warn=False):
if isinstance(element, Visitable):
return element
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif isinstance(element, util.string_types):
if warn:
util.warn_limited(
"Textual SQL expression %(expr)r should be "
"explicitly declared as text(%(expr)r)",
- {"expr": util.ellipses_string(element)})
+ {"expr": util.ellipses_string(element)},
+ )
return TextClause(util.text_type(element))
elif isinstance(element, (util.NoneType, bool)):
@@ -4319,20 +4451,23 @@ def _literal_as_text(element, warn=False):
def _no_literals(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif not isinstance(element, Visitable):
- raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' "
- "function to indicate a SQL expression "
- "literal, or 'literal()' to indicate a "
- "bound value." % (element, ))
+ raise exc.ArgumentError(
+ "Ambiguous literal: %r. Use the 'text()' "
+ "function to indicate a SQL expression "
+ "literal, or 'literal()' to indicate a "
+ "bound value." % (element,)
+ )
else:
return element
def _is_literal(element):
- return not isinstance(element, Visitable) and \
- not hasattr(element, '__clause_element__')
+ return not isinstance(element, Visitable) and not hasattr(
+ element, "__clause_element__"
+ )
def _only_column_elements_or_none(element, name):
@@ -4343,17 +4478,18 @@ def _only_column_elements_or_none(element, name):
def _only_column_elements(element, name):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
if not isinstance(element, ColumnElement):
raise exc.ArgumentError(
"Column-based expression object expected for argument "
- "'%s'; got: '%s', type %s" % (name, element, type(element)))
+ "'%s'; got: '%s', type %s" % (name, element, type(element))
+ )
return element
def _literal_as_binds(element, name=None, type_=None):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif not isinstance(element, Visitable):
if element is None:
@@ -4363,13 +4499,14 @@ def _literal_as_binds(element, name=None, type_=None):
else:
return element
-_guess_straight_column = re.compile(r'^\w\S*$', re.I)
+
+_guess_straight_column = re.compile(r"^\w\S*$", re.I)
def _interpret_as_column_or_from(element):
if isinstance(element, Visitable):
return element
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
return element.__clause_element__()
insp = inspection.inspect(element, raiseerr=False)
@@ -4399,11 +4536,11 @@ def _interpret_as_column_or_from(element):
{
"column": util.ellipses_string(element),
"literal_column": "literal_column"
- if guess_is_literal else "column"
- })
- return ColumnClause(
- element,
- is_literal=guess_is_literal)
+ if guess_is_literal
+ else "column",
+ },
+ )
+ return ColumnClause(element, is_literal=guess_is_literal)
def _const_expr(element):
@@ -4416,9 +4553,7 @@ def _const_expr(element):
elif element is True:
return True_()
else:
- raise exc.ArgumentError(
- "Expected None, False, or True"
- )
+ raise exc.ArgumentError("Expected None, False, or True")
def _type_from_args(args):
@@ -4429,18 +4564,15 @@ def _type_from_args(args):
return type_api.NULLTYPE
-def _corresponding_column_or_error(fromclause, column,
- require_embedded=False):
- c = fromclause.corresponding_column(column,
- require_embedded=require_embedded)
+def _corresponding_column_or_error(fromclause, column, require_embedded=False):
+ c = fromclause.corresponding_column(
+ column, require_embedded=require_embedded
+ )
if c is None:
raise exc.InvalidRequestError(
"Given column '%s', attached to table '%s', "
"failed to locate a corresponding column from table '%s'"
- %
- (column,
- getattr(column, 'table', None),
- fromclause.description)
+ % (column, getattr(column, "table", None), fromclause.description)
)
return c
@@ -4449,7 +4581,7 @@ class AnnotatedColumnElement(Annotated):
def __init__(self, element, values):
Annotated.__init__(self, element, values)
ColumnElement.comparator._reset(self)
- for attr in ('name', 'key', 'table'):
+ for attr in ("name", "key", "table"):
if self.__dict__.get(attr, False) is None:
self.__dict__.pop(attr)
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index b69b6ee8c..aab9f46d4 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -15,43 +15,142 @@ class.
"""
__all__ = [
- 'Alias', 'any_', 'all_', 'ClauseElement', 'ColumnCollection', 'ColumnElement',
- 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 'Lateral',
- 'Select',
- 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', 'between',
- 'bindparam', 'case', 'cast', 'column', 'delete', 'desc', 'distinct',
- 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
- 'collate', 'insert', 'intersect', 'intersect_all', 'join', 'label',
- 'lateral', 'literal', 'literal_column', 'not_', 'null', 'nullsfirst',
- 'nullslast',
- 'or_', 'outparam', 'outerjoin', 'over', 'select', 'subquery',
- 'table', 'text',
- 'tuple_', 'type_coerce', 'quoted_name', 'union', 'union_all', 'update',
- 'within_group',
- 'TableSample', 'tablesample']
+ "Alias",
+ "any_",
+ "all_",
+ "ClauseElement",
+ "ColumnCollection",
+ "ColumnElement",
+ "CompoundSelect",
+ "Delete",
+ "FromClause",
+ "Insert",
+ "Join",
+ "Lateral",
+ "Select",
+ "Selectable",
+ "TableClause",
+ "Update",
+ "alias",
+ "and_",
+ "asc",
+ "between",
+ "bindparam",
+ "case",
+ "cast",
+ "column",
+ "delete",
+ "desc",
+ "distinct",
+ "except_",
+ "except_all",
+ "exists",
+ "extract",
+ "func",
+ "modifier",
+ "collate",
+ "insert",
+ "intersect",
+ "intersect_all",
+ "join",
+ "label",
+ "lateral",
+ "literal",
+ "literal_column",
+ "not_",
+ "null",
+ "nullsfirst",
+ "nullslast",
+ "or_",
+ "outparam",
+ "outerjoin",
+ "over",
+ "select",
+ "subquery",
+ "table",
+ "text",
+ "tuple_",
+ "type_coerce",
+ "quoted_name",
+ "union",
+ "union_all",
+ "update",
+ "within_group",
+ "TableSample",
+ "tablesample",
+]
from .visitors import Visitable
from .functions import func, modifier, FunctionElement, Function
from ..util.langhelpers import public_factory
-from .elements import ClauseElement, ColumnElement,\
- BindParameter, CollectionAggregate, UnaryExpression, BooleanClauseList, \
- Label, Cast, Case, ColumnClause, TextClause, Over, Null, \
- True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \
- Grouping, WithinGroup, not_, quoted_name, \
- collate, literal_column, between,\
- literal, outparam, TypeCoerce, ClauseList, FunctionFilter
+from .elements import (
+ ClauseElement,
+ ColumnElement,
+ BindParameter,
+ CollectionAggregate,
+ UnaryExpression,
+ BooleanClauseList,
+ Label,
+ Cast,
+ Case,
+ ColumnClause,
+ TextClause,
+ Over,
+ Null,
+ True_,
+ False_,
+ BinaryExpression,
+ Tuple,
+ TypeClause,
+ Extract,
+ Grouping,
+ WithinGroup,
+ not_,
+ quoted_name,
+ collate,
+ literal_column,
+ between,
+ literal,
+ outparam,
+ TypeCoerce,
+ ClauseList,
+ FunctionFilter,
+)
-from .elements import SavepointClause, RollbackToSavepointClause, \
- ReleaseSavepointClause
+from .elements import (
+ SavepointClause,
+ RollbackToSavepointClause,
+ ReleaseSavepointClause,
+)
-from .base import ColumnCollection, Generative, Executable, \
- PARSE_AUTOCOMMIT
+from .base import ColumnCollection, Generative, Executable, PARSE_AUTOCOMMIT
-from .selectable import Alias, Join, Select, Selectable, TableClause, \
- CompoundSelect, CTE, FromClause, FromGrouping, Lateral, SelectBase, \
- alias, GenerativeSelect, subquery, HasCTE, HasPrefixes, HasSuffixes, \
- lateral, Exists, ScalarSelect, TextAsFrom, TableSample, tablesample
+from .selectable import (
+ Alias,
+ Join,
+ Select,
+ Selectable,
+ TableClause,
+ CompoundSelect,
+ CTE,
+ FromClause,
+ FromGrouping,
+ Lateral,
+ SelectBase,
+ alias,
+ GenerativeSelect,
+ subquery,
+ HasCTE,
+ HasPrefixes,
+ HasSuffixes,
+ lateral,
+ Exists,
+ ScalarSelect,
+ TextAsFrom,
+ TableSample,
+ tablesample,
+)
from .dml import Insert, Update, Delete, UpdateBase, ValuesBase
@@ -79,23 +178,30 @@ extract = public_factory(Extract, ".expression.extract")
tuple_ = public_factory(Tuple, ".expression.tuple_")
except_ = public_factory(CompoundSelect._create_except, ".expression.except_")
except_all = public_factory(
- CompoundSelect._create_except_all, ".expression.except_all")
+ CompoundSelect._create_except_all, ".expression.except_all"
+)
intersect = public_factory(
- CompoundSelect._create_intersect, ".expression.intersect")
+ CompoundSelect._create_intersect, ".expression.intersect"
+)
intersect_all = public_factory(
- CompoundSelect._create_intersect_all, ".expression.intersect_all")
+ CompoundSelect._create_intersect_all, ".expression.intersect_all"
+)
union = public_factory(CompoundSelect._create_union, ".expression.union")
union_all = public_factory(
- CompoundSelect._create_union_all, ".expression.union_all")
+ CompoundSelect._create_union_all, ".expression.union_all"
+)
exists = public_factory(Exists, ".expression.exists")
nullsfirst = public_factory(
- UnaryExpression._create_nullsfirst, ".expression.nullsfirst")
+ UnaryExpression._create_nullsfirst, ".expression.nullsfirst"
+)
nullslast = public_factory(
- UnaryExpression._create_nullslast, ".expression.nullslast")
+ UnaryExpression._create_nullslast, ".expression.nullslast"
+)
asc = public_factory(UnaryExpression._create_asc, ".expression.asc")
desc = public_factory(UnaryExpression._create_desc, ".expression.desc")
distinct = public_factory(
- UnaryExpression._create_distinct, ".expression.distinct")
+ UnaryExpression._create_distinct, ".expression.distinct"
+)
type_coerce = public_factory(TypeCoerce, ".expression.type_coerce")
true = public_factory(True_._instance, ".expression.true")
false = public_factory(False_._instance, ".expression.false")
@@ -105,19 +211,30 @@ outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin")
insert = public_factory(Insert, ".expression.insert")
update = public_factory(Update, ".expression.update")
delete = public_factory(Delete, ".expression.delete")
-funcfilter = public_factory(
- FunctionFilter, ".expression.funcfilter")
+funcfilter = public_factory(FunctionFilter, ".expression.funcfilter")
# internal functions still being called from tests and the ORM,
# these might be better off in some other namespace
from .base import _from_objects
-from .elements import _literal_as_text, _clause_element_as_expr,\
- _is_column, _labeled, _only_column_elements, _string_or_unprintable, \
- _truncated_label, _clone, _cloned_difference, _cloned_intersection,\
- _column_as_key, _literal_as_binds, _select_iterables, \
- _corresponding_column_or_error, _literal_as_label_reference, \
- _expression_literal_as_text
+from .elements import (
+ _literal_as_text,
+ _clause_element_as_expr,
+ _is_column,
+ _labeled,
+ _only_column_elements,
+ _string_or_unprintable,
+ _truncated_label,
+ _clone,
+ _cloned_difference,
+ _cloned_intersection,
+ _column_as_key,
+ _literal_as_binds,
+ _select_iterables,
+ _corresponding_column_or_error,
+ _literal_as_label_reference,
+ _expression_literal_as_text,
+)
from .selectable import _interpret_as_from
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 4b4d2d463..883bb8cc3 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -10,10 +10,22 @@
"""
from . import sqltypes, schema
from .base import Executable, ColumnCollection
-from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
- literal_column, _type_from_args, ColumnElement, _clone,\
- Over, BindParameter, FunctionFilter, Grouping, WithinGroup, \
- BinaryExpression
+from .elements import (
+ ClauseList,
+ Cast,
+ Extract,
+ _literal_as_binds,
+ literal_column,
+ _type_from_args,
+ ColumnElement,
+ _clone,
+ Over,
+ BindParameter,
+ FunctionFilter,
+ Grouping,
+ WithinGroup,
+ BinaryExpression,
+)
from .selectable import FromClause, Select, Alias
from . import util as sqlutil
from . import operators
@@ -62,9 +74,8 @@ class FunctionElement(Executable, ColumnElement, FromClause):
args = [_literal_as_binds(c, self.name) for c in clauses]
self._has_args = self._has_args or bool(args)
self.clause_expr = ClauseList(
- operator=operators.comma_op,
- group_contents=True, *args).\
- self_group()
+ operator=operators.comma_op, group_contents=True, *args
+ ).self_group()
def _execute_on_connection(self, connection, multiparams, params):
return connection._execute_function(self, multiparams, params)
@@ -123,7 +134,7 @@ class FunctionElement(Executable, ColumnElement, FromClause):
partition_by=partition_by,
order_by=order_by,
rows=rows,
- range_=range_
+ range_=range_,
)
def within_group(self, *order_by):
@@ -233,16 +244,14 @@ class FunctionElement(Executable, ColumnElement, FromClause):
.. versionadded:: 1.3
"""
- return FunctionAsBinary(
- self, left_index, right_index
- )
+ return FunctionAsBinary(self, left_index, right_index)
@property
def _from_objects(self):
return self.clauses._from_objects
def get_children(self, **kwargs):
- return self.clause_expr,
+ return (self.clause_expr,)
def _copy_internals(self, clone=_clone, **kw):
self.clause_expr = clone(self.clause_expr, **kw)
@@ -336,24 +345,29 @@ class FunctionElement(Executable, ColumnElement, FromClause):
return self.select().execute()
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(None, obj, _compared_to_operator=operator,
- _compared_to_type=self.type, unique=True,
- type_=type_)
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ unique=True,
+ type_=type_,
+ )
def self_group(self, against=None):
# for the moment, we are parenthesizing all array-returning
# expressions against getitem. This may need to be made
# more portable if in the future we support other DBs
# besides postgresql.
- if against is operators.getitem and \
- isinstance(self.type, sqltypes.ARRAY):
+ if against is operators.getitem and isinstance(
+ self.type, sqltypes.ARRAY
+ ):
return Grouping(self)
else:
return super(FunctionElement, self).self_group(against=against)
class FunctionAsBinary(BinaryExpression):
-
def __init__(self, fn, left_index, right_index):
left = fn.clauses.clauses[left_index - 1]
right = fn.clauses.clauses[right_index - 1]
@@ -362,8 +376,11 @@ class FunctionAsBinary(BinaryExpression):
self.right_index = right_index
super(FunctionAsBinary, self).__init__(
- left, right, operators.function_as_comparison_op,
- type_=sqltypes.BOOLEANTYPE)
+ left,
+ right,
+ operators.function_as_comparison_op,
+ type_=sqltypes.BOOLEANTYPE,
+ )
@property
def left(self):
@@ -382,7 +399,7 @@ class FunctionAsBinary(BinaryExpression):
self.sql_function.clauses.clauses[self.right_index - 1] = value
def _copy_internals(self, **kw):
- clone = kw.pop('clone')
+ clone = kw.pop("clone")
self.sql_function = clone(self.sql_function, **kw)
super(FunctionAsBinary, self)._copy_internals(**kw)
@@ -396,13 +413,13 @@ class _FunctionGenerator(object):
def __getattr__(self, name):
# passthru __ attributes; fixes pydoc
- if name.startswith('__'):
+ if name.startswith("__"):
try:
return self.__dict__[name]
except KeyError:
raise AttributeError(name)
- elif name.endswith('_'):
+ elif name.endswith("_"):
name = name[0:-1]
f = _FunctionGenerator(**self.opts)
f.__names = list(self.__names) + [name]
@@ -426,8 +443,9 @@ class _FunctionGenerator(object):
if func is not None:
return func(*c, **o)
- return Function(self.__names[-1],
- packagenames=self.__names[0:-1], *c, **o)
+ return Function(
+ self.__names[-1], packagenames=self.__names[0:-1], *c, **o
+ )
func = _FunctionGenerator()
@@ -523,7 +541,7 @@ class Function(FunctionElement):
"""
- __visit_name__ = 'function'
+ __visit_name__ = "function"
def __init__(self, name, *clauses, **kw):
"""Construct a :class:`.Function`.
@@ -532,30 +550,33 @@ class Function(FunctionElement):
new :class:`.Function` instances.
"""
- self.packagenames = kw.pop('packagenames', None) or []
+ self.packagenames = kw.pop("packagenames", None) or []
self.name = name
- self._bind = kw.get('bind', None)
- self.type = sqltypes.to_instance(kw.get('type_', None))
+ self._bind = kw.get("bind", None)
+ self.type = sqltypes.to_instance(kw.get("type_", None))
FunctionElement.__init__(self, *clauses, **kw)
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(self.name, obj,
- _compared_to_operator=operator,
- _compared_to_type=self.type,
- type_=type_,
- unique=True)
+ return BindParameter(
+ self.name,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ )
class _GenericMeta(VisitableType):
def __init__(cls, clsname, bases, clsdict):
if annotation.Annotated not in cls.__mro__:
- cls.name = name = clsdict.get('name', clsname)
- cls.identifier = identifier = clsdict.get('identifier', name)
- package = clsdict.pop('package', '_default')
+ cls.name = name = clsdict.get("name", clsname)
+ cls.identifier = identifier = clsdict.get("identifier", name)
+ package = clsdict.pop("package", "_default")
# legacy
- if '__return_type__' in clsdict:
- cls.type = clsdict['__return_type__']
+ if "__return_type__" in clsdict:
+ cls.type = clsdict["__return_type__"]
register_function(identifier, cls, package)
super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
@@ -635,17 +656,19 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
coerce_arguments = True
def __init__(self, *args, **kwargs):
- parsed_args = kwargs.pop('_parsed_args', None)
+ parsed_args = kwargs.pop("_parsed_args", None)
if parsed_args is None:
parsed_args = [_literal_as_binds(c, self.name) for c in args]
self._has_args = self._has_args or bool(parsed_args)
self.packagenames = []
- self._bind = kwargs.get('bind', None)
+ self._bind = kwargs.get("bind", None)
self.clause_expr = ClauseList(
- operator=operators.comma_op,
- group_contents=True, *parsed_args).self_group()
+ operator=operators.comma_op, group_contents=True, *parsed_args
+ ).self_group()
self.type = sqltypes.to_instance(
- kwargs.pop("type_", None) or getattr(self, 'type', None))
+ kwargs.pop("type_", None) or getattr(self, "type", None)
+ )
+
register_function("cast", Cast)
register_function("extract", Extract)
@@ -660,13 +683,15 @@ class next_value(GenericFunction):
that does not provide support for sequences.
"""
+
type = sqltypes.Integer()
name = "next_value"
def __init__(self, seq, **kw):
- assert isinstance(seq, schema.Sequence), \
- "next_value() accepts a Sequence object as input."
- self._bind = kw.get('bind', None)
+ assert isinstance(
+ seq, schema.Sequence
+ ), "next_value() accepts a Sequence object as input."
+ self._bind = kw.get("bind", None)
self.sequence = seq
@property
@@ -684,8 +709,8 @@ class ReturnTypeFromArgs(GenericFunction):
def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c, self.name) for c in args]
- kwargs.setdefault('type_', _type_from_args(args))
- kwargs['_parsed_args'] = args
+ kwargs.setdefault("type_", _type_from_args(args))
+ kwargs["_parsed_args"] = args
super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
@@ -733,7 +758,7 @@ class count(GenericFunction):
def __init__(self, expression=None, **kwargs):
if expression is None:
- expression = literal_column('*')
+ expression = literal_column("*")
super(count, self).__init__(expression, **kwargs)
@@ -797,15 +822,15 @@ class array_agg(GenericFunction):
def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c) for c in args]
- default_array_type = kwargs.pop('_default_array_type', sqltypes.ARRAY)
- if 'type_' not in kwargs:
+ default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY)
+ if "type_" not in kwargs:
type_from_args = _type_from_args(args)
if isinstance(type_from_args, sqltypes.ARRAY):
- kwargs['type_'] = type_from_args
+ kwargs["type_"] = type_from_args
else:
- kwargs['type_'] = default_array_type(type_from_args)
- kwargs['_parsed_args'] = args
+ kwargs["type_"] = default_array_type(type_from_args)
+ kwargs["_parsed_args"] = args
super(array_agg, self).__init__(*args, **kwargs)
@@ -883,6 +908,7 @@ class rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Integer()
@@ -897,6 +923,7 @@ class dense_rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Integer()
@@ -911,6 +938,7 @@ class percent_rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Numeric()
@@ -925,6 +953,7 @@ class cume_dist(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Numeric()
diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py
index 0107ce724..144cc4dfc 100644
--- a/lib/sqlalchemy/sql/naming.py
+++ b/lib/sqlalchemy/sql/naming.py
@@ -10,8 +10,16 @@
"""
-from .schema import Constraint, ForeignKeyConstraint, PrimaryKeyConstraint, \
- UniqueConstraint, CheckConstraint, Index, Table, Column
+from .schema import (
+ Constraint,
+ ForeignKeyConstraint,
+ PrimaryKeyConstraint,
+ UniqueConstraint,
+ CheckConstraint,
+ Index,
+ Table,
+ Column,
+)
from .. import event, events
from .. import exc
from .elements import _truncated_label, _defer_name, _defer_none_name, conv
@@ -19,7 +27,6 @@ import re
class ConventionDict(object):
-
def __init__(self, const, table, convention):
self.const = const
self._is_fk = isinstance(const, ForeignKeyConstraint)
@@ -79,8 +86,8 @@ class ConventionDict(object):
def __getitem__(self, key):
if key in self.convention:
return self.convention[key](self.const, self.table)
- elif hasattr(self, '_key_%s' % key):
- return getattr(self, '_key_%s' % key)()
+ elif hasattr(self, "_key_%s" % key):
+ return getattr(self, "_key_%s" % key)()
else:
col_template = re.match(r".*_?column_(\d+)(_?N)?_.+", key)
if col_template:
@@ -108,12 +115,13 @@ class ConventionDict(object):
return getattr(self, attr)(idx)
raise KeyError(key)
+
_prefix_dict = {
Index: "ix",
PrimaryKeyConstraint: "pk",
CheckConstraint: "ck",
UniqueConstraint: "uq",
- ForeignKeyConstraint: "fk"
+ ForeignKeyConstraint: "fk",
}
@@ -134,15 +142,18 @@ def _constraint_name_for_table(const, table):
if isinstance(const.name, conv):
return const.name
- elif convention is not None and \
- not isinstance(const.name, conv) and \
- (
- const.name is None or
- "constraint_name" in convention or
- isinstance(const.name, _defer_name)):
+ elif (
+ convention is not None
+ and not isinstance(const.name, conv)
+ and (
+ const.name is None
+ or "constraint_name" in convention
+ or isinstance(const.name, _defer_name)
+ )
+ ):
return conv(
- convention % ConventionDict(const, table,
- metadata.naming_convention)
+ convention
+ % ConventionDict(const, table, metadata.naming_convention)
)
elif isinstance(convention, _defer_none_name):
return None
@@ -155,9 +166,11 @@ def _constraint_name(const, table):
# for column-attached constraint, set another event
# to link the column attached to the table as this constraint
# associated with the table.
- event.listen(table, "after_parent_attach",
- lambda col, table: _constraint_name(const, table)
- )
+ event.listen(
+ table,
+ "after_parent_attach",
+ lambda col, table: _constraint_name(const, table),
+ )
elif isinstance(table, Table):
if isinstance(const.name, (conv, _defer_name)):
return
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
index 5b4a28a06..2b843d751 100644
--- a/lib/sqlalchemy/sql/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
@@ -13,8 +13,25 @@
from .. import util
from operator import (
- and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq, neg,
- getitem, lshift, rshift, contains
+ and_,
+ or_,
+ inv,
+ add,
+ mul,
+ sub,
+ mod,
+ truediv,
+ lt,
+ le,
+ ne,
+ gt,
+ ge,
+ eq,
+ neg,
+ getitem,
+ lshift,
+ rshift,
+ contains,
)
if util.py2k:
@@ -37,6 +54,7 @@ class Operators(object):
:class:`.ColumnOperators`.
"""
+
__slots__ = ()
def __and__(self, other):
@@ -105,8 +123,8 @@ class Operators(object):
return self.operate(inv)
def op(
- self, opstring, precedence=0, is_comparison=False,
- return_type=None):
+ self, opstring, precedence=0, is_comparison=False, return_type=None
+ ):
"""produce a generic operator function.
e.g.::
@@ -168,6 +186,7 @@ class Operators(object):
def against(other):
return operator(self, other)
+
return against
def bool_op(self, opstring, precedence=0):
@@ -247,12 +266,18 @@ class custom_op(object):
:meth:`.Operators.bool_op`
"""
- __name__ = 'custom_op'
+
+ __name__ = "custom_op"
def __init__(
- self, opstring, precedence=0, is_comparison=False,
- return_type=None, natural_self_precedent=False,
- eager_grouping=False):
+ self,
+ opstring,
+ precedence=0,
+ is_comparison=False,
+ return_type=None,
+ natural_self_precedent=False,
+ eager_grouping=False,
+ ):
self.opstring = opstring
self.precedence = precedence
self.is_comparison = is_comparison
@@ -263,8 +288,7 @@ class custom_op(object):
)
def __eq__(self, other):
- return isinstance(other, custom_op) and \
- other.opstring == self.opstring
+ return isinstance(other, custom_op) and other.opstring == self.opstring
def __hash__(self):
return id(self)
@@ -1138,6 +1162,7 @@ class ColumnOperators(Operators):
"""
return self.reverse_operate(truediv, other)
+
_commutative = {eq, ne, add, mul}
_comparison = {eq, ne, lt, gt, ge, le}
@@ -1261,20 +1286,18 @@ def _escaped_like_impl(fn, other, escape, autoescape):
if autoescape:
if autoescape is not True:
util.warn(
- "The autoescape parameter is now a simple boolean True/False")
+ "The autoescape parameter is now a simple boolean True/False"
+ )
if escape is None:
- escape = '/'
+ escape = "/"
if not isinstance(other, util.compat.string_types):
raise TypeError("String value expected when autoescape=True")
- if escape not in ('%', '_'):
+ if escape not in ("%", "_"):
other = other.replace(escape, escape + escape)
- other = (
- other.replace('%', escape + '%').
- replace('_', escape + '_')
- )
+ other = other.replace("%", escape + "%").replace("_", escape + "_")
return fn(other, escape=escape)
@@ -1362,8 +1385,7 @@ def json_path_getitem_op(a, b):
def is_comparison(op):
- return op in _comparison or \
- isinstance(op, custom_op) and op.is_comparison
+ return op in _comparison or isinstance(op, custom_op) and op.is_comparison
def is_commutative(op):
@@ -1371,13 +1393,16 @@ def is_commutative(op):
def is_ordering_modifier(op):
- return op in (asc_op, desc_op,
- nullsfirst_op, nullslast_op)
+ return op in (asc_op, desc_op, nullsfirst_op, nullslast_op)
def is_natural_self_precedent(op):
- return op in _natural_self_precedent or \
- isinstance(op, custom_op) and op.natural_self_precedent
+ return (
+ op in _natural_self_precedent
+ or isinstance(op, custom_op)
+ and op.natural_self_precedent
+ )
+
_booleans = (inv, istrue, isfalse, and_, or_)
@@ -1385,12 +1410,8 @@ _booleans = (inv, istrue, isfalse, and_, or_)
def is_boolean(op):
return is_comparison(op) or op in _booleans
-_mirror = {
- gt: lt,
- ge: le,
- lt: gt,
- le: ge
-}
+
+_mirror = {gt: lt, ge: le, lt: gt, le: ge}
def mirror(op):
@@ -1404,17 +1425,18 @@ def mirror(op):
_associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne])
-_natural_self_precedent = _associative.union([
- getitem, json_getitem_op, json_path_getitem_op])
+_natural_self_precedent = _associative.union(
+ [getitem, json_getitem_op, json_path_getitem_op]
+)
"""Operators where if we have (a op b) op c, we don't want to
parenthesize (a op b).
"""
-_asbool = util.symbol('_asbool', canonical=-10)
-_smallest = util.symbol('_smallest', canonical=-100)
-_largest = util.symbol('_largest', canonical=100)
+_asbool = util.symbol("_asbool", canonical=-10)
+_smallest = util.symbol("_smallest", canonical=-100)
+_largest = util.symbol("_largest", canonical=100)
_PRECEDENCE = {
from_: 15,
@@ -1424,7 +1446,6 @@ _PRECEDENCE = {
getitem: 15,
json_getitem_op: 15,
json_path_getitem_op: 15,
-
mul: 8,
truediv: 8,
div: 8,
@@ -1432,22 +1453,17 @@ _PRECEDENCE = {
neg: 8,
add: 7,
sub: 7,
-
concat_op: 6,
-
match_op: 5,
notmatch_op: 5,
-
ilike_op: 5,
notilike_op: 5,
like_op: 5,
notlike_op: 5,
in_op: 5,
notin_op: 5,
-
is_: 5,
isnot: 5,
-
eq: 5,
ne: 5,
is_distinct_from: 5,
@@ -1458,7 +1474,6 @@ _PRECEDENCE = {
lt: 5,
ge: 5,
le: 5,
-
between_op: 5,
notbetween_op: 5,
distinct_op: 5,
@@ -1468,17 +1483,14 @@ _PRECEDENCE = {
and_: 3,
or_: 2,
comma_op: -1,
-
desc_op: 3,
asc_op: 3,
collate: 4,
-
as_: -1,
exists: 0,
-
_asbool: -10,
_smallest: _smallest,
- _largest: _largest
+ _largest: _largest,
}
@@ -1486,7 +1498,6 @@ def is_precedent(operator, against):
if operator is against and is_natural_self_precedent(operator):
return False
else:
- return (_PRECEDENCE.get(operator,
- getattr(operator, 'precedence', _smallest)) <=
- _PRECEDENCE.get(against,
- getattr(against, 'precedence', _largest)))
+ return _PRECEDENCE.get(
+ operator, getattr(operator, "precedence", _smallest)
+ ) <= _PRECEDENCE.get(against, getattr(against, "precedence", _largest))
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 3e9aa174a..d6c3f5000 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -36,25 +36,31 @@ import operator
from . import visitors
from . import type_api
from .base import _bind_or_error, ColumnCollection
-from .elements import ClauseElement, ColumnClause, \
- _as_truncated, TextClause, _literal_as_text,\
- ColumnElement, quoted_name
+from .elements import (
+ ClauseElement,
+ ColumnClause,
+ _as_truncated,
+ TextClause,
+ _literal_as_text,
+ ColumnElement,
+ quoted_name,
+)
from .selectable import TableClause
import collections
import sqlalchemy
from . import ddl
-RETAIN_SCHEMA = util.symbol('retain_schema')
+RETAIN_SCHEMA = util.symbol("retain_schema")
BLANK_SCHEMA = util.symbol(
- 'blank_schema',
+ "blank_schema",
"""Symbol indicating that a :class:`.Table` or :class:`.Sequence`
should have 'None' for its schema, even if the parent
:class:`.MetaData` has specified a schema.
.. versionadded:: 1.0.14
- """
+ """,
)
@@ -69,11 +75,15 @@ def _get_table_key(name, schema):
# break an import cycle
def _copy_expression(expression, source_table, target_table):
def replace(col):
- if isinstance(col, Column) and \
- col.table is source_table and col.key in source_table.c:
+ if (
+ isinstance(col, Column)
+ and col.table is source_table
+ and col.key in source_table.c
+ ):
return target_table.c[col.key]
else:
return None
+
return visitors.replacement_traverse(expression, {}, replace)
@@ -81,7 +91,7 @@ def _copy_expression(expression, source_table, target_table):
class SchemaItem(SchemaEventTarget, visitors.Visitable):
"""Base class for items that define a database schema."""
- __visit_name__ = 'schema_item'
+ __visit_name__ = "schema_item"
def _init_items(self, *args):
"""Initialize the list of child items for this SchemaItem."""
@@ -95,10 +105,10 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
return []
def __repr__(self):
- return util.generic_repr(self, omit_kwarg=['info'])
+ return util.generic_repr(self, omit_kwarg=["info"])
@property
- @util.deprecated('0.9', 'Use ``<obj>.name.quote``')
+ @util.deprecated("0.9", "Use ``<obj>.name.quote``")
def quote(self):
"""Return the value of the ``quote`` flag passed
to this schema object, for those schema items which
@@ -121,7 +131,7 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
return {}
def _schema_item_copy(self, schema_item):
- if 'info' in self.__dict__:
+ if "info" in self.__dict__:
schema_item.info = self.info.copy()
schema_item.dispatch._update(self.dispatch)
return schema_item
@@ -396,7 +406,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"""
- __visit_name__ = 'table'
+ __visit_name__ = "table"
def __new__(cls, *args, **kw):
if not args:
@@ -408,26 +418,26 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
except IndexError:
raise TypeError("Table() takes at least two arguments")
- schema = kw.get('schema', None)
+ schema = kw.get("schema", None)
if schema is None:
schema = metadata.schema
elif schema is BLANK_SCHEMA:
schema = None
- keep_existing = kw.pop('keep_existing', False)
- extend_existing = kw.pop('extend_existing', False)
- if 'useexisting' in kw:
+ keep_existing = kw.pop("keep_existing", False)
+ extend_existing = kw.pop("extend_existing", False)
+ if "useexisting" in kw:
msg = "useexisting is deprecated. Use extend_existing."
util.warn_deprecated(msg)
if extend_existing:
msg = "useexisting is synonymous with extend_existing."
raise exc.ArgumentError(msg)
- extend_existing = kw.pop('useexisting', False)
+ extend_existing = kw.pop("useexisting", False)
if keep_existing and extend_existing:
msg = "keep_existing and extend_existing are mutually exclusive."
raise exc.ArgumentError(msg)
- mustexist = kw.pop('mustexist', False)
+ mustexist = kw.pop("mustexist", False)
key = _get_table_key(name, schema)
if key in metadata.tables:
if not keep_existing and not extend_existing and bool(args):
@@ -436,15 +446,15 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"instance. Specify 'extend_existing=True' "
"to redefine "
"options and columns on an "
- "existing Table object." % key)
+ "existing Table object." % key
+ )
table = metadata.tables[key]
if extend_existing:
table._init_existing(*args, **kw)
return table
else:
if mustexist:
- raise exc.InvalidRequestError(
- "Table '%s' not defined" % (key))
+ raise exc.InvalidRequestError("Table '%s' not defined" % (key))
table = object.__new__(cls)
table.dispatch.before_parent_attach(table, metadata)
metadata._add_table(name, schema, table)
@@ -457,7 +467,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
metadata._remove_table(name, schema)
@property
- @util.deprecated('0.9', 'Use ``table.schema.quote``')
+ @util.deprecated("0.9", "Use ``table.schema.quote``")
def quote_schema(self):
"""Return the value of the ``quote_schema`` flag passed
to this :class:`.Table`.
@@ -478,23 +488,25 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
def _init(self, name, metadata, *args, **kwargs):
super(Table, self).__init__(
- quoted_name(name, kwargs.pop('quote', None)))
+ quoted_name(name, kwargs.pop("quote", None))
+ )
self.metadata = metadata
- self.schema = kwargs.pop('schema', None)
+ self.schema = kwargs.pop("schema", None)
if self.schema is None:
self.schema = metadata.schema
elif self.schema is BLANK_SCHEMA:
self.schema = None
else:
- quote_schema = kwargs.pop('quote_schema', None)
+ quote_schema = kwargs.pop("quote_schema", None)
self.schema = quoted_name(self.schema, quote_schema)
self.indexes = set()
self.constraints = set()
self._columns = ColumnCollection()
- PrimaryKeyConstraint(_implicit_generated=True).\
- _set_parent_with_dispatch(self)
+ PrimaryKeyConstraint(
+ _implicit_generated=True
+ )._set_parent_with_dispatch(self)
self.foreign_keys = set()
self._extra_dependencies = set()
if self.schema is not None:
@@ -502,26 +514,26 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
else:
self.fullname = self.name
- autoload_with = kwargs.pop('autoload_with', None)
- autoload = kwargs.pop('autoload', autoload_with is not None)
+ autoload_with = kwargs.pop("autoload_with", None)
+ autoload = kwargs.pop("autoload", autoload_with is not None)
# this argument is only used with _init_existing()
- kwargs.pop('autoload_replace', True)
+ kwargs.pop("autoload_replace", True)
_extend_on = kwargs.pop("_extend_on", None)
- include_columns = kwargs.pop('include_columns', None)
+ include_columns = kwargs.pop("include_columns", None)
- self.implicit_returning = kwargs.pop('implicit_returning', True)
+ self.implicit_returning = kwargs.pop("implicit_returning", True)
- self.comment = kwargs.pop('comment', None)
+ self.comment = kwargs.pop("comment", None)
- if 'info' in kwargs:
- self.info = kwargs.pop('info')
- if 'listeners' in kwargs:
- listeners = kwargs.pop('listeners')
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+ if "listeners" in kwargs:
+ listeners = kwargs.pop("listeners")
for evt, fn in listeners:
event.listen(self, evt, fn)
- self._prefixes = kwargs.pop('prefixes', [])
+ self._prefixes = kwargs.pop("prefixes", [])
self._extra_kwargs(**kwargs)
@@ -530,21 +542,29 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
# circular foreign keys
if autoload:
self._autoload(
- metadata, autoload_with,
- include_columns, _extend_on=_extend_on)
+ metadata, autoload_with, include_columns, _extend_on=_extend_on
+ )
# initialize all the column, etc. objects. done after reflection to
# allow user-overrides
self._init_items(*args)
- def _autoload(self, metadata, autoload_with, include_columns,
- exclude_columns=(), _extend_on=None):
+ def _autoload(
+ self,
+ metadata,
+ autoload_with,
+ include_columns,
+ exclude_columns=(),
+ _extend_on=None,
+ ):
if autoload_with:
autoload_with.run_callable(
autoload_with.dialect.reflecttable,
- self, include_columns, exclude_columns,
- _extend_on=_extend_on
+ self,
+ include_columns,
+ exclude_columns,
+ _extend_on=_extend_on,
)
else:
bind = _bind_or_error(
@@ -553,11 +573,14 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"Pass an engine to the Table via "
"autoload_with=<someengine>, "
"or associate the MetaData with an engine via "
- "metadata.bind=<someengine>")
+ "metadata.bind=<someengine>",
+ )
bind.run_callable(
bind.dialect.reflecttable,
- self, include_columns, exclude_columns,
- _extend_on=_extend_on
+ self,
+ include_columns,
+ exclude_columns,
+ _extend_on=_extend_on,
)
@property
@@ -582,34 +605,36 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
return set(fkc.constraint for fkc in self.foreign_keys)
def _init_existing(self, *args, **kwargs):
- autoload_with = kwargs.pop('autoload_with', None)
- autoload = kwargs.pop('autoload', autoload_with is not None)
- autoload_replace = kwargs.pop('autoload_replace', True)
- schema = kwargs.pop('schema', None)
- _extend_on = kwargs.pop('_extend_on', None)
+ autoload_with = kwargs.pop("autoload_with", None)
+ autoload = kwargs.pop("autoload", autoload_with is not None)
+ autoload_replace = kwargs.pop("autoload_replace", True)
+ schema = kwargs.pop("schema", None)
+ _extend_on = kwargs.pop("_extend_on", None)
if schema and schema != self.schema:
raise exc.ArgumentError(
"Can't change schema of existing table from '%s' to '%s'",
- (self.schema, schema))
+ (self.schema, schema),
+ )
- include_columns = kwargs.pop('include_columns', None)
+ include_columns = kwargs.pop("include_columns", None)
if include_columns is not None:
for c in self.c:
if c.name not in include_columns:
self._columns.remove(c)
- for key in ('quote', 'quote_schema'):
+ for key in ("quote", "quote_schema"):
if key in kwargs:
raise exc.ArgumentError(
- "Can't redefine 'quote' or 'quote_schema' arguments")
+ "Can't redefine 'quote' or 'quote_schema' arguments"
+ )
- if 'comment' in kwargs:
- self.comment = kwargs.pop('comment', None)
+ if "comment" in kwargs:
+ self.comment = kwargs.pop("comment", None)
- if 'info' in kwargs:
- self.info = kwargs.pop('info')
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
if autoload:
if not autoload_replace:
@@ -620,8 +645,12 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
else:
exclude_columns = ()
self._autoload(
- self.metadata, autoload_with,
- include_columns, exclude_columns, _extend_on=_extend_on)
+ self.metadata,
+ autoload_with,
+ include_columns,
+ exclude_columns,
+ _extend_on=_extend_on,
+ )
self._extra_kwargs(**kwargs)
self._init_items(*args)
@@ -653,10 +682,12 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
return _get_table_key(self.name, self.schema)
def __repr__(self):
- return "Table(%s)" % ', '.join(
- [repr(self.name)] + [repr(self.metadata)] +
- [repr(x) for x in self.columns] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']])
+ return "Table(%s)" % ", ".join(
+ [repr(self.name)]
+ + [repr(self.metadata)]
+ + [repr(x) for x in self.columns]
+ + ["%s=%s" % (k, repr(getattr(self, k))) for k in ["schema"]]
+ )
def __str__(self):
return _get_table_key(self.description, self.schema)
@@ -735,17 +766,19 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
def adapt_listener(target, connection, **kw):
listener(event_name, target, connection)
- event.listen(self, "" + event_name.replace('-', '_'), adapt_listener)
+ event.listen(self, "" + event_name.replace("-", "_"), adapt_listener)
def _set_parent(self, metadata):
metadata._add_table(self.name, self.schema, self)
self.metadata = metadata
- def get_children(self, column_collections=True,
- schema_visitor=False, **kw):
+ def get_children(
+ self, column_collections=True, schema_visitor=False, **kw
+ ):
if not schema_visitor:
return TableClause.get_children(
- self, column_collections=column_collections, **kw)
+ self, column_collections=column_collections, **kw
+ )
else:
if column_collections:
return list(self.columns)
@@ -758,8 +791,9 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
if bind is None:
bind = _bind_or_error(self)
- return bind.run_callable(bind.dialect.has_table,
- self.name, schema=self.schema)
+ return bind.run_callable(
+ bind.dialect.has_table, self.name, schema=self.schema
+ )
def create(self, bind=None, checkfirst=False):
"""Issue a ``CREATE`` statement for this
@@ -774,9 +808,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaGenerator,
- self,
- checkfirst=checkfirst)
+ bind._run_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=False):
"""Issue a ``DROP`` statement for this
@@ -790,12 +822,15 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaDropper,
- self,
- checkfirst=checkfirst)
-
- def tometadata(self, metadata, schema=RETAIN_SCHEMA,
- referred_schema_fn=None, name=None):
+ bind._run_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
+
+ def tometadata(
+ self,
+ metadata,
+ schema=RETAIN_SCHEMA,
+ referred_schema_fn=None,
+ name=None,
+ ):
"""Return a copy of this :class:`.Table` associated with a different
:class:`.MetaData`.
@@ -868,29 +903,37 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
schema = metadata.schema
key = _get_table_key(name, schema)
if key in metadata.tables:
- util.warn("Table '%s' already exists within the given "
- "MetaData - not copying." % self.description)
+ util.warn(
+ "Table '%s' already exists within the given "
+ "MetaData - not copying." % self.description
+ )
return metadata.tables[key]
args = []
for c in self.columns:
args.append(c.copy(schema=schema))
table = Table(
- name, metadata, schema=schema,
+ name,
+ metadata,
+ schema=schema,
comment=self.comment,
- *args, **self.kwargs
+ *args,
+ **self.kwargs
)
for c in self.constraints:
if isinstance(c, ForeignKeyConstraint):
referred_schema = c._referred_schema
if referred_schema_fn:
fk_constraint_schema = referred_schema_fn(
- self, schema, c, referred_schema)
+ self, schema, c, referred_schema
+ )
else:
fk_constraint_schema = (
- schema if referred_schema == self.schema else None)
+ schema if referred_schema == self.schema else None
+ )
table.append_constraint(
- c.copy(schema=fk_constraint_schema, target_table=table))
+ c.copy(schema=fk_constraint_schema, target_table=table)
+ )
elif not c._type_bound:
# skip unique constraints that would be generated
# by the 'unique' flag on Column
@@ -898,25 +941,30 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
continue
table.append_constraint(
- c.copy(schema=schema, target_table=table))
+ c.copy(schema=schema, target_table=table)
+ )
for index in self.indexes:
# skip indexes that would be generated
# by the 'index' flag on Column
if index._column_flag:
continue
- Index(index.name,
- unique=index.unique,
- *[_copy_expression(expr, self, table)
- for expr in index.expressions],
- _table=table,
- **index.kwargs)
+ Index(
+ index.name,
+ unique=index.unique,
+ *[
+ _copy_expression(expr, self, table)
+ for expr in index.expressions
+ ],
+ _table=table,
+ **index.kwargs
+ )
return self._schema_item_copy(table)
class Column(DialectKWArgs, SchemaItem, ColumnClause):
"""Represents a column in a database table."""
- __visit_name__ = 'column'
+ __visit_name__ = "column"
def __init__(self, *args, **kwargs):
r"""
@@ -1192,14 +1240,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"""
- name = kwargs.pop('name', None)
- type_ = kwargs.pop('type_', None)
+ name = kwargs.pop("name", None)
+ type_ = kwargs.pop("type_", None)
args = list(args)
if args:
if isinstance(args[0], util.string_types):
if name is not None:
raise exc.ArgumentError(
- "May not pass name positionally and as a keyword.")
+ "May not pass name positionally and as a keyword."
+ )
name = args.pop(0)
if args:
coltype = args[0]
@@ -1207,40 +1256,42 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if hasattr(coltype, "_sqla_type"):
if type_ is not None:
raise exc.ArgumentError(
- "May not pass type_ positionally and as a keyword.")
+ "May not pass type_ positionally and as a keyword."
+ )
type_ = args.pop(0)
if name is not None:
- name = quoted_name(name, kwargs.pop('quote', None))
+ name = quoted_name(name, kwargs.pop("quote", None))
elif "quote" in kwargs:
- raise exc.ArgumentError("Explicit 'name' is required when "
- "sending 'quote' argument")
+ raise exc.ArgumentError(
+ "Explicit 'name' is required when " "sending 'quote' argument"
+ )
super(Column, self).__init__(name, type_)
- self.key = kwargs.pop('key', name)
- self.primary_key = kwargs.pop('primary_key', False)
- self.nullable = kwargs.pop('nullable', not self.primary_key)
- self.default = kwargs.pop('default', None)
- self.server_default = kwargs.pop('server_default', None)
- self.server_onupdate = kwargs.pop('server_onupdate', None)
+ self.key = kwargs.pop("key", name)
+ self.primary_key = kwargs.pop("primary_key", False)
+ self.nullable = kwargs.pop("nullable", not self.primary_key)
+ self.default = kwargs.pop("default", None)
+ self.server_default = kwargs.pop("server_default", None)
+ self.server_onupdate = kwargs.pop("server_onupdate", None)
# these default to None because .index and .unique is *not*
# an informational flag about Column - there can still be an
# Index or UniqueConstraint referring to this Column.
- self.index = kwargs.pop('index', None)
- self.unique = kwargs.pop('unique', None)
+ self.index = kwargs.pop("index", None)
+ self.unique = kwargs.pop("unique", None)
- self.system = kwargs.pop('system', False)
- self.doc = kwargs.pop('doc', None)
- self.onupdate = kwargs.pop('onupdate', None)
- self.autoincrement = kwargs.pop('autoincrement', "auto")
+ self.system = kwargs.pop("system", False)
+ self.doc = kwargs.pop("doc", None)
+ self.onupdate = kwargs.pop("onupdate", None)
+ self.autoincrement = kwargs.pop("autoincrement", "auto")
self.constraints = set()
self.foreign_keys = set()
- self.comment = kwargs.pop('comment', None)
+ self.comment = kwargs.pop("comment", None)
# check if this Column is proxying another column
- if '_proxies' in kwargs:
- self._proxies = kwargs.pop('_proxies')
+ if "_proxies" in kwargs:
+ self._proxies = kwargs.pop("_proxies")
# otherwise, add DDL-related events
elif isinstance(self.type, SchemaEventTarget):
self.type._set_parent_with_dispatch(self)
@@ -1249,14 +1300,13 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if isinstance(self.default, (ColumnDefault, Sequence)):
args.append(self.default)
else:
- if getattr(self.type, '_warn_on_bytestring', False):
+ if getattr(self.type, "_warn_on_bytestring", False):
if isinstance(self.default, util.binary_type):
util.warn(
"Unicode column '%s' has non-unicode "
- "default value %r specified." % (
- self.key,
- self.default
- ))
+ "default value %r specified."
+ % (self.key, self.default)
+ )
args.append(ColumnDefault(self.default))
if self.server_default is not None:
@@ -1275,30 +1325,31 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if isinstance(self.server_onupdate, FetchedValue):
args.append(self.server_onupdate._as_for_update(True))
else:
- args.append(DefaultClause(self.server_onupdate,
- for_update=True))
+ args.append(
+ DefaultClause(self.server_onupdate, for_update=True)
+ )
self._init_items(*args)
util.set_creation_order(self)
- if 'info' in kwargs:
- self.info = kwargs.pop('info')
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
self._extra_kwargs(**kwargs)
def _extra_kwargs(self, **kwargs):
self._validate_dialect_kwargs(kwargs)
-# @property
-# def quote(self):
-# return getattr(self.name, "quote", None)
+ # @property
+ # def quote(self):
+ # return getattr(self.name, "quote", None)
def __str__(self):
if self.name is None:
return "(no name)"
elif self.table is not None:
if self.table.named_with_column:
- return (self.table.description + "." + self.description)
+ return self.table.description + "." + self.description
else:
return self.description
else:
@@ -1320,40 +1371,47 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
def __repr__(self):
kwarg = []
if self.key != self.name:
- kwarg.append('key')
+ kwarg.append("key")
if self.primary_key:
- kwarg.append('primary_key')
+ kwarg.append("primary_key")
if not self.nullable:
- kwarg.append('nullable')
+ kwarg.append("nullable")
if self.onupdate:
- kwarg.append('onupdate')
+ kwarg.append("onupdate")
if self.default:
- kwarg.append('default')
+ kwarg.append("default")
if self.server_default:
- kwarg.append('server_default')
- return "Column(%s)" % ', '.join(
- [repr(self.name)] + [repr(self.type)] +
- [repr(x) for x in self.foreign_keys if x is not None] +
- [repr(x) for x in self.constraints] +
- [(self.table is not None and "table=<%s>" %
- self.table.description or "table=None")] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg])
+ kwarg.append("server_default")
+ return "Column(%s)" % ", ".join(
+ [repr(self.name)]
+ + [repr(self.type)]
+ + [repr(x) for x in self.foreign_keys if x is not None]
+ + [repr(x) for x in self.constraints]
+ + [
+ (
+ self.table is not None
+ and "table=<%s>" % self.table.description
+ or "table=None"
+ )
+ ]
+ + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]
+ )
def _set_parent(self, table):
if not self.name:
raise exc.ArgumentError(
"Column must be constructed with a non-blank name or "
- "assign a non-blank .name before adding to a Table.")
+ "assign a non-blank .name before adding to a Table."
+ )
if self.key is None:
self.key = self.name
- existing = getattr(self, 'table', None)
+ existing = getattr(self, "table", None)
if existing is not None and existing is not table:
raise exc.ArgumentError(
- "Column object '%s' already assigned to Table '%s'" % (
- self.key,
- existing.description
- ))
+ "Column object '%s' already assigned to Table '%s'"
+ % (self.key, existing.description)
+ )
if self.key in table._columns:
col = table._columns.get(self.key)
@@ -1373,8 +1431,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
elif self.key in table.primary_key:
raise exc.ArgumentError(
"Trying to redefine primary-key column '%s' as a "
- "non-primary-key column on table '%s'" % (
- self.key, table.fullname))
+ "non-primary-key column on table '%s'"
+ % (self.key, table.fullname)
+ )
self.table = table
@@ -1383,7 +1442,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
raise exc.ArgumentError(
"The 'index' keyword argument on Column is boolean only. "
"To create indexes with a specific name, create an "
- "explicit Index object external to the Table.")
+ "explicit Index object external to the Table."
+ )
Index(None, self, unique=bool(self.unique), _column_flag=True)
elif self.unique:
if isinstance(self.unique, util.string_types):
@@ -1392,9 +1452,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"only. To create unique constraints or indexes with a "
"specific name, append an explicit UniqueConstraint to "
"the Table's list of elements, or create an explicit "
- "Index object external to the Table.")
+ "Index object external to the Table."
+ )
table.append_constraint(
- UniqueConstraint(self.key, _column_flag=True))
+ UniqueConstraint(self.key, _column_flag=True)
+ )
self._setup_on_memoized_fks(lambda fk: fk._set_remote_table(table))
@@ -1413,7 +1475,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if self.table is not None:
fn(self, self.table)
else:
- event.listen(self, 'after_parent_attach', fn)
+ event.listen(self, "after_parent_attach", fn)
def copy(self, **kw):
"""Create a copy of this ``Column``, unitialized.
@@ -1423,9 +1485,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"""
# Constraint objects plus non-constraint-bound ForeignKey objects
- args = \
- [c.copy(**kw) for c in self.constraints if not c._type_bound] + \
- [c.copy(**kw) for c in self.foreign_keys if not c.constraint]
+ args = [
+ c.copy(**kw) for c in self.constraints if not c._type_bound
+ ] + [c.copy(**kw) for c in self.foreign_keys if not c.constraint]
type_ = self.type
if isinstance(type_, SchemaEventTarget):
@@ -1452,8 +1514,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
)
return self._schema_item_copy(c)
- def _make_proxy(self, selectable, name=None, key=None,
- name_is_truncatable=False, **kw):
+ def _make_proxy(
+ self, selectable, name=None, key=None, name_is_truncatable=False, **kw
+ ):
"""Create a *proxy* for this column.
This is a copy of this ``Column`` referenced by a different parent
@@ -1462,22 +1525,28 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
information is not transferred.
"""
- fk = [ForeignKey(f.column, _constraint=f.constraint)
- for f in self.foreign_keys]
+ fk = [
+ ForeignKey(f.column, _constraint=f.constraint)
+ for f in self.foreign_keys
+ ]
if name is None and self.name is None:
raise exc.InvalidRequestError(
"Cannot initialize a sub-selectable"
" with this Column object until its 'name' has "
- "been assigned.")
+ "been assigned."
+ )
try:
c = self._constructor(
- _as_truncated(name or self.name) if
- name_is_truncatable else (name or self.name),
+ _as_truncated(name or self.name)
+ if name_is_truncatable
+ else (name or self.name),
self.type,
key=key if key else name if name else self.key,
primary_key=self.primary_key,
nullable=self.nullable,
- _proxies=[self], *fk)
+ _proxies=[self],
+ *fk
+ )
except TypeError:
util.raise_from_cause(
TypeError(
@@ -1485,7 +1554,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"Ensure the class includes a _constructor() "
"attribute or method which accepts the "
"standard Column constructor arguments, or "
- "references the Column class itself." % self.__class__)
+ "references the Column class itself." % self.__class__
+ )
)
c.table = selectable
@@ -1499,9 +1569,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
def get_children(self, schema_visitor=False, **kwargs):
if schema_visitor:
- return [x for x in (self.default, self.onupdate)
- if x is not None] + \
- list(self.foreign_keys) + list(self.constraints)
+ return (
+ [x for x in (self.default, self.onupdate) if x is not None]
+ + list(self.foreign_keys)
+ + list(self.constraints)
+ )
else:
return ColumnClause.get_children(self, **kwargs)
@@ -1543,13 +1615,23 @@ class ForeignKey(DialectKWArgs, SchemaItem):
"""
- __visit_name__ = 'foreign_key'
-
- def __init__(self, column, _constraint=None, use_alter=False, name=None,
- onupdate=None, ondelete=None, deferrable=None,
- initially=None, link_to_name=False, match=None,
- info=None,
- **dialect_kw):
+ __visit_name__ = "foreign_key"
+
+ def __init__(
+ self,
+ column,
+ _constraint=None,
+ use_alter=False,
+ name=None,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ link_to_name=False,
+ match=None,
+ info=None,
+ **dialect_kw
+ ):
r"""
Construct a column-level FOREIGN KEY.
@@ -1626,7 +1708,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if isinstance(self._colspec, util.string_types):
self._table_column = None
else:
- if hasattr(self._colspec, '__clause_element__'):
+ if hasattr(self._colspec, "__clause_element__"):
self._table_column = self._colspec.__clause_element__()
else:
self._table_column = self._colspec
@@ -1634,9 +1716,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if not isinstance(self._table_column, ColumnClause):
raise exc.ArgumentError(
"String, Column, or Column-bound argument "
- "expected, got %r" % self._table_column)
+ "expected, got %r" % self._table_column
+ )
elif not isinstance(
- self._table_column.table, (util.NoneType, TableClause)):
+ self._table_column.table, (util.NoneType, TableClause)
+ ):
raise exc.ArgumentError(
"ForeignKey received Column not bound "
"to a Table, got: %r" % self._table_column.table
@@ -1715,7 +1799,9 @@ class ForeignKey(DialectKWArgs, SchemaItem):
return "%s.%s" % (table_name, colname)
elif self._table_column is not None:
return "%s.%s" % (
- self._table_column.table.fullname, self._table_column.key)
+ self._table_column.table.fullname,
+ self._table_column.key,
+ )
else:
return self._colspec
@@ -1756,12 +1842,12 @@ class ForeignKey(DialectKWArgs, SchemaItem):
def _column_tokens(self):
"""parse a string-based _colspec into its component parts."""
- m = self._get_colspec().split('.')
+ m = self._get_colspec().split(".")
if m is None:
raise exc.ArgumentError(
- "Invalid foreign key column specification: %s" %
- self._colspec)
- if (len(m) == 1):
+ "Invalid foreign key column specification: %s" % self._colspec
+ )
+ if len(m) == 1:
tname = m.pop()
colname = None
else:
@@ -1777,8 +1863,8 @@ class ForeignKey(DialectKWArgs, SchemaItem):
# indirectly related -- Ticket #594. This assumes that '.'
# will never appear *within* any component of the FK.
- if (len(m) > 0):
- schema = '.'.join(m)
+ if len(m) > 0:
+ schema = ".".join(m)
else:
schema = None
return schema, tname, colname
@@ -1787,12 +1873,14 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if self.parent is None:
raise exc.InvalidRequestError(
"this ForeignKey object does not yet have a "
- "parent Column associated with it.")
+ "parent Column associated with it."
+ )
elif self.parent.table is None:
raise exc.InvalidRequestError(
"this ForeignKey's parent column is not yet associated "
- "with a Table.")
+ "with a Table."
+ )
parenttable = self.parent.table
@@ -1817,7 +1905,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
return parenttable, tablekey, colname
def _link_to_col_by_colstring(self, parenttable, table, colname):
- if not hasattr(self.constraint, '_referred_table'):
+ if not hasattr(self.constraint, "_referred_table"):
self.constraint._referred_table = table
else:
assert self.constraint._referred_table is table
@@ -1843,9 +1931,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
raise exc.NoReferencedColumnError(
"Could not initialize target column "
"for ForeignKey '%s' on table '%s': "
- "table '%s' has no column named '%s'" %
- (self._colspec, parenttable.name, table.name, key),
- table.name, key)
+ "table '%s' has no column named '%s'"
+ % (self._colspec, parenttable.name, table.name, key),
+ table.name,
+ key,
+ )
self._set_target_column(_column)
@@ -1861,6 +1951,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
def set_type(fk):
if fk.parent.type._isnull:
fk.parent.type = column.type
+
self.parent._setup_on_memoized_fks(set_type)
self.column = column
@@ -1888,21 +1979,25 @@ class ForeignKey(DialectKWArgs, SchemaItem):
raise exc.NoReferencedTableError(
"Foreign key associated with column '%s' could not find "
"table '%s' with which to generate a "
- "foreign key to target column '%s'" %
- (self.parent, tablekey, colname),
- tablekey)
+ "foreign key to target column '%s'"
+ % (self.parent, tablekey, colname),
+ tablekey,
+ )
elif parenttable.key not in parenttable.metadata:
raise exc.InvalidRequestError(
"Table %s is no longer associated with its "
- "parent MetaData" % parenttable)
+ "parent MetaData" % parenttable
+ )
else:
raise exc.NoReferencedColumnError(
"Could not initialize target column for "
"ForeignKey '%s' on table '%s': "
- "table '%s' has no column named '%s'" % (
- self._colspec, parenttable.name, tablekey, colname),
- tablekey, colname)
- elif hasattr(self._colspec, '__clause_element__'):
+ "table '%s' has no column named '%s'"
+ % (self._colspec, parenttable.name, tablekey, colname),
+ tablekey,
+ colname,
+ )
+ elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
return _column
else:
@@ -1912,7 +2007,8 @@ class ForeignKey(DialectKWArgs, SchemaItem):
def _set_parent(self, column):
if self.parent is not None and self.parent is not column:
raise exc.InvalidRequestError(
- "This ForeignKey already has a parent !")
+ "This ForeignKey already has a parent !"
+ )
self.parent = column
self.parent.foreign_keys.add(self)
self.parent._on_table_attach(self._set_table)
@@ -1935,9 +2031,14 @@ class ForeignKey(DialectKWArgs, SchemaItem):
# on the hosting Table when attached to the Table.
if self.constraint is None and isinstance(table, Table):
self.constraint = ForeignKeyConstraint(
- [], [], use_alter=self.use_alter, name=self.name,
- onupdate=self.onupdate, ondelete=self.ondelete,
- deferrable=self.deferrable, initially=self.initially,
+ [],
+ [],
+ use_alter=self.use_alter,
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ deferrable=self.deferrable,
+ initially=self.initially,
match=self.match,
**self._unvalidated_dialect_kw
)
@@ -1953,13 +2054,12 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if table_key in parenttable.metadata.tables:
table = parenttable.metadata.tables[table_key]
try:
- self._link_to_col_by_colstring(
- parenttable, table, colname)
+ self._link_to_col_by_colstring(parenttable, table, colname)
except exc.NoReferencedColumnError:
# this is OK, we'll try later
pass
parenttable.metadata._fk_memos[fk_key].append(self)
- elif hasattr(self._colspec, '__clause_element__'):
+ elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
self._set_target_column(_column)
else:
@@ -1971,7 +2071,8 @@ class _NotAColumnExpr(object):
def _not_a_column_expr(self):
raise exc.InvalidRequestError(
"This %s cannot be used directly "
- "as a column expression." % self.__class__.__name__)
+ "as a column expression." % self.__class__.__name__
+ )
__clause_element__ = self_group = lambda self: self._not_a_column_expr()
_from_objects = property(lambda self: self._not_a_column_expr())
@@ -1980,7 +2081,7 @@ class _NotAColumnExpr(object):
class DefaultGenerator(_NotAColumnExpr, SchemaItem):
"""Base class for column *default* values."""
- __visit_name__ = 'default_generator'
+ __visit_name__ = "default_generator"
is_sequence = False
is_server_default = False
@@ -2007,7 +2108,7 @@ class DefaultGenerator(_NotAColumnExpr, SchemaItem):
@property
def bind(self):
"""Return the connectable associated with this default."""
- if getattr(self, 'column', None) is not None:
+ if getattr(self, "column", None) is not None:
return self.column.table.bind
else:
return None
@@ -2064,7 +2165,8 @@ class ColumnDefault(DefaultGenerator):
super(ColumnDefault, self).__init__(**kwargs)
if isinstance(arg, FetchedValue):
raise exc.ArgumentError(
- "ColumnDefault may not be a server-side default type.")
+ "ColumnDefault may not be a server-side default type."
+ )
if util.callable(arg):
arg = self._maybe_wrap_callable(arg)
self.arg = arg
@@ -2079,9 +2181,11 @@ class ColumnDefault(DefaultGenerator):
@util.memoized_property
def is_scalar(self):
- return not self.is_callable and \
- not self.is_clause_element and \
- not self.is_sequence
+ return (
+ not self.is_callable
+ and not self.is_clause_element
+ and not self.is_sequence
+ )
@util.memoized_property
@util.dependencies("sqlalchemy.sql.sqltypes")
@@ -2114,17 +2218,19 @@ class ColumnDefault(DefaultGenerator):
else:
raise exc.ArgumentError(
"ColumnDefault Python function takes zero or one "
- "positional arguments")
+ "positional arguments"
+ )
def _visit_name(self):
if self.for_update:
return "column_onupdate"
else:
return "column_default"
+
__visit_name__ = property(_visit_name)
def __repr__(self):
- return "ColumnDefault(%r)" % (self.arg, )
+ return "ColumnDefault(%r)" % (self.arg,)
class Sequence(DefaultGenerator):
@@ -2157,15 +2263,29 @@ class Sequence(DefaultGenerator):
"""
- __visit_name__ = 'sequence'
+ __visit_name__ = "sequence"
is_sequence = True
- def __init__(self, name, start=None, increment=None, minvalue=None,
- maxvalue=None, nominvalue=None, nomaxvalue=None, cycle=None,
- schema=None, cache=None, order=None, optional=False,
- quote=None, metadata=None, quote_schema=None,
- for_update=False):
+ def __init__(
+ self,
+ name,
+ start=None,
+ increment=None,
+ minvalue=None,
+ maxvalue=None,
+ nominvalue=None,
+ nomaxvalue=None,
+ cycle=None,
+ schema=None,
+ cache=None,
+ order=None,
+ optional=False,
+ quote=None,
+ metadata=None,
+ quote_schema=None,
+ for_update=False,
+ ):
"""Construct a :class:`.Sequence` object.
:param name: The name of the sequence.
@@ -2353,27 +2473,22 @@ class Sequence(DefaultGenerator):
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaGenerator,
- self,
- checkfirst=checkfirst)
+ bind._run_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=True):
"""Drops this sequence from the database."""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaDropper,
- self,
- checkfirst=checkfirst)
+ bind._run_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
def _not_a_column_expr(self):
raise exc.InvalidRequestError(
"This %s cannot be used directly "
"as a column expression. Use func.next_value(sequence) "
"to produce a 'next value' function that's usable "
- "as a column element."
- % self.__class__.__name__)
-
+ "as a column element." % self.__class__.__name__
+ )
@inspection._self_inspects
@@ -2396,6 +2511,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget):
:ref:`triggered_columns`
"""
+
is_server_default = True
reflected = False
has_argument = False
@@ -2412,7 +2528,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget):
def _clone(self, for_update):
n = self.__class__.__new__(self.__class__)
n.__dict__.update(self.__dict__)
- n.__dict__.pop('column', None)
+ n.__dict__.pop("column", None)
n.for_update = for_update
return n
@@ -2452,16 +2568,15 @@ class DefaultClause(FetchedValue):
has_argument = True
def __init__(self, arg, for_update=False, _reflected=False):
- util.assert_arg_type(arg, (util.string_types[0],
- ClauseElement,
- TextClause), 'arg')
+ util.assert_arg_type(
+ arg, (util.string_types[0], ClauseElement, TextClause), "arg"
+ )
super(DefaultClause, self).__init__(for_update)
self.arg = arg
self.reflected = _reflected
def __repr__(self):
- return "DefaultClause(%r, for_update=%r)" % \
- (self.arg, self.for_update)
+ return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update)
class PassiveDefault(DefaultClause):
@@ -2471,10 +2586,13 @@ class PassiveDefault(DefaultClause):
:class:`.PassiveDefault` is deprecated.
Use :class:`.DefaultClause`.
"""
- @util.deprecated("0.6",
- ":class:`.PassiveDefault` is deprecated. "
- "Use :class:`.DefaultClause`.",
- False)
+
+ @util.deprecated(
+ "0.6",
+ ":class:`.PassiveDefault` is deprecated. "
+ "Use :class:`.DefaultClause`.",
+ False,
+ )
def __init__(self, *arg, **kw):
DefaultClause.__init__(self, *arg, **kw)
@@ -2482,11 +2600,18 @@ class PassiveDefault(DefaultClause):
class Constraint(DialectKWArgs, SchemaItem):
"""A table-level SQL constraint."""
- __visit_name__ = 'constraint'
-
- def __init__(self, name=None, deferrable=None, initially=None,
- _create_rule=None, info=None, _type_bound=False,
- **dialect_kw):
+ __visit_name__ = "constraint"
+
+ def __init__(
+ self,
+ name=None,
+ deferrable=None,
+ initially=None,
+ _create_rule=None,
+ info=None,
+ _type_bound=False,
+ **dialect_kw
+ ):
r"""Create a SQL constraint.
:param name:
@@ -2548,7 +2673,8 @@ class Constraint(DialectKWArgs, SchemaItem):
pass
raise exc.InvalidRequestError(
"This constraint is not bound to a table. Did you "
- "mean to call table.append_constraint(constraint) ?")
+ "mean to call table.append_constraint(constraint) ?"
+ )
def _set_parent(self, parent):
self.parent = parent
@@ -2559,7 +2685,7 @@ class Constraint(DialectKWArgs, SchemaItem):
def _to_schema_column(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
if not isinstance(element, Column):
raise exc.ArgumentError("schema.Column object expected")
@@ -2567,9 +2693,9 @@ def _to_schema_column(element):
def _to_schema_column_or_string(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
- if not isinstance(element, util.string_types + (ColumnElement, )):
+ if not isinstance(element, util.string_types + (ColumnElement,)):
msg = "Element %r is not a string name or column element"
raise exc.ArgumentError(msg % element)
return element
@@ -2588,11 +2714,12 @@ class ColumnCollectionMixin(object):
_allow_multiple_tables = False
def __init__(self, *columns, **kw):
- _autoattach = kw.pop('_autoattach', True)
- self._column_flag = kw.pop('_column_flag', False)
+ _autoattach = kw.pop("_autoattach", True)
+ self._column_flag = kw.pop("_column_flag", False)
self.columns = ColumnCollection()
- self._pending_colargs = [_to_schema_column_or_string(c)
- for c in columns]
+ self._pending_colargs = [
+ _to_schema_column_or_string(c) for c in columns
+ ]
if _autoattach and self._pending_colargs:
self._check_attach()
@@ -2601,7 +2728,7 @@ class ColumnCollectionMixin(object):
for expr in expressions:
strname = None
column = None
- if hasattr(expr, '__clause_element__'):
+ if hasattr(expr, "__clause_element__"):
expr = expr.__clause_element__()
if not isinstance(expr, (ColumnElement, TextClause)):
@@ -2609,21 +2736,16 @@ class ColumnCollectionMixin(object):
strname = expr
else:
cols = []
- visitors.traverse(expr, {}, {'column': cols.append})
+ visitors.traverse(expr, {}, {"column": cols.append})
if cols:
column = cols[0]
add_element = column if column is not None else strname
yield expr, column, strname, add_element
def _check_attach(self, evt=False):
- col_objs = [
- c for c in self._pending_colargs
- if isinstance(c, Column)
- ]
+ col_objs = [c for c in self._pending_colargs if isinstance(c, Column)]
- cols_w_table = [
- c for c in col_objs if isinstance(c.table, Table)
- ]
+ cols_w_table = [c for c in col_objs if isinstance(c.table, Table)]
cols_wo_table = set(col_objs).difference(cols_w_table)
@@ -2636,6 +2758,7 @@ class ColumnCollectionMixin(object):
# columns are specified as strings.
has_string_cols = set(self._pending_colargs).difference(col_objs)
if not has_string_cols:
+
def _col_attached(column, table):
# this isinstance() corresponds with the
# isinstance() above; only want to count Table-bound
@@ -2644,6 +2767,7 @@ class ColumnCollectionMixin(object):
cols_wo_table.discard(column)
if not cols_wo_table:
self._check_attach(evt=True)
+
self._cols_wo_table = cols_wo_table
for col in cols_wo_table:
col._on_table_attach(_col_attached)
@@ -2659,9 +2783,11 @@ class ColumnCollectionMixin(object):
others = [c for c in columns[1:] if c.table is not table]
if others:
raise exc.ArgumentError(
- "Column(s) %s are not part of table '%s'." %
- (", ".join("'%s'" % c for c in others),
- table.description)
+ "Column(s) %s are not part of table '%s'."
+ % (
+ ", ".join("'%s'" % c for c in others),
+ table.description,
+ )
)
def _set_parent(self, table):
@@ -2694,11 +2820,12 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
arguments are propagated to the :class:`.Constraint` superclass.
"""
- _autoattach = kw.pop('_autoattach', True)
- _column_flag = kw.pop('_column_flag', False)
+ _autoattach = kw.pop("_autoattach", True)
+ _column_flag = kw.pop("_column_flag", False)
Constraint.__init__(self, **kw)
ColumnCollectionMixin.__init__(
- self, *columns, _autoattach=_autoattach, _column_flag=_column_flag)
+ self, *columns, _autoattach=_autoattach, _column_flag=_column_flag
+ )
columns = None
"""A :class:`.ColumnCollection` representing the set of columns
@@ -2714,8 +2841,12 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
return x in self.columns
def copy(self, **kw):
- c = self.__class__(name=self.name, deferrable=self.deferrable,
- initially=self.initially, *self.columns.keys())
+ c = self.__class__(
+ name=self.name,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ *self.columns.keys()
+ )
return self._schema_item_copy(c)
def contains_column(self, col):
@@ -2747,9 +2878,19 @@ class CheckConstraint(ColumnCollectionConstraint):
_allow_multiple_tables = True
- def __init__(self, sqltext, name=None, deferrable=None,
- initially=None, table=None, info=None, _create_rule=None,
- _autoattach=True, _type_bound=False, **kw):
+ def __init__(
+ self,
+ sqltext,
+ name=None,
+ deferrable=None,
+ initially=None,
+ table=None,
+ info=None,
+ _create_rule=None,
+ _autoattach=True,
+ _type_bound=False,
+ **kw
+ ):
r"""Construct a CHECK constraint.
:param sqltext:
@@ -2781,14 +2922,19 @@ class CheckConstraint(ColumnCollectionConstraint):
self.sqltext = _literal_as_text(sqltext, warn=False)
columns = []
- visitors.traverse(self.sqltext, {}, {'column': columns.append})
-
- super(CheckConstraint, self).\
- __init__(
- name=name, deferrable=deferrable,
- initially=initially, _create_rule=_create_rule, info=info,
- _type_bound=_type_bound, _autoattach=_autoattach,
- *columns, **kw)
+ visitors.traverse(self.sqltext, {}, {"column": columns.append})
+
+ super(CheckConstraint, self).__init__(
+ name=name,
+ deferrable=deferrable,
+ initially=initially,
+ _create_rule=_create_rule,
+ info=info,
+ _type_bound=_type_bound,
+ _autoattach=_autoattach,
+ *columns,
+ **kw
+ )
if table is not None:
self._set_parent_with_dispatch(table)
@@ -2797,22 +2943,24 @@ class CheckConstraint(ColumnCollectionConstraint):
return "check_constraint"
else:
return "column_check_constraint"
+
__visit_name__ = property(__visit_name__)
def copy(self, target_table=None, **kw):
if target_table is not None:
- sqltext = _copy_expression(
- self.sqltext, self.table, target_table)
+ sqltext = _copy_expression(self.sqltext, self.table, target_table)
else:
sqltext = self.sqltext
- c = CheckConstraint(sqltext,
- name=self.name,
- initially=self.initially,
- deferrable=self.deferrable,
- _create_rule=self._create_rule,
- table=target_table,
- _autoattach=False,
- _type_bound=self._type_bound)
+ c = CheckConstraint(
+ sqltext,
+ name=self.name,
+ initially=self.initially,
+ deferrable=self.deferrable,
+ _create_rule=self._create_rule,
+ table=target_table,
+ _autoattach=False,
+ _type_bound=self._type_bound,
+ )
return self._schema_item_copy(c)
@@ -2828,12 +2976,25 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
Examples of foreign key configuration are in :ref:`metadata_foreignkeys`.
"""
- __visit_name__ = 'foreign_key_constraint'
- def __init__(self, columns, refcolumns, name=None, onupdate=None,
- ondelete=None, deferrable=None, initially=None,
- use_alter=False, link_to_name=False, match=None,
- table=None, info=None, **dialect_kw):
+ __visit_name__ = "foreign_key_constraint"
+
+ def __init__(
+ self,
+ columns,
+ refcolumns,
+ name=None,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ use_alter=False,
+ link_to_name=False,
+ match=None,
+ table=None,
+ info=None,
+ **dialect_kw
+ ):
r"""Construct a composite-capable FOREIGN KEY.
:param columns: A sequence of local column names. The named columns
@@ -2905,8 +3066,13 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
"""
Constraint.__init__(
- self, name=name, deferrable=deferrable, initially=initially,
- info=info, **dialect_kw)
+ self,
+ name=name,
+ deferrable=deferrable,
+ initially=initially,
+ info=info,
+ **dialect_kw
+ )
self.onupdate = onupdate
self.ondelete = ondelete
self.link_to_name = link_to_name
@@ -2927,7 +3093,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
raise exc.ArgumentError(
"ForeignKeyConstraint number "
"of constrained columns must match the number of "
- "referenced columns.")
+ "referenced columns."
+ )
# standalone ForeignKeyConstraint - create
# associated ForeignKey objects which will be applied to hosted
@@ -2946,7 +3113,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
deferrable=self.deferrable,
initially=self.initially,
**self.dialect_kwargs
- ) for refcol in refcolumns
+ )
+ for refcol in refcolumns
]
ColumnCollectionMixin.__init__(self, *columns)
@@ -2978,9 +3146,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
@property
def _elements(self):
# legacy - provide a dictionary view of (column_key, fk)
- return util.OrderedDict(
- zip(self.column_keys, self.elements)
- )
+ return util.OrderedDict(zip(self.column_keys, self.elements))
@property
def _referred_schema(self):
@@ -3004,18 +3170,14 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
return self.elements[0].column.table
def _validate_dest_table(self, table):
- table_keys = set([elem._table_key()
- for elem in self.elements])
+ table_keys = set([elem._table_key() for elem in self.elements])
if None not in table_keys and len(table_keys) > 1:
elem0, elem1 = sorted(table_keys)[0:2]
raise exc.ArgumentError(
- 'ForeignKeyConstraint on %s(%s) refers to '
- 'multiple remote tables: %s and %s' % (
- table.fullname,
- self._col_description,
- elem0,
- elem1
- ))
+ "ForeignKeyConstraint on %s(%s) refers to "
+ "multiple remote tables: %s and %s"
+ % (table.fullname, self._col_description, elem0, elem1)
+ )
@property
def column_keys(self):
@@ -3034,8 +3196,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
return self.columns.keys()
else:
return [
- col.key if isinstance(col, ColumnElement)
- else str(col) for col in self._pending_colargs
+ col.key if isinstance(col, ColumnElement) else str(col)
+ for col in self._pending_colargs
]
@property
@@ -3051,11 +3213,11 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
raise exc.ArgumentError(
"Can't create ForeignKeyConstraint "
"on table '%s': no column "
- "named '%s' is present." % (table.description, ke.args[0]))
+ "named '%s' is present." % (table.description, ke.args[0])
+ )
for col, fk in zip(self.columns, self.elements):
- if not hasattr(fk, 'parent') or \
- fk.parent is not col:
+ if not hasattr(fk, "parent") or fk.parent is not col:
fk._set_parent_with_dispatch(col)
self._validate_dest_table(table)
@@ -3063,13 +3225,16 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
def copy(self, schema=None, target_table=None, **kw):
fkc = ForeignKeyConstraint(
[x.parent.key for x in self.elements],
- [x._get_colspec(
- schema=schema,
- table_name=target_table.name
- if target_table is not None
- and x._table_key() == x.parent.table.key
- else None)
- for x in self.elements],
+ [
+ x._get_colspec(
+ schema=schema,
+ table_name=target_table.name
+ if target_table is not None
+ and x._table_key() == x.parent.table.key
+ else None,
+ )
+ for x in self.elements
+ ],
name=self.name,
onupdate=self.onupdate,
ondelete=self.ondelete,
@@ -3077,11 +3242,9 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
deferrable=self.deferrable,
initially=self.initially,
link_to_name=self.link_to_name,
- match=self.match
+ match=self.match,
)
- for self_fk, other_fk in zip(
- self.elements,
- fkc.elements):
+ for self_fk, other_fk in zip(self.elements, fkc.elements):
self_fk._schema_item_copy(other_fk)
return self._schema_item_copy(fkc)
@@ -3160,10 +3323,10 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
"""
- __visit_name__ = 'primary_key_constraint'
+ __visit_name__ = "primary_key_constraint"
def __init__(self, *columns, **kw):
- self._implicit_generated = kw.pop('_implicit_generated', False)
+ self._implicit_generated = kw.pop("_implicit_generated", False)
super(PrimaryKeyConstraint, self).__init__(*columns, **kw)
def _set_parent(self, table):
@@ -3175,18 +3338,21 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
table.constraints.add(self)
table_pks = [c for c in table.c if c.primary_key]
- if self.columns and table_pks and \
- set(table_pks) != set(self.columns.values()):
+ if (
+ self.columns
+ and table_pks
+ and set(table_pks) != set(self.columns.values())
+ ):
util.warn(
"Table '%s' specifies columns %s as primary_key=True, "
"not matching locally specified columns %s; setting the "
"current primary key columns to %s. This warning "
- "may become an exception in a future release" %
- (
+ "may become an exception in a future release"
+ % (
table.name,
", ".join("'%s'" % c.name for c in table_pks),
", ".join("'%s'" % c.name for c in self.columns),
- ", ".join("'%s'" % c.name for c in self.columns)
+ ", ".join("'%s'" % c.name for c in self.columns),
)
)
table_pks[:] = []
@@ -3241,28 +3407,28 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
@util.memoized_property
def _autoincrement_column(self):
-
def _validate_autoinc(col, autoinc_true):
if col.type._type_affinity is None or not issubclass(
- col.type._type_affinity,
- type_api.INTEGERTYPE._type_affinity):
+ col.type._type_affinity, type_api.INTEGERTYPE._type_affinity
+ ):
if autoinc_true:
raise exc.ArgumentError(
"Column type %s on column '%s' is not "
- "compatible with autoincrement=True" % (
- col.type,
- col
- ))
+ "compatible with autoincrement=True" % (col.type, col)
+ )
else:
return False
- elif not isinstance(col.default, (type(None), Sequence)) and \
- not autoinc_true:
- return False
+ elif (
+ not isinstance(col.default, (type(None), Sequence))
+ and not autoinc_true
+ ):
+ return False
elif col.server_default is not None and not autoinc_true:
return False
- elif (
- col.foreign_keys and col.autoincrement
- not in (True, 'ignore_fk')):
+ elif col.foreign_keys and col.autoincrement not in (
+ True,
+ "ignore_fk",
+ ):
return False
return True
@@ -3272,10 +3438,10 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
if col.autoincrement is True:
_validate_autoinc(col, True)
return col
- elif (
- col.autoincrement in ('auto', 'ignore_fk') and
- _validate_autoinc(col, False)
- ):
+ elif col.autoincrement in (
+ "auto",
+ "ignore_fk",
+ ) and _validate_autoinc(col, False):
return col
else:
@@ -3286,8 +3452,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
if autoinc is not None:
raise exc.ArgumentError(
"Only one Column may be marked "
- "autoincrement=True, found both %s and %s." %
- (col.name, autoinc.name)
+ "autoincrement=True, found both %s and %s."
+ % (col.name, autoinc.name)
)
else:
autoinc = col
@@ -3304,7 +3470,7 @@ class UniqueConstraint(ColumnCollectionConstraint):
UniqueConstraint.
"""
- __visit_name__ = 'unique_constraint'
+ __visit_name__ = "unique_constraint"
class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
@@ -3382,7 +3548,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
"""
- __visit_name__ = 'index'
+ __visit_name__ = "index"
def __init__(self, name, *expressions, **kw):
r"""Construct an index object.
@@ -3420,30 +3586,35 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
columns = []
processed_expressions = []
- for expr, column, strname, add_element in self.\
- _extract_col_expression_collection(expressions):
+ for (
+ expr,
+ column,
+ strname,
+ add_element,
+ ) in self._extract_col_expression_collection(expressions):
if add_element is not None:
columns.append(add_element)
processed_expressions.append(expr)
self.expressions = processed_expressions
self.name = quoted_name(name, kw.pop("quote", None))
- self.unique = kw.pop('unique', False)
- _column_flag = kw.pop('_column_flag', False)
- if 'info' in kw:
- self.info = kw.pop('info')
+ self.unique = kw.pop("unique", False)
+ _column_flag = kw.pop("_column_flag", False)
+ if "info" in kw:
+ self.info = kw.pop("info")
# TODO: consider "table" argument being public, but for
# the purpose of the fix here, it starts as private.
- if '_table' in kw:
- table = kw.pop('_table')
+ if "_table" in kw:
+ table = kw.pop("_table")
self._validate_dialect_kwargs(kw)
# will call _set_parent() if table-bound column
# objects are present
ColumnCollectionMixin.__init__(
- self, *columns, _column_flag=_column_flag)
+ self, *columns, _column_flag=_column_flag
+ )
if table is not None:
self._set_parent(table)
@@ -3454,20 +3625,17 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
if self.table is not None and table is not self.table:
raise exc.ArgumentError(
"Index '%s' is against table '%s', and "
- "cannot be associated with table '%s'." % (
- self.name,
- self.table.description,
- table.description
- )
+ "cannot be associated with table '%s'."
+ % (self.name, self.table.description, table.description)
)
self.table = table
table.indexes.add(self)
self.expressions = [
- expr if isinstance(expr, ClauseElement)
- else colexpr
- for expr, colexpr in util.zip_longest(self.expressions,
- self.columns)
+ expr if isinstance(expr, ClauseElement) else colexpr
+ for expr, colexpr in util.zip_longest(
+ self.expressions, self.columns
+ )
]
@property
@@ -3506,17 +3674,16 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
bind._run_visitor(ddl.SchemaDropper, self)
def __repr__(self):
- return 'Index(%s)' % (
+ return "Index(%s)" % (
", ".join(
- [repr(self.name)] +
- [repr(e) for e in self.expressions] +
- (self.unique and ["unique=True"] or [])
- ))
+ [repr(self.name)]
+ + [repr(e) for e in self.expressions]
+ + (self.unique and ["unique=True"] or [])
+ )
+ )
-DEFAULT_NAMING_CONVENTION = util.immutabledict({
- "ix": 'ix_%(column_0_label)s'
-})
+DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"})
class MetaData(SchemaItem):
@@ -3542,13 +3709,17 @@ class MetaData(SchemaItem):
"""
- __visit_name__ = 'metadata'
-
- def __init__(self, bind=None, reflect=False, schema=None,
- quote_schema=None,
- naming_convention=DEFAULT_NAMING_CONVENTION,
- info=None
- ):
+ __visit_name__ = "metadata"
+
+ def __init__(
+ self,
+ bind=None,
+ reflect=False,
+ schema=None,
+ quote_schema=None,
+ naming_convention=DEFAULT_NAMING_CONVENTION,
+ info=None,
+ ):
"""Create a new MetaData object.
:param bind:
@@ -3712,12 +3883,15 @@ class MetaData(SchemaItem):
self.bind = bind
if reflect:
- util.warn_deprecated("reflect=True is deprecate; please "
- "use the reflect() method.")
+ util.warn_deprecated(
+ "reflect=True is deprecate; please "
+ "use the reflect() method."
+ )
if not bind:
raise exc.ArgumentError(
"A bind must be supplied in conjunction "
- "with reflect=True")
+ "with reflect=True"
+ )
self.reflect()
tables = None
@@ -3735,7 +3909,7 @@ class MetaData(SchemaItem):
"""
def __repr__(self):
- return 'MetaData(bind=%r)' % self.bind
+ return "MetaData(bind=%r)" % self.bind
def __contains__(self, table_or_key):
if not isinstance(table_or_key, util.string_types):
@@ -3755,27 +3929,32 @@ class MetaData(SchemaItem):
for fk in removed.foreign_keys:
fk._remove_from_metadata(self)
if self._schemas:
- self._schemas = set([t.schema
- for t in self.tables.values()
- if t.schema is not None])
+ self._schemas = set(
+ [
+ t.schema
+ for t in self.tables.values()
+ if t.schema is not None
+ ]
+ )
def __getstate__(self):
- return {'tables': self.tables,
- 'schema': self.schema,
- 'schemas': self._schemas,
- 'sequences': self._sequences,
- 'fk_memos': self._fk_memos,
- 'naming_convention': self.naming_convention
- }
+ return {
+ "tables": self.tables,
+ "schema": self.schema,
+ "schemas": self._schemas,
+ "sequences": self._sequences,
+ "fk_memos": self._fk_memos,
+ "naming_convention": self.naming_convention,
+ }
def __setstate__(self, state):
- self.tables = state['tables']
- self.schema = state['schema']
- self.naming_convention = state['naming_convention']
+ self.tables = state["tables"]
+ self.schema = state["schema"]
+ self.naming_convention = state["naming_convention"]
self._bind = None
- self._sequences = state['sequences']
- self._schemas = state['schemas']
- self._fk_memos = state['fk_memos']
+ self._sequences = state["sequences"]
+ self._schemas = state["schemas"]
+ self._fk_memos = state["fk_memos"]
def is_bound(self):
"""True if this MetaData is bound to an Engine or Connection."""
@@ -3805,10 +3984,11 @@ class MetaData(SchemaItem):
def _bind_to(self, url, bind):
"""Bind this MetaData to an Engine, Connection, string or URL."""
- if isinstance(bind, util.string_types + (url.URL, )):
+ if isinstance(bind, util.string_types + (url.URL,)):
self._bind = sqlalchemy.create_engine(bind)
else:
self._bind = bind
+
bind = property(bind, _bind_to)
def clear(self):
@@ -3858,12 +4038,20 @@ class MetaData(SchemaItem):
"""
- return ddl.sort_tables(sorted(self.tables.values(), key=lambda t: t.key))
+ return ddl.sort_tables(
+ sorted(self.tables.values(), key=lambda t: t.key)
+ )
- def reflect(self, bind=None, schema=None, views=False, only=None,
- extend_existing=False,
- autoload_replace=True,
- **dialect_kwargs):
+ def reflect(
+ self,
+ bind=None,
+ schema=None,
+ views=False,
+ only=None,
+ extend_existing=False,
+ autoload_replace=True,
+ **dialect_kwargs
+ ):
r"""Load all available table definitions from the database.
Automatically creates ``Table`` entries in this ``MetaData`` for any
@@ -3926,11 +4114,11 @@ class MetaData(SchemaItem):
with bind.connect() as conn:
reflect_opts = {
- 'autoload': True,
- 'autoload_with': conn,
- 'extend_existing': extend_existing,
- 'autoload_replace': autoload_replace,
- '_extend_on': set()
+ "autoload": True,
+ "autoload_with": conn,
+ "extend_existing": extend_existing,
+ "autoload_replace": autoload_replace,
+ "_extend_on": set(),
}
reflect_opts.update(dialect_kwargs)
@@ -3939,42 +4127,49 @@ class MetaData(SchemaItem):
schema = self.schema
if schema is not None:
- reflect_opts['schema'] = schema
+ reflect_opts["schema"] = schema
available = util.OrderedSet(
- bind.engine.table_names(schema, connection=conn))
+ bind.engine.table_names(schema, connection=conn)
+ )
if views:
- available.update(
- bind.dialect.get_view_names(conn, schema)
- )
+ available.update(bind.dialect.get_view_names(conn, schema))
if schema is not None:
- available_w_schema = util.OrderedSet(["%s.%s" % (schema, name)
- for name in available])
+ available_w_schema = util.OrderedSet(
+ ["%s.%s" % (schema, name) for name in available]
+ )
else:
available_w_schema = available
current = set(self.tables)
if only is None:
- load = [name for name, schname in
- zip(available, available_w_schema)
- if extend_existing or schname not in current]
+ load = [
+ name
+ for name, schname in zip(available, available_w_schema)
+ if extend_existing or schname not in current
+ ]
elif util.callable(only):
- load = [name for name, schname in
- zip(available, available_w_schema)
- if (extend_existing or schname not in current)
- and only(name, self)]
+ load = [
+ name
+ for name, schname in zip(available, available_w_schema)
+ if (extend_existing or schname not in current)
+ and only(name, self)
+ ]
else:
missing = [name for name in only if name not in available]
if missing:
- s = schema and (" schema '%s'" % schema) or ''
+ s = schema and (" schema '%s'" % schema) or ""
raise exc.InvalidRequestError(
- 'Could not reflect: requested table(s) not available '
- 'in %r%s: (%s)' %
- (bind.engine, s, ', '.join(missing)))
- load = [name for name in only if extend_existing or
- name not in current]
+ "Could not reflect: requested table(s) not available "
+ "in %r%s: (%s)" % (bind.engine, s, ", ".join(missing))
+ )
+ load = [
+ name
+ for name in only
+ if extend_existing or name not in current
+ ]
for name in load:
try:
@@ -3989,11 +4184,12 @@ class MetaData(SchemaItem):
See :class:`.DDLEvents`.
"""
+
def adapt_listener(target, connection, **kw):
- tables = kw['tables']
+ tables = kw["tables"]
listener(event, target, connection, tables=tables)
- event.listen(self, "" + event_name.replace('-', '_'), adapt_listener)
+ event.listen(self, "" + event_name.replace("-", "_"), adapt_listener)
def create_all(self, bind=None, tables=None, checkfirst=True):
"""Create all tables stored in this metadata.
@@ -4017,10 +4213,9 @@ class MetaData(SchemaItem):
"""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaGenerator,
- self,
- checkfirst=checkfirst,
- tables=tables)
+ bind._run_visitor(
+ ddl.SchemaGenerator, self, checkfirst=checkfirst, tables=tables
+ )
def drop_all(self, bind=None, tables=None, checkfirst=True):
"""Drop all tables stored in this metadata.
@@ -4044,10 +4239,9 @@ class MetaData(SchemaItem):
"""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaDropper,
- self,
- checkfirst=checkfirst,
- tables=tables)
+ bind._run_visitor(
+ ddl.SchemaDropper, self, checkfirst=checkfirst, tables=tables
+ )
class ThreadLocalMetaData(MetaData):
@@ -4064,7 +4258,7 @@ class ThreadLocalMetaData(MetaData):
"""
- __visit_name__ = 'metadata'
+ __visit_name__ = "metadata"
def __init__(self):
"""Construct a ThreadLocalMetaData."""
@@ -4080,13 +4274,13 @@ class ThreadLocalMetaData(MetaData):
string or URL to automatically create a basic Engine for this bind
with ``create_engine()``."""
- return getattr(self.context, '_engine', None)
+ return getattr(self.context, "_engine", None)
@util.dependencies("sqlalchemy.engine.url")
def _bind_to(self, url, bind):
"""Bind to a Connectable in the caller's thread."""
- if isinstance(bind, util.string_types + (url.URL, )):
+ if isinstance(bind, util.string_types + (url.URL,)):
try:
self.context._engine = self.__engines[bind]
except KeyError:
@@ -4104,14 +4298,16 @@ class ThreadLocalMetaData(MetaData):
def is_bound(self):
"""True if there is a bind for this thread."""
- return (hasattr(self.context, '_engine') and
- self.context._engine is not None)
+ return (
+ hasattr(self.context, "_engine")
+ and self.context._engine is not None
+ )
def dispose(self):
"""Dispose all bound engines, in all thread contexts."""
for e in self.__engines.values():
- if hasattr(e, 'dispose'):
+ if hasattr(e, "dispose"):
e.dispose()
@@ -4128,22 +4324,25 @@ class _SchemaTranslateMap(object):
"""
- __slots__ = 'map_', '__call__', 'hash_key', 'is_default'
+
+ __slots__ = "map_", "__call__", "hash_key", "is_default"
_default_schema_getter = operator.attrgetter("schema")
def __init__(self, map_):
self.map_ = map_
if map_ is not None:
+
def schema_for_object(obj):
effective_schema = self._default_schema_getter(obj)
effective_schema = obj._translate_schema(
- effective_schema, map_)
+ effective_schema, map_
+ )
return effective_schema
+
self.__call__ = schema_for_object
self.hash_key = ";".join(
- "%s=%s" % (k, map_[k])
- for k in sorted(map_, key=str)
+ "%s=%s" % (k, map_[k]) for k in sorted(map_, key=str)
)
self.is_default = False
else:
@@ -4160,6 +4359,6 @@ class _SchemaTranslateMap(object):
else:
return _SchemaTranslateMap(map_)
+
_default_schema_map = _SchemaTranslateMap(None)
_schema_getter = _SchemaTranslateMap._schema_getter
-
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index f64f152c4..1f1800514 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -10,15 +10,39 @@ SQL tables and derived rowsets.
"""
-from .elements import ClauseElement, TextClause, ClauseList, \
- and_, Grouping, UnaryExpression, literal_column, BindParameter
-from .elements import _clone, \
- _literal_as_text, _interpret_as_column_or_from, _expand_cloned,\
- _select_iterables, _anonymous_label, _clause_element_as_expr,\
- _cloned_intersection, _cloned_difference, True_, \
- _literal_as_label_reference, _literal_and_labels_as_label_reference
-from .base import Immutable, Executable, _generative, \
- ColumnCollection, ColumnSet, _from_objects, Generative
+from .elements import (
+ ClauseElement,
+ TextClause,
+ ClauseList,
+ and_,
+ Grouping,
+ UnaryExpression,
+ literal_column,
+ BindParameter,
+)
+from .elements import (
+ _clone,
+ _literal_as_text,
+ _interpret_as_column_or_from,
+ _expand_cloned,
+ _select_iterables,
+ _anonymous_label,
+ _clause_element_as_expr,
+ _cloned_intersection,
+ _cloned_difference,
+ True_,
+ _literal_as_label_reference,
+ _literal_and_labels_as_label_reference,
+)
+from .base import (
+ Immutable,
+ Executable,
+ _generative,
+ ColumnCollection,
+ ColumnSet,
+ _from_objects,
+ Generative,
+)
from . import type_api
from .. import inspection
from .. import util
@@ -40,7 +64,8 @@ def _interpret_as_from(element):
"Textual SQL FROM expression %(expr)r should be "
"explicitly declared as text(%(expr)r), "
"or use table(%(expr)r) for more specificity",
- {"expr": util.ellipses_string(element)})
+ {"expr": util.ellipses_string(element)},
+ )
return TextClause(util.text_type(element))
try:
@@ -73,7 +98,7 @@ def _offset_or_limit_clause(element, name=None, type_=None):
"""
if element is None:
return None
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif isinstance(element, Visitable):
return element
@@ -97,7 +122,8 @@ def _offset_or_limit_clause_asint(clause, attrname):
except AttributeError:
raise exc.CompileError(
"This SELECT structure does not use a simple "
- "integer value for %s" % attrname)
+ "integer value for %s" % attrname
+ )
else:
return util.asint(value)
@@ -225,12 +251,14 @@ def tablesample(selectable, sampling, name=None, seed=None):
"""
return _interpret_as_from(selectable).tablesample(
- sampling, name=name, seed=seed)
+ sampling, name=name, seed=seed
+ )
class Selectable(ClauseElement):
"""mark a class as being selectable"""
- __visit_name__ = 'selectable'
+
+ __visit_name__ = "selectable"
is_selectable = True
@@ -265,15 +293,17 @@ class HasPrefixes(object):
limit rendering of this prefix to only that dialect.
"""
- dialect = kw.pop('dialect', None)
+ dialect = kw.pop("dialect", None)
if kw:
- raise exc.ArgumentError("Unsupported argument(s): %s" %
- ",".join(kw))
+ raise exc.ArgumentError(
+ "Unsupported argument(s): %s" % ",".join(kw)
+ )
self._setup_prefixes(expr, dialect)
def _setup_prefixes(self, prefixes, dialect=None):
self._prefixes = self._prefixes + tuple(
- [(_literal_as_text(p, warn=False), dialect) for p in prefixes])
+ [(_literal_as_text(p, warn=False), dialect) for p in prefixes]
+ )
class HasSuffixes(object):
@@ -301,15 +331,17 @@ class HasSuffixes(object):
limit rendering of this suffix to only that dialect.
"""
- dialect = kw.pop('dialect', None)
+ dialect = kw.pop("dialect", None)
if kw:
- raise exc.ArgumentError("Unsupported argument(s): %s" %
- ",".join(kw))
+ raise exc.ArgumentError(
+ "Unsupported argument(s): %s" % ",".join(kw)
+ )
self._setup_suffixes(expr, dialect)
def _setup_suffixes(self, suffixes, dialect=None):
self._suffixes = self._suffixes + tuple(
- [(_literal_as_text(p, warn=False), dialect) for p in suffixes])
+ [(_literal_as_text(p, warn=False), dialect) for p in suffixes]
+ )
class FromClause(Selectable):
@@ -330,7 +362,8 @@ class FromClause(Selectable):
"""
- __visit_name__ = 'fromclause'
+
+ __visit_name__ = "fromclause"
named_with_column = False
_hide_froms = []
@@ -359,13 +392,14 @@ class FromClause(Selectable):
_memoized_property = util.group_expirable_memoized_property(["_columns"])
@util.deprecated(
- '1.1',
+ "1.1",
message="``FromClause.count()`` is deprecated. Counting "
"rows requires that the correct column expression and "
"accommodations for joins, DISTINCT, etc. must be made, "
"otherwise results may not be what's expected. "
"Please use an appropriate ``func.count()`` expression "
- "directly.")
+ "directly.",
+ )
@util.dependencies("sqlalchemy.sql.functions")
def count(self, functions, whereclause=None, **params):
"""return a SELECT COUNT generated against this
@@ -392,10 +426,11 @@ class FromClause(Selectable):
else:
col = list(self.columns)[0]
return Select(
- [functions.func.count(col).label('tbl_row_count')],
+ [functions.func.count(col).label("tbl_row_count")],
whereclause,
from_obj=[self],
- **params)
+ **params
+ )
def select(self, whereclause=None, **params):
"""return a SELECT of this :class:`.FromClause`.
@@ -603,8 +638,9 @@ class FromClause(Selectable):
def embedded(expanded_proxy_set, target_set):
for t in target_set.difference(expanded_proxy_set):
- if not set(_expand_cloned([t])
- ).intersection(expanded_proxy_set):
+ if not set(_expand_cloned([t])).intersection(
+ expanded_proxy_set
+ ):
return False
return True
@@ -617,8 +653,10 @@ class FromClause(Selectable):
for c in cols:
expanded_proxy_set = set(_expand_cloned(c.proxy_set))
i = target_set.intersection(expanded_proxy_set)
- if i and (not require_embedded
- or embedded(expanded_proxy_set, target_set)):
+ if i and (
+ not require_embedded
+ or embedded(expanded_proxy_set, target_set)
+ ):
if col is None:
# no corresponding column yet, pick this one.
@@ -646,12 +684,20 @@ class FromClause(Selectable):
col_distance = util.reduce(
operator.add,
- [sc._annotations.get('weight', 1) for sc in
- col.proxy_set if sc.shares_lineage(column)])
+ [
+ sc._annotations.get("weight", 1)
+ for sc in col.proxy_set
+ if sc.shares_lineage(column)
+ ],
+ )
c_distance = util.reduce(
operator.add,
- [sc._annotations.get('weight', 1) for sc in
- c.proxy_set if sc.shares_lineage(column)])
+ [
+ sc._annotations.get("weight", 1)
+ for sc in c.proxy_set
+ if sc.shares_lineage(column)
+ ],
+ )
if c_distance < col_distance:
col, intersect = c, i
return col
@@ -663,7 +709,7 @@ class FromClause(Selectable):
Used primarily for error message formatting.
"""
- return getattr(self, 'name', self.__class__.__name__ + " object")
+ return getattr(self, "name", self.__class__.__name__ + " object")
def _reset_exported(self):
"""delete memoized collections when a FromClause is cloned."""
@@ -683,7 +729,7 @@ class FromClause(Selectable):
"""
- if '_columns' not in self.__dict__:
+ if "_columns" not in self.__dict__:
self._init_collections()
self._populate_column_collection()
return self._columns.as_immutable()
@@ -706,14 +752,16 @@ class FromClause(Selectable):
self._populate_column_collection()
return self.foreign_keys
- c = property(attrgetter('columns'),
- doc="An alias for the :attr:`.columns` attribute.")
- _select_iterable = property(attrgetter('columns'))
+ c = property(
+ attrgetter("columns"),
+ doc="An alias for the :attr:`.columns` attribute.",
+ )
+ _select_iterable = property(attrgetter("columns"))
def _init_collections(self):
- assert '_columns' not in self.__dict__
- assert 'primary_key' not in self.__dict__
- assert 'foreign_keys' not in self.__dict__
+ assert "_columns" not in self.__dict__
+ assert "primary_key" not in self.__dict__
+ assert "foreign_keys" not in self.__dict__
self._columns = ColumnCollection()
self.primary_key = ColumnSet()
@@ -721,7 +769,7 @@ class FromClause(Selectable):
@property
def _cols_populated(self):
- return '_columns' in self.__dict__
+ return "_columns" in self.__dict__
def _populate_column_collection(self):
"""Called on subclasses to establish the .c collection.
@@ -758,8 +806,7 @@ class FromClause(Selectable):
"""
if not self._cols_populated:
return None
- elif (column.key in self.columns and
- self.columns[column.key] is column):
+ elif column.key in self.columns and self.columns[column.key] is column:
return column
else:
return None
@@ -780,7 +827,8 @@ class Join(FromClause):
:meth:`.FromClause.join`
"""
- __visit_name__ = 'join'
+
+ __visit_name__ = "join"
_is_join = True
@@ -829,8 +877,9 @@ class Join(FromClause):
return cls(left, right, onclause, isouter=True, full=full)
@classmethod
- def _create_join(cls, left, right, onclause=None, isouter=False,
- full=False):
+ def _create_join(
+ cls, left, right, onclause=None, isouter=False, full=False
+ ):
"""Produce a :class:`.Join` object, given two :class:`.FromClause`
expressions.
@@ -882,26 +931,34 @@ class Join(FromClause):
self.left.description,
id(self.left),
self.right.description,
- id(self.right))
+ id(self.right),
+ )
def is_derived_from(self, fromclause):
- return fromclause is self or \
- self.left.is_derived_from(fromclause) or \
- self.right.is_derived_from(fromclause)
+ return (
+ fromclause is self
+ or self.left.is_derived_from(fromclause)
+ or self.right.is_derived_from(fromclause)
+ )
def self_group(self, against=None):
return FromGrouping(self)
@util.dependencies("sqlalchemy.sql.util")
def _populate_column_collection(self, sqlutil):
- columns = [c for c in self.left.columns] + \
- [c for c in self.right.columns]
+ columns = [c for c in self.left.columns] + [
+ c for c in self.right.columns
+ ]
- self.primary_key.extend(sqlutil.reduce_columns(
- (c for c in columns if c.primary_key), self.onclause))
+ self.primary_key.extend(
+ sqlutil.reduce_columns(
+ (c for c in columns if c.primary_key), self.onclause
+ )
+ )
self._columns.update((col._label, col) for col in columns)
- self.foreign_keys.update(itertools.chain(
- *[col.foreign_keys for col in columns]))
+ self.foreign_keys.update(
+ itertools.chain(*[col.foreign_keys for col in columns])
+ )
def _refresh_for_new_column(self, column):
col = self.left._refresh_for_new_column(column)
@@ -933,9 +990,14 @@ class Join(FromClause):
return self._join_condition(left, right, a_subset=left_right)
@classmethod
- def _join_condition(cls, a, b, ignore_nonexistent_tables=False,
- a_subset=None,
- consider_as_foreign_keys=None):
+ def _join_condition(
+ cls,
+ a,
+ b,
+ ignore_nonexistent_tables=False,
+ a_subset=None,
+ consider_as_foreign_keys=None,
+ ):
"""create a join condition between two tables or selectables.
e.g.::
@@ -963,26 +1025,31 @@ class Join(FromClause):
"""
constraints = cls._joincond_scan_left_right(
- a, a_subset, b, consider_as_foreign_keys)
+ a, a_subset, b, consider_as_foreign_keys
+ )
if len(constraints) > 1:
cls._joincond_trim_constraints(
- a, b, constraints, consider_as_foreign_keys)
+ a, b, constraints, consider_as_foreign_keys
+ )
if len(constraints) == 0:
if isinstance(b, FromGrouping):
- hint = " Perhaps you meant to convert the right side to a "\
+ hint = (
+ " Perhaps you meant to convert the right side to a "
"subquery using alias()?"
+ )
else:
hint = ""
raise exc.NoForeignKeysError(
"Can't find any foreign key relationships "
- "between '%s' and '%s'.%s" %
- (a.description, b.description, hint))
+ "between '%s' and '%s'.%s"
+ % (a.description, b.description, hint)
+ )
crit = [(x == y) for x, y in list(constraints.values())[0]]
if len(crit) == 1:
- return (crit[0])
+ return crit[0]
else:
return and_(*crit)
@@ -994,24 +1061,30 @@ class Join(FromClause):
left_right = None
constraints = cls._joincond_scan_left_right(
- a=left, b=right, a_subset=left_right,
- consider_as_foreign_keys=consider_as_foreign_keys)
+ a=left,
+ b=right,
+ a_subset=left_right,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
return bool(constraints)
@classmethod
def _joincond_scan_left_right(
- cls, a, a_subset, b, consider_as_foreign_keys):
+ cls, a, a_subset, b, consider_as_foreign_keys
+ ):
constraints = collections.defaultdict(list)
for left in (a_subset, a):
if left is None:
continue
for fk in sorted(
- b.foreign_keys,
- key=lambda fk: fk.parent._creation_order):
- if consider_as_foreign_keys is not None and \
- fk.parent not in consider_as_foreign_keys:
+ b.foreign_keys, key=lambda fk: fk.parent._creation_order
+ ):
+ if (
+ consider_as_foreign_keys is not None
+ and fk.parent not in consider_as_foreign_keys
+ ):
continue
try:
col = fk.get_referent(left)
@@ -1025,10 +1098,12 @@ class Join(FromClause):
constraints[fk.constraint].append((col, fk.parent))
if left is not b:
for fk in sorted(
- left.foreign_keys,
- key=lambda fk: fk.parent._creation_order):
- if consider_as_foreign_keys is not None and \
- fk.parent not in consider_as_foreign_keys:
+ left.foreign_keys, key=lambda fk: fk.parent._creation_order
+ ):
+ if (
+ consider_as_foreign_keys is not None
+ and fk.parent not in consider_as_foreign_keys
+ ):
continue
try:
col = fk.get_referent(b)
@@ -1046,14 +1121,16 @@ class Join(FromClause):
@classmethod
def _joincond_trim_constraints(
- cls, a, b, constraints, consider_as_foreign_keys):
+ cls, a, b, constraints, consider_as_foreign_keys
+ ):
# more than one constraint matched. narrow down the list
# to include just those FKCs that match exactly to
# "consider_as_foreign_keys".
if consider_as_foreign_keys:
for const in list(constraints):
if set(f.parent for f in const.elements) != set(
- consider_as_foreign_keys):
+ consider_as_foreign_keys
+ ):
del constraints[const]
# if still multiple constraints, but
@@ -1070,8 +1147,8 @@ class Join(FromClause):
"tables have more than one foreign key "
"constraint relationship between them. "
"Please specify the 'onclause' of this "
- "join explicitly." % (a.description, b.description))
-
+ "join explicitly." % (a.description, b.description)
+ )
def select(self, whereclause=None, **kwargs):
r"""Create a :class:`.Select` from this :class:`.Join`.
@@ -1200,27 +1277,37 @@ class Join(FromClause):
"""
if flat:
assert name is None, "Can't send name argument with flat"
- left_a, right_a = self.left.alias(flat=True), \
- self.right.alias(flat=True)
- adapter = sqlutil.ClauseAdapter(left_a).\
- chain(sqlutil.ClauseAdapter(right_a))
+ left_a, right_a = (
+ self.left.alias(flat=True),
+ self.right.alias(flat=True),
+ )
+ adapter = sqlutil.ClauseAdapter(left_a).chain(
+ sqlutil.ClauseAdapter(right_a)
+ )
- return left_a.join(right_a, adapter.traverse(self.onclause),
- isouter=self.isouter, full=self.full)
+ return left_a.join(
+ right_a,
+ adapter.traverse(self.onclause),
+ isouter=self.isouter,
+ full=self.full,
+ )
else:
return self.select(use_labels=True, correlate=False).alias(name)
@property
def _hide_froms(self):
- return itertools.chain(*[_from_objects(x.left, x.right)
- for x in self._cloned_set])
+ return itertools.chain(
+ *[_from_objects(x.left, x.right) for x in self._cloned_set]
+ )
@property
def _from_objects(self):
- return [self] + \
- self.onclause._from_objects + \
- self.left._from_objects + \
- self.right._from_objects
+ return (
+ [self]
+ + self.onclause._from_objects
+ + self.left._from_objects
+ + self.right._from_objects
+ )
class Alias(FromClause):
@@ -1236,7 +1323,7 @@ class Alias(FromClause):
"""
- __visit_name__ = 'alias'
+ __visit_name__ = "alias"
named_with_column = True
_is_from_container = True
@@ -1252,15 +1339,16 @@ class Alias(FromClause):
self.element = selectable
if name is None:
if self.original.named_with_column:
- name = getattr(self.original, 'name', None)
- name = _anonymous_label('%%(%d %s)s' % (id(self), name
- or 'anon'))
+ name = getattr(self.original, "name", None)
+ name = _anonymous_label("%%(%d %s)s" % (id(self), name or "anon"))
self.name = name
def self_group(self, against=None):
- if isinstance(against, CompoundSelect) and \
- isinstance(self.original, Select) and \
- self.original._needs_parens_for_grouping():
+ if (
+ isinstance(against, CompoundSelect)
+ and isinstance(self.original, Select)
+ and self.original._needs_parens_for_grouping()
+ ):
return FromGrouping(self)
return super(Alias, self).self_group(against=against)
@@ -1270,14 +1358,15 @@ class Alias(FromClause):
if util.py3k:
return self.name
else:
- return self.name.encode('ascii', 'backslashreplace')
+ return self.name.encode("ascii", "backslashreplace")
def as_scalar(self):
try:
return self.element.as_scalar()
except AttributeError:
- raise AttributeError("Element %s does not support "
- "'as_scalar()'" % self.element)
+ raise AttributeError(
+ "Element %s does not support " "'as_scalar()'" % self.element
+ )
def is_derived_from(self, fromclause):
if fromclause in self._cloned_set:
@@ -1344,7 +1433,7 @@ class Lateral(Alias):
"""
- __visit_name__ = 'lateral'
+ __visit_name__ = "lateral"
_is_lateral = True
@@ -1363,11 +1452,9 @@ class TableSample(Alias):
"""
- __visit_name__ = 'tablesample'
+ __visit_name__ = "tablesample"
- def __init__(self, selectable, sampling,
- name=None,
- seed=None):
+ def __init__(self, selectable, sampling, name=None, seed=None):
self.sampling = sampling
self.seed = seed
super(TableSample, self).__init__(selectable, name=name)
@@ -1390,14 +1477,18 @@ class CTE(Generative, HasSuffixes, Alias):
.. versionadded:: 0.7.6
"""
- __visit_name__ = 'cte'
-
- def __init__(self, selectable,
- name=None,
- recursive=False,
- _cte_alias=None,
- _restates=frozenset(),
- _suffixes=None):
+
+ __visit_name__ = "cte"
+
+ def __init__(
+ self,
+ selectable,
+ name=None,
+ recursive=False,
+ _cte_alias=None,
+ _restates=frozenset(),
+ _suffixes=None,
+ ):
self.recursive = recursive
self._cte_alias = _cte_alias
self._restates = _restates
@@ -1409,9 +1500,9 @@ class CTE(Generative, HasSuffixes, Alias):
super(CTE, self)._copy_internals(clone, **kw)
if self._cte_alias is not None:
self._cte_alias = clone(self._cte_alias, **kw)
- self._restates = frozenset([
- clone(elem, **kw) for elem in self._restates
- ])
+ self._restates = frozenset(
+ [clone(elem, **kw) for elem in self._restates]
+ )
@util.dependencies("sqlalchemy.sql.dml")
def _populate_column_collection(self, dml):
@@ -1428,7 +1519,7 @@ class CTE(Generative, HasSuffixes, Alias):
name=name,
recursive=self.recursive,
_cte_alias=self,
- _suffixes=self._suffixes
+ _suffixes=self._suffixes,
)
def union(self, other):
@@ -1437,7 +1528,7 @@ class CTE(Generative, HasSuffixes, Alias):
name=self.name,
recursive=self.recursive,
_restates=self._restates.union([self]),
- _suffixes=self._suffixes
+ _suffixes=self._suffixes,
)
def union_all(self, other):
@@ -1446,7 +1537,7 @@ class CTE(Generative, HasSuffixes, Alias):
name=self.name,
recursive=self.recursive,
_restates=self._restates.union([self]),
- _suffixes=self._suffixes
+ _suffixes=self._suffixes,
)
@@ -1620,7 +1711,8 @@ class HasCTE(object):
class FromGrouping(FromClause):
"""Represent a grouping of a FROM clause"""
- __visit_name__ = 'grouping'
+
+ __visit_name__ = "grouping"
def __init__(self, element):
self.element = element
@@ -1651,7 +1743,7 @@ class FromGrouping(FromClause):
return self.element._hide_froms
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
@@ -1664,10 +1756,10 @@ class FromGrouping(FromClause):
return getattr(self.element, attr)
def __getstate__(self):
- return {'element': self.element}
+ return {"element": self.element}
def __setstate__(self, state):
- self.element = state['element']
+ self.element = state["element"]
class TableClause(Immutable, FromClause):
@@ -1699,7 +1791,7 @@ class TableClause(Immutable, FromClause):
"""
- __visit_name__ = 'table'
+ __visit_name__ = "table"
named_with_column = True
@@ -1744,7 +1836,7 @@ class TableClause(Immutable, FromClause):
if util.py3k:
return self.name
else:
- return self.name.encode('ascii', 'backslashreplace')
+ return self.name.encode("ascii", "backslashreplace")
def append_column(self, c):
self._columns[c.key] = c
@@ -1773,7 +1865,8 @@ class TableClause(Immutable, FromClause):
@util.dependencies("sqlalchemy.sql.dml")
def update(
- self, dml, whereclause=None, values=None, inline=False, **kwargs):
+ self, dml, whereclause=None, values=None, inline=False, **kwargs
+ ):
"""Generate an :func:`.update` construct against this
:class:`.TableClause`.
@@ -1785,8 +1878,13 @@ class TableClause(Immutable, FromClause):
"""
- return dml.Update(self, whereclause=whereclause,
- values=values, inline=inline, **kwargs)
+ return dml.Update(
+ self,
+ whereclause=whereclause,
+ values=values,
+ inline=inline,
+ **kwargs
+ )
@util.dependencies("sqlalchemy.sql.dml")
def delete(self, dml, whereclause=None, **kwargs):
@@ -1809,7 +1907,6 @@ class TableClause(Immutable, FromClause):
class ForUpdateArg(ClauseElement):
-
@classmethod
def parse_legacy_select(self, arg):
"""Parse the for_update argument of :func:`.select`.
@@ -1836,11 +1933,11 @@ class ForUpdateArg(ClauseElement):
return None
nowait = read = False
- if arg == 'nowait':
+ if arg == "nowait":
nowait = True
- elif arg == 'read':
+ elif arg == "read":
read = True
- elif arg == 'read_nowait':
+ elif arg == "read_nowait":
read = nowait = True
elif arg is not True:
raise exc.ArgumentError("Unknown for_update argument: %r" % arg)
@@ -1860,12 +1957,12 @@ class ForUpdateArg(ClauseElement):
def __eq__(self, other):
return (
- isinstance(other, ForUpdateArg) and
- other.nowait == self.nowait and
- other.read == self.read and
- other.skip_locked == self.skip_locked and
- other.key_share == self.key_share and
- other.of is self.of
+ isinstance(other, ForUpdateArg)
+ and other.nowait == self.nowait
+ and other.read == self.read
+ and other.skip_locked == self.skip_locked
+ and other.key_share == self.key_share
+ and other.of is self.of
)
def __hash__(self):
@@ -1876,8 +1973,13 @@ class ForUpdateArg(ClauseElement):
self.of = [clone(col, **kw) for col in self.of]
def __init__(
- self, nowait=False, read=False, of=None,
- skip_locked=False, key_share=False):
+ self,
+ nowait=False,
+ read=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
"""Represents arguments specified to :meth:`.Select.for_update`.
.. versionadded:: 0.9.0
@@ -1889,8 +1991,9 @@ class ForUpdateArg(ClauseElement):
self.skip_locked = skip_locked
self.key_share = key_share
if of is not None:
- self.of = [_interpret_as_column_or_from(elem)
- for elem in util.to_list(of)]
+ self.of = [
+ _interpret_as_column_or_from(elem) for elem in util.to_list(of)
+ ]
else:
self.of = None
@@ -1930,17 +2033,20 @@ class SelectBase(HasCTE, Executable, FromClause):
return self.as_scalar().label(name)
@_generative
- @util.deprecated('0.6',
- message="``autocommit()`` is deprecated. Use "
- ":meth:`.Executable.execution_options` with the "
- "'autocommit' flag.")
+ @util.deprecated(
+ "0.6",
+ message="``autocommit()`` is deprecated. Use "
+ ":meth:`.Executable.execution_options` with the "
+ "'autocommit' flag.",
+ )
def autocommit(self):
"""return a new selectable with the 'autocommit' flag set to
True.
"""
- self._execution_options = \
- self._execution_options.union({'autocommit': True})
+ self._execution_options = self._execution_options.union(
+ {"autocommit": True}
+ )
def _generate(self):
"""Override the default _generate() method to also clear out
@@ -1973,34 +2079,38 @@ class GenerativeSelect(SelectBase):
used for other SELECT-like objects, e.g. :class:`.TextAsFrom`.
"""
+
_order_by_clause = ClauseList()
_group_by_clause = ClauseList()
_limit_clause = None
_offset_clause = None
_for_update_arg = None
- def __init__(self,
- use_labels=False,
- for_update=False,
- limit=None,
- offset=None,
- order_by=None,
- group_by=None,
- bind=None,
- autocommit=None):
+ def __init__(
+ self,
+ use_labels=False,
+ for_update=False,
+ limit=None,
+ offset=None,
+ order_by=None,
+ group_by=None,
+ bind=None,
+ autocommit=None,
+ ):
self.use_labels = use_labels
if for_update is not False:
- self._for_update_arg = (ForUpdateArg.
- parse_legacy_select(for_update))
+ self._for_update_arg = ForUpdateArg.parse_legacy_select(for_update)
if autocommit is not None:
- util.warn_deprecated('autocommit on select() is '
- 'deprecated. Use .execution_options(a'
- 'utocommit=True)')
- self._execution_options = \
- self._execution_options.union(
- {'autocommit': autocommit})
+ util.warn_deprecated(
+ "autocommit on select() is "
+ "deprecated. Use .execution_options(a"
+ "utocommit=True)"
+ )
+ self._execution_options = self._execution_options.union(
+ {"autocommit": autocommit}
+ )
if limit is not None:
self._limit_clause = _offset_or_limit_clause(limit)
if offset is not None:
@@ -2010,11 +2120,13 @@ class GenerativeSelect(SelectBase):
if order_by is not None:
self._order_by_clause = ClauseList(
*util.to_list(order_by),
- _literal_as_text=_literal_and_labels_as_label_reference)
+ _literal_as_text=_literal_and_labels_as_label_reference
+ )
if group_by is not None:
self._group_by_clause = ClauseList(
*util.to_list(group_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
@property
def for_update(self):
@@ -2030,8 +2142,14 @@ class GenerativeSelect(SelectBase):
self._for_update_arg = ForUpdateArg.parse_legacy_select(value)
@_generative
- def with_for_update(self, nowait=False, read=False, of=None,
- skip_locked=False, key_share=False):
+ def with_for_update(
+ self,
+ nowait=False,
+ read=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
"""Specify a ``FOR UPDATE`` clause for this :class:`.GenerativeSelect`.
E.g.::
@@ -2079,9 +2197,13 @@ class GenerativeSelect(SelectBase):
.. versionadded:: 1.1.0
"""
- self._for_update_arg = ForUpdateArg(nowait=nowait, read=read, of=of,
- skip_locked=skip_locked,
- key_share=key_share)
+ self._for_update_arg = ForUpdateArg(
+ nowait=nowait,
+ read=read,
+ of=of,
+ skip_locked=skip_locked,
+ key_share=key_share,
+ )
@_generative
def apply_labels(self):
@@ -2209,11 +2331,12 @@ class GenerativeSelect(SelectBase):
if len(clauses) == 1 and clauses[0] is None:
self._order_by_clause = ClauseList()
else:
- if getattr(self, '_order_by_clause', None) is not None:
+ if getattr(self, "_order_by_clause", None) is not None:
clauses = list(self._order_by_clause) + list(clauses)
self._order_by_clause = ClauseList(
*clauses,
- _literal_as_text=_literal_and_labels_as_label_reference)
+ _literal_as_text=_literal_and_labels_as_label_reference
+ )
def append_group_by(self, *clauses):
"""Append the given GROUP BY criterion applied to this selectable.
@@ -2228,10 +2351,11 @@ class GenerativeSelect(SelectBase):
if len(clauses) == 1 and clauses[0] is None:
self._group_by_clause = ClauseList()
else:
- if getattr(self, '_group_by_clause', None) is not None:
+ if getattr(self, "_group_by_clause", None) is not None:
clauses = list(self._group_by_clause) + list(clauses)
self._group_by_clause = ClauseList(
- *clauses, _literal_as_text=_literal_as_label_reference)
+ *clauses, _literal_as_text=_literal_as_label_reference
+ )
@property
def _label_resolve_dict(self):
@@ -2265,19 +2389,19 @@ class CompoundSelect(GenerativeSelect):
"""
- __visit_name__ = 'compound_select'
+ __visit_name__ = "compound_select"
- UNION = util.symbol('UNION')
- UNION_ALL = util.symbol('UNION ALL')
- EXCEPT = util.symbol('EXCEPT')
- EXCEPT_ALL = util.symbol('EXCEPT ALL')
- INTERSECT = util.symbol('INTERSECT')
- INTERSECT_ALL = util.symbol('INTERSECT ALL')
+ UNION = util.symbol("UNION")
+ UNION_ALL = util.symbol("UNION ALL")
+ EXCEPT = util.symbol("EXCEPT")
+ EXCEPT_ALL = util.symbol("EXCEPT ALL")
+ INTERSECT = util.symbol("INTERSECT")
+ INTERSECT_ALL = util.symbol("INTERSECT ALL")
_is_from_container = True
def __init__(self, keyword, *selects, **kwargs):
- self._auto_correlate = kwargs.pop('correlate', False)
+ self._auto_correlate = kwargs.pop("correlate", False)
self.keyword = keyword
self.selects = []
@@ -2291,12 +2415,16 @@ class CompoundSelect(GenerativeSelect):
numcols = len(s.c._all_columns)
elif len(s.c._all_columns) != numcols:
raise exc.ArgumentError(
- 'All selectables passed to '
- 'CompoundSelect must have identical numbers of '
- 'columns; select #%d has %d columns, select '
- '#%d has %d' %
- (1, len(self.selects[0].c._all_columns),
- n + 1, len(s.c._all_columns))
+ "All selectables passed to "
+ "CompoundSelect must have identical numbers of "
+ "columns; select #%d has %d columns, select "
+ "#%d has %d"
+ % (
+ 1,
+ len(self.selects[0].c._all_columns),
+ n + 1,
+ len(s.c._all_columns),
+ )
)
self.selects.append(s.self_group(against=self))
@@ -2305,9 +2433,7 @@ class CompoundSelect(GenerativeSelect):
@property
def _label_resolve_dict(self):
- d = dict(
- (c.key, c) for c in self.c
- )
+ d = dict((c.key, c) for c in self.c)
return d, d, d
@classmethod
@@ -2416,8 +2542,7 @@ class CompoundSelect(GenerativeSelect):
:func:`select`.
"""
- return CompoundSelect(
- CompoundSelect.INTERSECT_ALL, *selects, **kwargs)
+ return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs)
def _scalar_type(self):
return self.selects[0]._scalar_type()
@@ -2445,8 +2570,10 @@ class CompoundSelect(GenerativeSelect):
# those fks too.
proxy = cols[0]._make_proxy(
- self, name=cols[0]._label if self.use_labels else None,
- key=cols[0]._key_label if self.use_labels else None)
+ self,
+ name=cols[0]._label if self.use_labels else None,
+ key=cols[0]._key_label if self.use_labels else None,
+ )
# hand-construct the "_proxies" collection to include all
# derived columns place a 'weight' annotation corresponding
@@ -2455,7 +2582,8 @@ class CompoundSelect(GenerativeSelect):
# conflicts
proxy._proxies = [
- c._annotate({'weight': i + 1}) for (i, c) in enumerate(cols)]
+ c._annotate({"weight": i + 1}) for (i, c) in enumerate(cols)
+ ]
def _refresh_for_new_column(self, column):
for s in self.selects:
@@ -2464,25 +2592,32 @@ class CompoundSelect(GenerativeSelect):
if not self._cols_populated:
return None
- raise NotImplementedError("CompoundSelect constructs don't support "
- "addition of columns to underlying "
- "selectables")
+ raise NotImplementedError(
+ "CompoundSelect constructs don't support "
+ "addition of columns to underlying "
+ "selectables"
+ )
def _copy_internals(self, clone=_clone, **kw):
super(CompoundSelect, self)._copy_internals(clone, **kw)
self._reset_exported()
self.selects = [clone(s, **kw) for s in self.selects]
- if hasattr(self, '_col_map'):
+ if hasattr(self, "_col_map"):
del self._col_map
for attr in (
- '_order_by_clause', '_group_by_clause', '_for_update_arg'):
+ "_order_by_clause",
+ "_group_by_clause",
+ "_for_update_arg",
+ ):
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
def get_children(self, column_collections=True, **kwargs):
- return (column_collections and list(self.c) or []) \
- + [self._order_by_clause, self._group_by_clause] \
+ return (
+ (column_collections and list(self.c) or [])
+ + [self._order_by_clause, self._group_by_clause]
+ list(self.selects)
+ )
def bind(self):
if self._bind:
@@ -2496,6 +2631,7 @@ class CompoundSelect(GenerativeSelect):
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
@@ -2504,7 +2640,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
- __visit_name__ = 'select'
+ __visit_name__ = "select"
_prefixes = ()
_suffixes = ()
@@ -2517,16 +2653,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
_memoized_property = SelectBase._memoized_property
_is_select = True
- def __init__(self,
- columns=None,
- whereclause=None,
- from_obj=None,
- distinct=False,
- having=None,
- correlate=True,
- prefixes=None,
- suffixes=None,
- **kwargs):
+ def __init__(
+ self,
+ columns=None,
+ whereclause=None,
+ from_obj=None,
+ distinct=False,
+ having=None,
+ correlate=True,
+ prefixes=None,
+ suffixes=None,
+ **kwargs
+ ):
"""Construct a new :class:`.Select`.
Similar functionality is also available via the
@@ -2729,22 +2867,23 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._distinct = True
else:
self._distinct = [
- _literal_as_text(e)
- for e in util.to_list(distinct)
+ _literal_as_text(e) for e in util.to_list(distinct)
]
if from_obj is not None:
self._from_obj = util.OrderedSet(
- _interpret_as_from(f)
- for f in util.to_list(from_obj))
+ _interpret_as_from(f) for f in util.to_list(from_obj)
+ )
else:
self._from_obj = util.OrderedSet()
try:
cols_present = bool(columns)
except TypeError:
- raise exc.ArgumentError("columns argument to select() must "
- "be a Python list or other iterable")
+ raise exc.ArgumentError(
+ "columns argument to select() must "
+ "be a Python list or other iterable"
+ )
if cols_present:
self._raw_columns = []
@@ -2757,14 +2896,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._raw_columns = []
if whereclause is not None:
- self._whereclause = _literal_as_text(
- whereclause).self_group(against=operators._asbool)
+ self._whereclause = _literal_as_text(whereclause).self_group(
+ against=operators._asbool
+ )
else:
self._whereclause = None
if having is not None:
- self._having = _literal_as_text(
- having).self_group(against=operators._asbool)
+ self._having = _literal_as_text(having).self_group(
+ against=operators._asbool
+ )
else:
self._having = None
@@ -2789,12 +2930,14 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
for item in itertools.chain(
_from_objects(*self._raw_columns),
_from_objects(self._whereclause)
- if self._whereclause is not None else (),
- self._from_obj
+ if self._whereclause is not None
+ else (),
+ self._from_obj,
):
if item is self:
raise exc.InvalidRequestError(
- "select() construct refers to itself as a FROM")
+ "select() construct refers to itself as a FROM"
+ )
if translate and item in translate:
item = translate[item]
if not seen.intersection(item._cloned_set):
@@ -2803,8 +2946,9 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return froms
- def _get_display_froms(self, explicit_correlate_froms=None,
- implicit_correlate_froms=None):
+ def _get_display_froms(
+ self, explicit_correlate_froms=None, implicit_correlate_froms=None
+ ):
"""Return the full list of 'from' clauses to be displayed.
Takes into account a set of existing froms which may be
@@ -2815,17 +2959,17 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
froms = self._froms
- toremove = set(itertools.chain(*[
- _expand_cloned(f._hide_froms)
- for f in froms]))
+ toremove = set(
+ itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms])
+ )
if toremove:
# if we're maintaining clones of froms,
# add the copies out to the toremove list. only include
# clones that are lexical equivalents.
if self._from_cloned:
toremove.update(
- self._from_cloned[f] for f in
- toremove.intersection(self._from_cloned)
+ self._from_cloned[f]
+ for f in toremove.intersection(self._from_cloned)
if self._from_cloned[f]._is_lexical_equivalent(f)
)
# filter out to FROM clauses not in the list,
@@ -2836,41 +2980,53 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
to_correlate = self._correlate
if to_correlate:
froms = [
- f for f in froms if f not in
- _cloned_intersection(
+ f
+ for f in froms
+ if f
+ not in _cloned_intersection(
_cloned_intersection(
- froms, explicit_correlate_froms or ()),
- to_correlate
+ froms, explicit_correlate_froms or ()
+ ),
+ to_correlate,
)
]
if self._correlate_except is not None:
froms = [
- f for f in froms if f not in
- _cloned_difference(
+ f
+ for f in froms
+ if f
+ not in _cloned_difference(
_cloned_intersection(
- froms, explicit_correlate_froms or ()),
- self._correlate_except
+ froms, explicit_correlate_froms or ()
+ ),
+ self._correlate_except,
)
]
- if self._auto_correlate and \
- implicit_correlate_froms and \
- len(froms) > 1:
+ if (
+ self._auto_correlate
+ and implicit_correlate_froms
+ and len(froms) > 1
+ ):
froms = [
- f for f in froms if f not in
- _cloned_intersection(froms, implicit_correlate_froms)
+ f
+ for f in froms
+ if f
+ not in _cloned_intersection(froms, implicit_correlate_froms)
]
if not len(froms):
- raise exc.InvalidRequestError("Select statement '%s"
- "' returned no FROM clauses "
- "due to auto-correlation; "
- "specify correlate(<tables>) "
- "to control correlation "
- "manually." % self)
+ raise exc.InvalidRequestError(
+ "Select statement '%s"
+ "' returned no FROM clauses "
+ "due to auto-correlation; "
+ "specify correlate(<tables>) "
+ "to control correlation "
+ "manually." % self
+ )
return froms
@@ -2885,7 +3041,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return self._get_display_froms()
- def with_statement_hint(self, text, dialect_name='*'):
+ def with_statement_hint(self, text, dialect_name="*"):
"""add a statement hint to this :class:`.Select`.
This method is similar to :meth:`.Select.with_hint` except that
@@ -2906,7 +3062,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return self.with_hint(None, text, dialect_name)
@_generative
- def with_hint(self, selectable, text, dialect_name='*'):
+ def with_hint(self, selectable, text, dialect_name="*"):
r"""Add an indexing or other executional context hint for the given
selectable to this :class:`.Select`.
@@ -2940,17 +3096,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
if selectable is None:
- self._statement_hints += ((dialect_name, text), )
+ self._statement_hints += ((dialect_name, text),)
else:
- self._hints = self._hints.union(
- {(selectable, dialect_name): text})
+ self._hints = self._hints.union({(selectable, dialect_name): text})
@property
def type(self):
- raise exc.InvalidRequestError("Select objects don't have a type. "
- "Call as_scalar() on this Select "
- "object to return a 'scalar' version "
- "of this Select.")
+ raise exc.InvalidRequestError(
+ "Select objects don't have a type. "
+ "Call as_scalar() on this Select "
+ "object to return a 'scalar' version "
+ "of this Select."
+ )
@_memoized_property.method
def locate_all_froms(self):
@@ -2977,10 +3134,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
with_cols = dict(
(c._resolve_label or c._label or c.key, c)
for c in _select_iterables(self._raw_columns)
- if c._allow_label_resolve)
+ if c._allow_label_resolve
+ )
only_froms = dict(
- (c.key, c) for c in
- _select_iterables(self.froms) if c._allow_label_resolve)
+ (c.key, c)
+ for c in _select_iterables(self.froms)
+ if c._allow_label_resolve
+ )
only_cols = with_cols.copy()
for key, value in only_froms.items():
with_cols.setdefault(key, value)
@@ -3011,11 +3171,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
# gets cleared on each generation. previously we were "baking"
# _froms into self._from_obj.
self._from_cloned = from_cloned = dict(
- (f, clone(f, **kw)) for f in self._from_obj.union(self._froms))
+ (f, clone(f, **kw)) for f in self._from_obj.union(self._froms)
+ )
# 3. update persistent _from_obj with the cloned versions.
- self._from_obj = util.OrderedSet(from_cloned[f] for f in
- self._from_obj)
+ self._from_obj = util.OrderedSet(
+ from_cloned[f] for f in self._from_obj
+ )
# the _correlate collection is done separately, what can happen
# here is the same item is _correlate as in _from_obj but the
@@ -3023,16 +3185,22 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
# RelationshipProperty.Comparator._criterion_exists() does
# this). Also keep _correlate liberally open with its previous
# contents, as this set is used for matching, not rendering.
- self._correlate = set(clone(f) for f in
- self._correlate).union(self._correlate)
+ self._correlate = set(clone(f) for f in self._correlate).union(
+ self._correlate
+ )
# 4. clone other things. The difficulty here is that Column
# objects are not actually cloned, and refer to their original
# .table, resulting in the wrong "from" parent after a clone
# operation. Hence _from_cloned and _from_obj supersede what is
# present here.
self._raw_columns = [clone(c, **kw) for c in self._raw_columns]
- for attr in '_whereclause', '_having', '_order_by_clause', \
- '_group_by_clause', '_for_update_arg':
+ for attr in (
+ "_whereclause",
+ "_having",
+ "_order_by_clause",
+ "_group_by_clause",
+ "_for_update_arg",
+ ):
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
@@ -3043,12 +3211,21 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
def get_children(self, column_collections=True, **kwargs):
"""return child elements as per the ClauseElement specification."""
- return (column_collections and list(self.columns) or []) + \
- self._raw_columns + list(self._froms) + \
- [x for x in
- (self._whereclause, self._having,
- self._order_by_clause, self._group_by_clause)
- if x is not None]
+ return (
+ (column_collections and list(self.columns) or [])
+ + self._raw_columns
+ + list(self._froms)
+ + [
+ x
+ for x in (
+ self._whereclause,
+ self._having,
+ self._order_by_clause,
+ self._group_by_clause,
+ )
+ if x is not None
+ ]
+ )
@_generative
def column(self, column):
@@ -3094,7 +3271,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
sqlutil.reduce_columns(
self.inner_columns,
only_synonyms=only_synonyms,
- *(self._whereclause, ) + tuple(self._from_obj)
+ *(self._whereclause,) + tuple(self._from_obj)
)
)
@@ -3307,7 +3484,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._correlate = ()
else:
self._correlate = set(self._correlate).union(
- _interpret_as_from(f) for f in fromclauses)
+ _interpret_as_from(f) for f in fromclauses
+ )
@_generative
def correlate_except(self, *fromclauses):
@@ -3349,7 +3527,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._correlate_except = ()
else:
self._correlate_except = set(self._correlate_except or ()).union(
- _interpret_as_from(f) for f in fromclauses)
+ _interpret_as_from(f) for f in fromclauses
+ )
def append_correlation(self, fromclause):
"""append the given correlation expression to this select()
@@ -3363,7 +3542,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._auto_correlate = False
self._correlate = set(self._correlate).union(
- _interpret_as_from(f) for f in fromclause)
+ _interpret_as_from(f) for f in fromclause
+ )
def append_column(self, column):
"""append the given column expression to the columns clause of this
@@ -3415,8 +3595,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
self._reset_exported()
- self._whereclause = and_(
- True_._ifnone(self._whereclause), whereclause)
+ self._whereclause = and_(True_._ifnone(self._whereclause), whereclause)
def append_having(self, having):
"""append the given expression to this select() construct's HAVING
@@ -3463,19 +3642,17 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return [
name_for_col(c)
- for c in util.unique_list(
- _select_iterables(self._raw_columns))
+ for c in util.unique_list(_select_iterables(self._raw_columns))
]
else:
return [
(None, c)
- for c in util.unique_list(
- _select_iterables(self._raw_columns))
+ for c in util.unique_list(_select_iterables(self._raw_columns))
]
def _populate_column_collection(self):
for name, c in self._columns_plus_names:
- if not hasattr(c, '_make_proxy'):
+ if not hasattr(c, "_make_proxy"):
continue
if name is None:
key = None
@@ -3486,9 +3663,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
else:
key = None
- c._make_proxy(self, key=key,
- name=name,
- name_is_truncatable=True)
+ c._make_proxy(self, key=key, name=name, name_is_truncatable=True)
def _refresh_for_new_column(self, column):
for fromclause in self._froms:
@@ -3501,15 +3676,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self,
name=col._label if self.use_labels else None,
key=col._key_label if self.use_labels else None,
- name_is_truncatable=True)
+ name_is_truncatable=True,
+ )
return None
return None
def _needs_parens_for_grouping(self):
return (
- self._limit_clause is not None or
- self._offset_clause is not None or
- bool(self._order_by_clause.clauses)
+ self._limit_clause is not None
+ or self._offset_clause is not None
+ or bool(self._order_by_clause.clauses)
)
def self_group(self, against=None):
@@ -3521,8 +3697,10 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
expressions and should not require explicit use.
"""
- if isinstance(against, CompoundSelect) and \
- not self._needs_parens_for_grouping():
+ if (
+ isinstance(against, CompoundSelect)
+ and not self._needs_parens_for_grouping()
+ ):
return self
return FromGrouping(self)
@@ -3586,6 +3764,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
@@ -3600,9 +3779,12 @@ class ScalarSelect(Generative, Grouping):
@property
def columns(self):
- raise exc.InvalidRequestError('Scalar Select expression has no '
- 'columns; use this object directly '
- 'within a column-level expression.')
+ raise exc.InvalidRequestError(
+ "Scalar Select expression has no "
+ "columns; use this object directly "
+ "within a column-level expression."
+ )
+
c = columns
@_generative
@@ -3621,6 +3803,7 @@ class Exists(UnaryExpression):
"""Represent an ``EXISTS`` clause.
"""
+
__visit_name__ = UnaryExpression.__visit_name__
_from_objects = []
@@ -3646,12 +3829,16 @@ class Exists(UnaryExpression):
s = args[0]
else:
if not args:
- args = ([literal_column('*')],)
+ args = ([literal_column("*")],)
s = Select(*args, **kwargs).as_scalar().self_group()
- UnaryExpression.__init__(self, s, operator=operators.exists,
- type_=type_api.BOOLEANTYPE,
- wraps_column_expression=True)
+ UnaryExpression.__init__(
+ self,
+ s,
+ operator=operators.exists,
+ type_=type_api.BOOLEANTYPE,
+ wraps_column_expression=True,
+ )
def select(self, whereclause=None, **params):
return Select([self], whereclause, **params)
@@ -3706,6 +3893,7 @@ class TextAsFrom(SelectBase):
:meth:`.TextClause.columns`
"""
+
__visit_name__ = "text_as_from"
_textual = True
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index c5708940b..61fc6d3c9 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -15,10 +15,21 @@ import collections
import json
from . import elements
-from .type_api import TypeEngine, TypeDecorator, to_instance, Variant, \
- Emulated, NativeForEmulated
-from .elements import quoted_name, TypeCoerce as type_coerce, _defer_name, \
- Slice, _literal_as_binds
+from .type_api import (
+ TypeEngine,
+ TypeDecorator,
+ to_instance,
+ Variant,
+ Emulated,
+ NativeForEmulated,
+)
+from .elements import (
+ quoted_name,
+ TypeCoerce as type_coerce,
+ _defer_name,
+ Slice,
+ _literal_as_binds,
+)
from .. import exc, util, processors
from .base import _bind_or_error, SchemaEventTarget
from . import operators
@@ -51,14 +62,15 @@ class _LookupExpressionAdapter(object):
def _adapt_expression(self, op, other_comparator):
othertype = other_comparator.type._type_affinity
lookup = self.type._expression_adaptations.get(
- op, self._blank_dict).get(
- othertype, self.type)
+ op, self._blank_dict
+ ).get(othertype, self.type)
if lookup is othertype:
return (op, other_comparator.type)
elif lookup is self.type._type_affinity:
return (op, self.type)
else:
return (op, to_instance(lookup))
+
comparator_factory = Comparator
@@ -68,17 +80,16 @@ class Concatenable(object):
typically strings."""
class Comparator(TypeEngine.Comparator):
-
def _adapt_expression(self, op, other_comparator):
- if (op is operators.add and
- isinstance(
- other_comparator,
- (Concatenable.Comparator, NullType.Comparator)
- )):
+ if op is operators.add and isinstance(
+ other_comparator,
+ (Concatenable.Comparator, NullType.Comparator),
+ ):
return operators.concat_op, self.expr.type
else:
return super(Concatenable.Comparator, self)._adapt_expression(
- op, other_comparator)
+ op, other_comparator
+ )
comparator_factory = Comparator
@@ -94,17 +105,15 @@ class Indexable(object):
"""
class Comparator(TypeEngine.Comparator):
-
def _setup_getitem(self, index):
raise NotImplementedError()
def __getitem__(self, index):
- adjusted_op, adjusted_right_expr, result_type = \
- self._setup_getitem(index)
+ adjusted_op, adjusted_right_expr, result_type = self._setup_getitem(
+ index
+ )
return self.operate(
- adjusted_op,
- adjusted_right_expr,
- result_type=result_type
+ adjusted_op, adjusted_right_expr, result_type=result_type
)
comparator_factory = Comparator
@@ -124,13 +133,16 @@ class String(Concatenable, TypeEngine):
"""
- __visit_name__ = 'string'
+ __visit_name__ = "string"
- def __init__(self, length=None, collation=None,
- convert_unicode=False,
- unicode_error=None,
- _warn_on_bytestring=False
- ):
+ def __init__(
+ self,
+ length=None,
+ collation=None,
+ convert_unicode=False,
+ unicode_error=None,
+ _warn_on_bytestring=False,
+ ):
"""
Create a string-holding type.
@@ -207,9 +219,10 @@ class String(Concatenable, TypeEngine):
strings from a column with varied or corrupted encodings.
"""
- if unicode_error is not None and convert_unicode != 'force':
- raise exc.ArgumentError("convert_unicode must be 'force' "
- "when unicode_error is set.")
+ if unicode_error is not None and convert_unicode != "force":
+ raise exc.ArgumentError(
+ "convert_unicode must be 'force' " "when unicode_error is set."
+ )
self.length = length
self.collation = collation
@@ -222,23 +235,29 @@ class String(Concatenable, TypeEngine):
value = value.replace("'", "''")
if dialect.identifier_preparer._double_percents:
- value = value.replace('%', '%%')
+ value = value.replace("%", "%%")
return "'%s'" % value
+
return process
def bind_processor(self, dialect):
if self.convert_unicode or dialect.convert_unicode:
- if dialect.supports_unicode_binds and \
- self.convert_unicode != 'force':
+ if (
+ dialect.supports_unicode_binds
+ and self.convert_unicode != "force"
+ ):
if self._warn_on_bytestring:
+
def process(value):
if isinstance(value, util.binary_type):
util.warn_limited(
"Unicode type received non-unicode "
"bind param value %r.",
- (util.ellipses_string(value),))
+ (util.ellipses_string(value),),
+ )
return value
+
return process
else:
return None
@@ -253,29 +272,34 @@ class String(Concatenable, TypeEngine):
util.warn_limited(
"Unicode type received non-unicode bind "
"param value %r.",
- (util.ellipses_string(value),))
+ (util.ellipses_string(value),),
+ )
return value
+
return process
else:
return None
def result_processor(self, dialect, coltype):
wants_unicode = self.convert_unicode or dialect.convert_unicode
- needs_convert = wants_unicode and \
- (dialect.returns_unicode_strings is not True or
- self.convert_unicode in ('force', 'force_nocheck'))
+ needs_convert = wants_unicode and (
+ dialect.returns_unicode_strings is not True
+ or self.convert_unicode in ("force", "force_nocheck")
+ )
needs_isinstance = (
- needs_convert and
- dialect.returns_unicode_strings and
- self.convert_unicode != 'force_nocheck'
+ needs_convert
+ and dialect.returns_unicode_strings
+ and self.convert_unicode != "force_nocheck"
)
if needs_convert:
if needs_isinstance:
return processors.to_conditional_unicode_processor_factory(
- dialect.encoding, self.unicode_error)
+ dialect.encoding, self.unicode_error
+ )
else:
return processors.to_unicode_processor_factory(
- dialect.encoding, self.unicode_error)
+ dialect.encoding, self.unicode_error
+ )
else:
return None
@@ -301,7 +325,8 @@ class Text(String):
argument here, it will be rejected by others.
"""
- __visit_name__ = 'text'
+
+ __visit_name__ = "text"
class Unicode(String):
@@ -360,7 +385,7 @@ class Unicode(String):
"""
- __visit_name__ = 'unicode'
+ __visit_name__ = "unicode"
def __init__(self, length=None, **kwargs):
"""
@@ -371,8 +396,8 @@ class Unicode(String):
defaults to ``True``.
"""
- kwargs.setdefault('convert_unicode', True)
- kwargs.setdefault('_warn_on_bytestring', True)
+ kwargs.setdefault("convert_unicode", True)
+ kwargs.setdefault("_warn_on_bytestring", True)
super(Unicode, self).__init__(length=length, **kwargs)
@@ -389,7 +414,7 @@ class UnicodeText(Text):
"""
- __visit_name__ = 'unicode_text'
+ __visit_name__ = "unicode_text"
def __init__(self, length=None, **kwargs):
"""
@@ -400,8 +425,8 @@ class UnicodeText(Text):
defaults to ``True``.
"""
- kwargs.setdefault('convert_unicode', True)
- kwargs.setdefault('_warn_on_bytestring', True)
+ kwargs.setdefault("convert_unicode", True)
+ kwargs.setdefault("_warn_on_bytestring", True)
super(UnicodeText, self).__init__(length=length, **kwargs)
@@ -409,7 +434,7 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
"""A type for ``int`` integers."""
- __visit_name__ = 'integer'
+ __visit_name__ = "integer"
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
@@ -421,6 +446,7 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
def literal_processor(self, dialect):
def process(value):
return str(value)
+
return process
@util.memoized_property
@@ -438,18 +464,9 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
Integer: self.__class__,
Numeric: Numeric,
},
- operators.div: {
- Integer: self.__class__,
- Numeric: Numeric,
- },
- operators.truediv: {
- Integer: self.__class__,
- Numeric: Numeric,
- },
- operators.sub: {
- Integer: self.__class__,
- Numeric: Numeric,
- },
+ operators.div: {Integer: self.__class__, Numeric: Numeric},
+ operators.truediv: {Integer: self.__class__, Numeric: Numeric},
+ operators.sub: {Integer: self.__class__, Numeric: Numeric},
}
@@ -462,7 +479,7 @@ class SmallInteger(Integer):
"""
- __visit_name__ = 'small_integer'
+ __visit_name__ = "small_integer"
class BigInteger(Integer):
@@ -474,7 +491,7 @@ class BigInteger(Integer):
"""
- __visit_name__ = 'big_integer'
+ __visit_name__ = "big_integer"
class Numeric(_LookupExpressionAdapter, TypeEngine):
@@ -517,12 +534,17 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
"""
- __visit_name__ = 'numeric'
+ __visit_name__ = "numeric"
_default_decimal_return_scale = 10
- def __init__(self, precision=None, scale=None,
- decimal_return_scale=None, asdecimal=True):
+ def __init__(
+ self,
+ precision=None,
+ scale=None,
+ decimal_return_scale=None,
+ asdecimal=True,
+ ):
"""
Construct a Numeric.
@@ -587,6 +609,7 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
def literal_processor(self, dialect):
def process(value):
return str(value)
+
return process
@property
@@ -608,19 +631,23 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
# we're a "numeric", DBAPI will give us Decimal directly
return None
else:
- util.warn('Dialect %s+%s does *not* support Decimal '
- 'objects natively, and SQLAlchemy must '
- 'convert from floating point - rounding '
- 'errors and other issues may occur. Please '
- 'consider storing Decimal numbers as strings '
- 'or integers on this platform for lossless '
- 'storage.' % (dialect.name, dialect.driver))
+ util.warn(
+ "Dialect %s+%s does *not* support Decimal "
+ "objects natively, and SQLAlchemy must "
+ "convert from floating point - rounding "
+ "errors and other issues may occur. Please "
+ "consider storing Decimal numbers as strings "
+ "or integers on this platform for lossless "
+ "storage." % (dialect.name, dialect.driver)
+ )
# we're a "numeric", DBAPI returns floats, convert.
return processors.to_decimal_processor_factory(
decimal.Decimal,
- self.scale if self.scale is not None
- else self._default_decimal_return_scale)
+ self.scale
+ if self.scale is not None
+ else self._default_decimal_return_scale,
+ )
else:
if dialect.supports_native_decimal:
return processors.to_float
@@ -635,22 +662,13 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
Numeric: self.__class__,
Integer: self.__class__,
},
- operators.div: {
- Numeric: self.__class__,
- Integer: self.__class__,
- },
+ operators.div: {Numeric: self.__class__, Integer: self.__class__},
operators.truediv: {
Numeric: self.__class__,
Integer: self.__class__,
},
- operators.add: {
- Numeric: self.__class__,
- Integer: self.__class__,
- },
- operators.sub: {
- Numeric: self.__class__,
- Integer: self.__class__,
- }
+ operators.add: {Numeric: self.__class__, Integer: self.__class__},
+ operators.sub: {Numeric: self.__class__, Integer: self.__class__},
}
@@ -675,12 +693,17 @@ class Float(Numeric):
"""
- __visit_name__ = 'float'
+ __visit_name__ = "float"
scale = None
- def __init__(self, precision=None, asdecimal=False,
- decimal_return_scale=None, **kwargs):
+ def __init__(
+ self,
+ precision=None,
+ asdecimal=False,
+ decimal_return_scale=None,
+ **kwargs
+ ):
r"""
Construct a Float.
@@ -713,14 +736,15 @@ class Float(Numeric):
self.asdecimal = asdecimal
self.decimal_return_scale = decimal_return_scale
if kwargs:
- util.warn_deprecated("Additional keyword arguments "
- "passed to Float ignored.")
+ util.warn_deprecated(
+ "Additional keyword arguments " "passed to Float ignored."
+ )
def result_processor(self, dialect, coltype):
if self.asdecimal:
return processors.to_decimal_processor_factory(
- decimal.Decimal,
- self._effective_decimal_return_scale)
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
elif dialect.supports_native_decimal:
return processors.to_float
else:
@@ -746,7 +770,7 @@ class DateTime(_LookupExpressionAdapter, TypeEngine):
"""
- __visit_name__ = 'datetime'
+ __visit_name__ = "datetime"
def __init__(self, timezone=False):
"""Construct a new :class:`.DateTime`.
@@ -777,13 +801,8 @@ class DateTime(_LookupExpressionAdapter, TypeEngine):
# static/functions-datetime.html.
return {
- operators.add: {
- Interval: self.__class__,
- },
- operators.sub: {
- Interval: self.__class__,
- DateTime: Interval,
- },
+ operators.add: {Interval: self.__class__},
+ operators.sub: {Interval: self.__class__, DateTime: Interval},
}
@@ -791,7 +810,7 @@ class Date(_LookupExpressionAdapter, TypeEngine):
"""A type for ``datetime.date()`` objects."""
- __visit_name__ = 'date'
+ __visit_name__ = "date"
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
@@ -814,12 +833,9 @@ class Date(_LookupExpressionAdapter, TypeEngine):
operators.sub: {
# date - integer = date
Integer: self.__class__,
-
# date - date = integer.
Date: Integer,
-
Interval: DateTime,
-
# date - datetime = interval,
# this one is not in the PG docs
# but works
@@ -832,7 +848,7 @@ class Time(_LookupExpressionAdapter, TypeEngine):
"""A type for ``datetime.time()`` objects."""
- __visit_name__ = 'time'
+ __visit_name__ = "time"
def __init__(self, timezone=False):
self.timezone = timezone
@@ -850,14 +866,8 @@ class Time(_LookupExpressionAdapter, TypeEngine):
# static/functions-datetime.html.
return {
- operators.add: {
- Date: DateTime,
- Interval: self.__class__
- },
- operators.sub: {
- Time: Interval,
- Interval: self.__class__,
- },
+ operators.add: {Date: DateTime, Interval: self.__class__},
+ operators.sub: {Time: Interval, Interval: self.__class__},
}
@@ -872,6 +882,7 @@ class _Binary(TypeEngine):
def process(value):
value = value.decode(dialect.encoding).replace("'", "''")
return "'%s'" % value
+
return process
@property
@@ -891,14 +902,17 @@ class _Binary(TypeEngine):
return DBAPIBinary(value)
else:
return None
+
return process
# Python 3 has native bytes() type
# both sqlite3 and pg8000 seem to return it,
# psycopg2 as of 2.5 returns 'memoryview'
if util.py2k:
+
def result_processor(self, dialect, coltype):
if util.jython:
+
def process(value):
if value is not None:
if isinstance(value, array.array):
@@ -906,15 +920,19 @@ class _Binary(TypeEngine):
return str(value)
else:
return None
+
else:
process = processors.to_str
return process
+
else:
+
def result_processor(self, dialect, coltype):
def process(value):
if value is not None:
value = bytes(value)
return value
+
return process
def coerce_compared_value(self, op, value):
@@ -939,7 +957,7 @@ class LargeBinary(_Binary):
"""
- __visit_name__ = 'large_binary'
+ __visit_name__ = "large_binary"
def __init__(self, length=None):
"""
@@ -958,8 +976,9 @@ class Binary(LargeBinary):
"""Deprecated. Renamed to LargeBinary."""
def __init__(self, *arg, **kw):
- util.warn_deprecated('The Binary type has been renamed to '
- 'LargeBinary.')
+ util.warn_deprecated(
+ "The Binary type has been renamed to " "LargeBinary."
+ )
LargeBinary.__init__(self, *arg, **kw)
@@ -986,8 +1005,15 @@ class SchemaType(SchemaEventTarget):
"""
- def __init__(self, name=None, schema=None, metadata=None,
- inherit_schema=False, quote=None, _create_events=True):
+ def __init__(
+ self,
+ name=None,
+ schema=None,
+ metadata=None,
+ inherit_schema=False,
+ quote=None,
+ _create_events=True,
+ ):
if name is not None:
self.name = quoted_name(name, quote)
else:
@@ -1001,12 +1027,12 @@ class SchemaType(SchemaEventTarget):
event.listen(
self.metadata,
"before_create",
- util.portable_instancemethod(self._on_metadata_create)
+ util.portable_instancemethod(self._on_metadata_create),
)
event.listen(
self.metadata,
"after_drop",
- util.portable_instancemethod(self._on_metadata_drop)
+ util.portable_instancemethod(self._on_metadata_drop),
)
def _translate_schema(self, effective_schema, map_):
@@ -1018,7 +1044,7 @@ class SchemaType(SchemaEventTarget):
def _variant_mapping_for_set_table(self, column):
if isinstance(column.type, Variant):
variant_mapping = column.type.mapping.copy()
- variant_mapping['_default'] = column.type.impl
+ variant_mapping["_default"] = column.type.impl
else:
variant_mapping = None
return variant_mapping
@@ -1036,15 +1062,15 @@ class SchemaType(SchemaEventTarget):
table,
"before_create",
util.portable_instancemethod(
- self._on_table_create,
- {"variant_mapping": variant_mapping})
+ self._on_table_create, {"variant_mapping": variant_mapping}
+ ),
)
event.listen(
table,
"after_drop",
util.portable_instancemethod(
- self._on_table_drop,
- {"variant_mapping": variant_mapping})
+ self._on_table_drop, {"variant_mapping": variant_mapping}
+ ),
)
if self.metadata is None:
# TODO: what's the difference between self.metadata
@@ -1054,29 +1080,33 @@ class SchemaType(SchemaEventTarget):
"before_create",
util.portable_instancemethod(
self._on_metadata_create,
- {"variant_mapping": variant_mapping})
+ {"variant_mapping": variant_mapping},
+ ),
)
event.listen(
table.metadata,
"after_drop",
util.portable_instancemethod(
self._on_metadata_drop,
- {"variant_mapping": variant_mapping})
+ {"variant_mapping": variant_mapping},
+ ),
)
def copy(self, **kw):
return self.adapt(self.__class__, _create_events=True)
def adapt(self, impltype, **kw):
- schema = kw.pop('schema', self.schema)
- metadata = kw.pop('metadata', self.metadata)
- _create_events = kw.pop('_create_events', False)
- return impltype(name=self.name,
- schema=schema,
- inherit_schema=self.inherit_schema,
- metadata=metadata,
- _create_events=_create_events,
- **kw)
+ schema = kw.pop("schema", self.schema)
+ metadata = kw.pop("metadata", self.metadata)
+ _create_events = kw.pop("_create_events", False)
+ return impltype(
+ name=self.name,
+ schema=schema,
+ inherit_schema=self.inherit_schema,
+ metadata=metadata,
+ _create_events=_create_events,
+ **kw
+ )
@property
def bind(self):
@@ -1133,15 +1163,17 @@ class SchemaType(SchemaEventTarget):
t._on_metadata_drop(target, bind, **kw)
def _is_impl_for_variant(self, dialect, kw):
- variant_mapping = kw.pop('variant_mapping', None)
+ variant_mapping = kw.pop("variant_mapping", None)
if variant_mapping is None:
return True
- if dialect.name in variant_mapping and \
- variant_mapping[dialect.name] is self:
+ if (
+ dialect.name in variant_mapping
+ and variant_mapping[dialect.name] is self
+ ):
return True
elif dialect.name not in variant_mapping:
- return variant_mapping['_default'] is self
+ return variant_mapping["_default"] is self
class Enum(Emulated, String, SchemaType):
@@ -1220,7 +1252,8 @@ class Enum(Emulated, String, SchemaType):
:class:`.mysql.ENUM` - MySQL-specific type
"""
- __visit_name__ = 'enum'
+
+ __visit_name__ = "enum"
def __init__(self, *enums, **kw):
r"""Construct an enum.
@@ -1322,15 +1355,15 @@ class Enum(Emulated, String, SchemaType):
other arguments in kw to pass through.
"""
- self.native_enum = kw.pop('native_enum', True)
- self.create_constraint = kw.pop('create_constraint', True)
- self.values_callable = kw.pop('values_callable', None)
+ self.native_enum = kw.pop("native_enum", True)
+ self.create_constraint = kw.pop("create_constraint", True)
+ self.values_callable = kw.pop("values_callable", None)
values, objects = self._parse_into_values(enums, kw)
self._setup_for_values(values, objects, kw)
- convert_unicode = kw.pop('convert_unicode', None)
- self.validate_strings = kw.pop('validate_strings', False)
+ convert_unicode = kw.pop("convert_unicode", None)
+ self.validate_strings = kw.pop("validate_strings", False)
if convert_unicode is None:
for e in self.enums:
@@ -1347,33 +1380,35 @@ class Enum(Emulated, String, SchemaType):
self._valid_lookup[None] = self._object_lookup[None] = None
super(Enum, self).__init__(
- length=length,
- convert_unicode=convert_unicode,
+ length=length, convert_unicode=convert_unicode
)
if self.enum_class:
- kw.setdefault('name', self.enum_class.__name__.lower())
+ kw.setdefault("name", self.enum_class.__name__.lower())
SchemaType.__init__(
self,
- name=kw.pop('name', None),
- schema=kw.pop('schema', None),
- metadata=kw.pop('metadata', None),
- inherit_schema=kw.pop('inherit_schema', False),
- quote=kw.pop('quote', None),
- _create_events=kw.pop('_create_events', True)
+ name=kw.pop("name", None),
+ schema=kw.pop("schema", None),
+ metadata=kw.pop("metadata", None),
+ inherit_schema=kw.pop("inherit_schema", False),
+ quote=kw.pop("quote", None),
+ _create_events=kw.pop("_create_events", True),
)
def _parse_into_values(self, enums, kw):
- if not enums and '_enums' in kw:
- enums = kw.pop('_enums')
+ if not enums and "_enums" in kw:
+ enums = kw.pop("_enums")
- if len(enums) == 1 and hasattr(enums[0], '__members__'):
+ if len(enums) == 1 and hasattr(enums[0], "__members__"):
self.enum_class = enums[0]
if self.values_callable:
values = self.values_callable(self.enum_class)
else:
values = list(self.enum_class.__members__)
- objects = [self.enum_class.__members__[k] for k in self.enum_class.__members__]
+ objects = [
+ self.enum_class.__members__[k]
+ for k in self.enum_class.__members__
+ ]
return values, objects
else:
self.enum_class = None
@@ -1382,18 +1417,16 @@ class Enum(Emulated, String, SchemaType):
def _setup_for_values(self, values, objects, kw):
self.enums = list(values)
- self._valid_lookup = dict(
- zip(reversed(objects), reversed(values))
- )
+ self._valid_lookup = dict(zip(reversed(objects), reversed(values)))
- self._object_lookup = dict(
- zip(values, objects)
- )
+ self._object_lookup = dict(zip(values, objects))
- self._valid_lookup.update([
- (value, self._valid_lookup[self._object_lookup[value]])
- for value in values
- ])
+ self._valid_lookup.update(
+ [
+ (value, self._valid_lookup[self._object_lookup[value]])
+ for value in values
+ ]
+ )
@property
def native(self):
@@ -1411,22 +1444,24 @@ class Enum(Emulated, String, SchemaType):
# here between an INSERT statement and a criteria used in a SELECT,
# for now we're staying conservative w/ behavioral changes (perhaps
# someone has a trigger that handles strings on INSERT)
- if not self.validate_strings and \
- isinstance(elem, compat.string_types):
+ if not self.validate_strings and isinstance(
+ elem, compat.string_types
+ ):
return elem
else:
raise LookupError(
- '"%s" is not among the defined enum values' % elem)
+ '"%s" is not among the defined enum values' % elem
+ )
class Comparator(String.Comparator):
-
def _adapt_expression(self, op, other_comparator):
op, typ = super(Enum.Comparator, self)._adapt_expression(
- op, other_comparator)
+ op, other_comparator
+ )
if op is operators.concat_op:
typ = String(
- self.type.length,
- convert_unicode=self.type.convert_unicode)
+ self.type.length, convert_unicode=self.type.convert_unicode
+ )
return op, typ
comparator_factory = Comparator
@@ -1436,38 +1471,40 @@ class Enum(Emulated, String, SchemaType):
return self._object_lookup[elem]
except KeyError:
raise LookupError(
- '"%s" is not among the defined enum values' % elem)
+ '"%s" is not among the defined enum values' % elem
+ )
def __repr__(self):
return util.generic_repr(
self,
- additional_kw=[('native_enum', True)],
+ additional_kw=[("native_enum", True)],
to_inspect=[Enum, SchemaType],
)
def adapt_to_emulated(self, impltype, **kw):
kw.setdefault("convert_unicode", self.convert_unicode)
kw.setdefault("validate_strings", self.validate_strings)
- kw.setdefault('name', self.name)
- kw.setdefault('schema', self.schema)
- kw.setdefault('inherit_schema', self.inherit_schema)
- kw.setdefault('metadata', self.metadata)
- kw.setdefault('_create_events', False)
- kw.setdefault('native_enum', self.native_enum)
- kw.setdefault('values_callable', self.values_callable)
- kw.setdefault('create_constraint', self.create_constraint)
- assert '_enums' in kw
+ kw.setdefault("name", self.name)
+ kw.setdefault("schema", self.schema)
+ kw.setdefault("inherit_schema", self.inherit_schema)
+ kw.setdefault("metadata", self.metadata)
+ kw.setdefault("_create_events", False)
+ kw.setdefault("native_enum", self.native_enum)
+ kw.setdefault("values_callable", self.values_callable)
+ kw.setdefault("create_constraint", self.create_constraint)
+ assert "_enums" in kw
return impltype(**kw)
def adapt(self, impltype, **kw):
- kw['_enums'] = self._enums_argument
+ kw["_enums"] = self._enums_argument
return super(Enum, self).adapt(impltype, **kw)
def _should_create_constraint(self, compiler, **kw):
if not self._is_impl_for_variant(compiler.dialect, kw):
return False
- return not self.native_enum or \
- not compiler.dialect.supports_native_enum
+ return (
+ not self.native_enum or not compiler.dialect.supports_native_enum
+ )
@util.dependencies("sqlalchemy.sql.schema")
def _set_table(self, schema, column, table):
@@ -1483,20 +1520,21 @@ class Enum(Emulated, String, SchemaType):
name=_defer_name(self.name),
_create_rule=util.portable_instancemethod(
self._should_create_constraint,
- {"variant_mapping": variant_mapping}),
- _type_bound=True
+ {"variant_mapping": variant_mapping},
+ ),
+ _type_bound=True,
)
assert e.table is table
def literal_processor(self, dialect):
- parent_processor = super(
- Enum, self).literal_processor(dialect)
+ parent_processor = super(Enum, self).literal_processor(dialect)
def process(value):
value = self._db_value_for_elem(value)
if parent_processor:
value = parent_processor(value)
return value
+
return process
def bind_processor(self, dialect):
@@ -1510,8 +1548,7 @@ class Enum(Emulated, String, SchemaType):
return process
def result_processor(self, dialect, coltype):
- parent_processor = super(Enum, self).result_processor(
- dialect, coltype)
+ parent_processor = super(Enum, self).result_processor(dialect, coltype)
def process(value):
if parent_processor:
@@ -1548,8 +1585,9 @@ class PickleType(TypeDecorator):
impl = LargeBinary
- def __init__(self, protocol=pickle.HIGHEST_PROTOCOL,
- pickler=None, comparator=None):
+ def __init__(
+ self, protocol=pickle.HIGHEST_PROTOCOL, pickler=None, comparator=None
+ ):
"""
Construct a PickleType.
@@ -1570,40 +1608,46 @@ class PickleType(TypeDecorator):
super(PickleType, self).__init__()
def __reduce__(self):
- return PickleType, (self.protocol,
- None,
- self.comparator)
+ return PickleType, (self.protocol, None, self.comparator)
def bind_processor(self, dialect):
impl_processor = self.impl.bind_processor(dialect)
dumps = self.pickler.dumps
protocol = self.protocol
if impl_processor:
+
def process(value):
if value is not None:
value = dumps(value, protocol)
return impl_processor(value)
+
else:
+
def process(value):
if value is not None:
value = dumps(value, protocol)
return value
+
return process
def result_processor(self, dialect, coltype):
impl_processor = self.impl.result_processor(dialect, coltype)
loads = self.pickler.loads
if impl_processor:
+
def process(value):
value = impl_processor(value)
if value is None:
return None
return loads(value)
+
else:
+
def process(value):
if value is None:
return None
return loads(value)
+
return process
def compare_values(self, x, y):
@@ -1635,11 +1679,10 @@ class Boolean(Emulated, TypeEngine, SchemaType):
"""
- __visit_name__ = 'boolean'
+ __visit_name__ = "boolean"
native = True
- def __init__(
- self, create_constraint=True, name=None, _create_events=True):
+ def __init__(self, create_constraint=True, name=None, _create_events=True):
"""Construct a Boolean.
:param create_constraint: defaults to True. If the boolean
@@ -1657,8 +1700,10 @@ class Boolean(Emulated, TypeEngine, SchemaType):
def _should_create_constraint(self, compiler, **kw):
if not self._is_impl_for_variant(compiler.dialect, kw):
return False
- return not compiler.dialect.supports_native_boolean and \
- compiler.dialect.non_native_boolean_check_constraint
+ return (
+ not compiler.dialect.supports_native_boolean
+ and compiler.dialect.non_native_boolean_check_constraint
+ )
@util.dependencies("sqlalchemy.sql.schema")
def _set_table(self, schema, column, table):
@@ -1672,8 +1717,9 @@ class Boolean(Emulated, TypeEngine, SchemaType):
name=_defer_name(self.name),
_create_rule=util.portable_instancemethod(
self._should_create_constraint,
- {"variant_mapping": variant_mapping}),
- _type_bound=True
+ {"variant_mapping": variant_mapping},
+ ),
+ _type_bound=True,
)
assert e.table is table
@@ -1686,11 +1732,11 @@ class Boolean(Emulated, TypeEngine, SchemaType):
def _strict_as_bool(self, value):
if value not in self._strict_bools:
if not isinstance(value, int):
- raise TypeError(
- "Not a boolean value: %r" % value)
+ raise TypeError("Not a boolean value: %r" % value)
else:
raise ValueError(
- "Value %r is not None, True, or False" % value)
+ "Value %r is not None, True, or False" % value
+ )
return value
def literal_processor(self, dialect):
@@ -1700,6 +1746,7 @@ class Boolean(Emulated, TypeEngine, SchemaType):
def process(value):
return true if self._strict_as_bool(value) else false
+
return process
def bind_processor(self, dialect):
@@ -1714,6 +1761,7 @@ class Boolean(Emulated, TypeEngine, SchemaType):
if value is not None:
value = _coerce(value)
return value
+
return process
def result_processor(self, dialect, coltype):
@@ -1736,18 +1784,10 @@ class _AbstractInterval(_LookupExpressionAdapter, TypeEngine):
DateTime: DateTime,
Time: Time,
},
- operators.sub: {
- Interval: self.__class__
- },
- operators.mul: {
- Numeric: self.__class__
- },
- operators.truediv: {
- Numeric: self.__class__
- },
- operators.div: {
- Numeric: self.__class__
- }
+ operators.sub: {Interval: self.__class__},
+ operators.mul: {Numeric: self.__class__},
+ operators.truediv: {Numeric: self.__class__},
+ operators.div: {Numeric: self.__class__},
}
@property
@@ -1780,9 +1820,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator):
impl = DateTime
epoch = dt.datetime.utcfromtimestamp(0)
- def __init__(self, native=True,
- second_precision=None,
- day_precision=None):
+ def __init__(self, native=True, second_precision=None, day_precision=None):
"""Construct an Interval object.
:param native: when True, use the actual
@@ -1815,31 +1853,39 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator):
impl_processor = self.impl.bind_processor(dialect)
epoch = self.epoch
if impl_processor:
+
def process(value):
if value is not None:
value = epoch + value
return impl_processor(value)
+
else:
+
def process(value):
if value is not None:
value = epoch + value
return value
+
return process
def result_processor(self, dialect, coltype):
impl_processor = self.impl.result_processor(dialect, coltype)
epoch = self.epoch
if impl_processor:
+
def process(value):
value = impl_processor(value)
if value is None:
return None
return value - epoch
+
else:
+
def process(value):
if value is None:
return None
return value - epoch
+
return process
@@ -1986,10 +2032,11 @@ class JSON(Indexable, TypeEngine):
"""
- __visit_name__ = 'JSON'
+
+ __visit_name__ = "JSON"
hashable = False
- NULL = util.symbol('JSON_NULL')
+ NULL = util.symbol("JSON_NULL")
"""Describe the json value of NULL.
This value is used to force the JSON value of ``"null"`` to be
@@ -2109,20 +2156,25 @@ class JSON(Indexable, TypeEngine):
class Comparator(Indexable.Comparator, Concatenable.Comparator):
"""Define comparison operations for :class:`.types.JSON`."""
- @util.dependencies('sqlalchemy.sql.default_comparator')
+ @util.dependencies("sqlalchemy.sql.default_comparator")
def _setup_getitem(self, default_comparator, index):
- if not isinstance(index, util.string_types) and \
- isinstance(index, compat.collections_abc.Sequence):
+ if not isinstance(index, util.string_types) and isinstance(
+ index, compat.collections_abc.Sequence
+ ):
index = default_comparator._check_literal(
- self.expr, operators.json_path_getitem_op,
- index, bindparam_type=JSON.JSONPathType
+ self.expr,
+ operators.json_path_getitem_op,
+ index,
+ bindparam_type=JSON.JSONPathType,
)
operator = operators.json_path_getitem_op
else:
index = default_comparator._check_literal(
- self.expr, operators.json_getitem_op,
- index, bindparam_type=JSON.JSONIndexType
+ self.expr,
+ operators.json_getitem_op,
+ index,
+ bindparam_type=JSON.JSONIndexType,
)
operator = operators.json_getitem_op
@@ -2172,6 +2224,7 @@ class JSON(Indexable, TypeEngine):
if string_process:
value = string_process(value)
return json_deserializer(value)
+
return process
@@ -2266,7 +2319,8 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
:class:`.postgresql.ARRAY`
"""
- __visit_name__ = 'ARRAY'
+
+ __visit_name__ = "ARRAY"
zero_indexes = False
"""if True, Python zero-based indexes should be interpreted as one-based
@@ -2285,21 +2339,23 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
if isinstance(index, slice):
return_type = self.type
if self.type.zero_indexes:
- index = slice(
- index.start + 1,
- index.stop + 1,
- index.step
- )
+ index = slice(index.start + 1, index.stop + 1, index.step)
index = Slice(
_literal_as_binds(
- index.start, name=self.expr.key,
- type_=type_api.INTEGERTYPE),
+ index.start,
+ name=self.expr.key,
+ type_=type_api.INTEGERTYPE,
+ ),
_literal_as_binds(
- index.stop, name=self.expr.key,
- type_=type_api.INTEGERTYPE),
+ index.stop,
+ name=self.expr.key,
+ type_=type_api.INTEGERTYPE,
+ ),
_literal_as_binds(
- index.step, name=self.expr.key,
- type_=type_api.INTEGERTYPE)
+ index.step,
+ name=self.expr.key,
+ type_=type_api.INTEGERTYPE,
+ ),
)
else:
if self.type.zero_indexes:
@@ -2307,16 +2363,18 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
if self.type.dimensions is None or self.type.dimensions == 1:
return_type = self.type.item_type
else:
- adapt_kw = {'dimensions': self.type.dimensions - 1}
+ adapt_kw = {"dimensions": self.type.dimensions - 1}
return_type = self.type.adapt(
- self.type.__class__, **adapt_kw)
+ self.type.__class__, **adapt_kw
+ )
return operators.getitem, index, return_type
def contains(self, *arg, **kw):
raise NotImplementedError(
"ARRAY.contains() not implemented for the base "
- "ARRAY type; please use the dialect-specific ARRAY type")
+ "ARRAY type; please use the dialect-specific ARRAY type"
+ )
@util.dependencies("sqlalchemy.sql.elements")
def any(self, elements, other, operator=None):
@@ -2350,7 +2408,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
operator = operator if operator else operators.eq
return operator(
elements._literal_as_binds(other),
- elements.CollectionAggregate._create_any(self.expr)
+ elements.CollectionAggregate._create_any(self.expr),
)
@util.dependencies("sqlalchemy.sql.elements")
@@ -2385,13 +2443,14 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
operator = operator if operator else operators.eq
return operator(
elements._literal_as_binds(other),
- elements.CollectionAggregate._create_all(self.expr)
+ elements.CollectionAggregate._create_all(self.expr),
)
comparator_factory = Comparator
- def __init__(self, item_type, as_tuple=False, dimensions=None,
- zero_indexes=False):
+ def __init__(
+ self, item_type, as_tuple=False, dimensions=None, zero_indexes=False
+ ):
"""Construct an :class:`.types.ARRAY`.
E.g.::
@@ -2424,8 +2483,10 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
"""
if isinstance(item_type, ARRAY):
- raise ValueError("Do not nest ARRAY types; ARRAY(basetype) "
- "handles multi-dimensional arrays of basetype")
+ raise ValueError(
+ "Do not nest ARRAY types; ARRAY(basetype) "
+ "handles multi-dimensional arrays of basetype"
+ )
if isinstance(item_type, type):
item_type = item_type()
self.item_type = item_type
@@ -2463,35 +2524,37 @@ class REAL(Float):
"""The SQL REAL type."""
- __visit_name__ = 'REAL'
+ __visit_name__ = "REAL"
class FLOAT(Float):
"""The SQL FLOAT type."""
- __visit_name__ = 'FLOAT'
+ __visit_name__ = "FLOAT"
class NUMERIC(Numeric):
"""The SQL NUMERIC type."""
- __visit_name__ = 'NUMERIC'
+ __visit_name__ = "NUMERIC"
class DECIMAL(Numeric):
"""The SQL DECIMAL type."""
- __visit_name__ = 'DECIMAL'
+ __visit_name__ = "DECIMAL"
class INTEGER(Integer):
"""The SQL INT or INTEGER type."""
- __visit_name__ = 'INTEGER'
+ __visit_name__ = "INTEGER"
+
+
INT = INTEGER
@@ -2499,14 +2562,14 @@ class SMALLINT(SmallInteger):
"""The SQL SMALLINT type."""
- __visit_name__ = 'SMALLINT'
+ __visit_name__ = "SMALLINT"
class BIGINT(BigInteger):
"""The SQL BIGINT type."""
- __visit_name__ = 'BIGINT'
+ __visit_name__ = "BIGINT"
class TIMESTAMP(DateTime):
@@ -2520,7 +2583,7 @@ class TIMESTAMP(DateTime):
"""
- __visit_name__ = 'TIMESTAMP'
+ __visit_name__ = "TIMESTAMP"
def __init__(self, timezone=False):
"""Construct a new :class:`.TIMESTAMP`.
@@ -2543,28 +2606,28 @@ class DATETIME(DateTime):
"""The SQL DATETIME type."""
- __visit_name__ = 'DATETIME'
+ __visit_name__ = "DATETIME"
class DATE(Date):
"""The SQL DATE type."""
- __visit_name__ = 'DATE'
+ __visit_name__ = "DATE"
class TIME(Time):
"""The SQL TIME type."""
- __visit_name__ = 'TIME'
+ __visit_name__ = "TIME"
class TEXT(Text):
"""The SQL TEXT type."""
- __visit_name__ = 'TEXT'
+ __visit_name__ = "TEXT"
class CLOB(Text):
@@ -2574,63 +2637,63 @@ class CLOB(Text):
This type is found in Oracle and Informix.
"""
- __visit_name__ = 'CLOB'
+ __visit_name__ = "CLOB"
class VARCHAR(String):
"""The SQL VARCHAR type."""
- __visit_name__ = 'VARCHAR'
+ __visit_name__ = "VARCHAR"
class NVARCHAR(Unicode):
"""The SQL NVARCHAR type."""
- __visit_name__ = 'NVARCHAR'
+ __visit_name__ = "NVARCHAR"
class CHAR(String):
"""The SQL CHAR type."""
- __visit_name__ = 'CHAR'
+ __visit_name__ = "CHAR"
class NCHAR(Unicode):
"""The SQL NCHAR type."""
- __visit_name__ = 'NCHAR'
+ __visit_name__ = "NCHAR"
class BLOB(LargeBinary):
"""The SQL BLOB type."""
- __visit_name__ = 'BLOB'
+ __visit_name__ = "BLOB"
class BINARY(_Binary):
"""The SQL BINARY type."""
- __visit_name__ = 'BINARY'
+ __visit_name__ = "BINARY"
class VARBINARY(_Binary):
"""The SQL VARBINARY type."""
- __visit_name__ = 'VARBINARY'
+ __visit_name__ = "VARBINARY"
class BOOLEAN(Boolean):
"""The SQL BOOLEAN type."""
- __visit_name__ = 'BOOLEAN'
+ __visit_name__ = "BOOLEAN"
class NullType(TypeEngine):
@@ -2657,7 +2720,8 @@ class NullType(TypeEngine):
construct.
"""
- __visit_name__ = 'null'
+
+ __visit_name__ = "null"
_isnull = True
@@ -2666,16 +2730,18 @@ class NullType(TypeEngine):
def literal_processor(self, dialect):
def process(value):
return "NULL"
+
return process
class Comparator(TypeEngine.Comparator):
-
def _adapt_expression(self, op, other_comparator):
- if isinstance(other_comparator, NullType.Comparator) or \
- not operators.is_commutative(op):
+ if isinstance(
+ other_comparator, NullType.Comparator
+ ) or not operators.is_commutative(op):
return op, self.expr.type
else:
return other_comparator._adapt_expression(op, self)
+
comparator_factory = Comparator
@@ -2694,6 +2760,7 @@ class MatchType(Boolean):
"""
+
NULLTYPE = NullType()
BOOLEANTYPE = Boolean()
STRINGTYPE = String()
@@ -2709,7 +2776,7 @@ _type_map = {
dt.datetime: DateTime(),
dt.time: Time(),
dt.timedelta: Interval(),
- util.NoneType: NULLTYPE
+ util.NoneType: NULLTYPE,
}
if util.py3k:
@@ -2729,19 +2796,23 @@ def _resolve_value_to_type(value):
# objects.
insp = inspection.inspect(value, False)
if (
- insp is not None and
- # foil mock.Mock() and other impostors by ensuring
- # the inspection target itself self-inspects
- insp.__class__ in inspection._registrars
+ insp is not None
+ and
+ # foil mock.Mock() and other impostors by ensuring
+ # the inspection target itself self-inspects
+ insp.__class__ in inspection._registrars
):
raise exc.ArgumentError(
- "Object %r is not legal as a SQL literal value" % value)
+ "Object %r is not legal as a SQL literal value" % value
+ )
return NULLTYPE
else:
return _result_type
+
# back-assign to type_api
from . import type_api
+
type_api.BOOLEANTYPE = BOOLEANTYPE
type_api.STRINGTYPE = STRINGTYPE
type_api.INTEGERTYPE = INTEGERTYPE
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index a8dfa19be..7fe780783 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -49,7 +49,8 @@ class TypeEngine(Visitable):
"""
- __slots__ = 'expr', 'type'
+
+ __slots__ = "expr", "type"
default_comparator = None
@@ -57,16 +58,15 @@ class TypeEngine(Visitable):
self.expr = expr
self.type = expr.type
- @util.dependencies('sqlalchemy.sql.default_comparator')
+ @util.dependencies("sqlalchemy.sql.default_comparator")
def operate(self, default_comparator, op, *other, **kwargs):
o = default_comparator.operator_lookup[op.__name__]
return o[0](self.expr, op, *(other + o[1:]), **kwargs)
- @util.dependencies('sqlalchemy.sql.default_comparator')
+ @util.dependencies("sqlalchemy.sql.default_comparator")
def reverse_operate(self, default_comparator, op, other, **kwargs):
o = default_comparator.operator_lookup[op.__name__]
- return o[0](self.expr, op, other,
- reverse=True, *o[1:], **kwargs)
+ return o[0](self.expr, op, other, reverse=True, *o[1:], **kwargs)
def _adapt_expression(self, op, other_comparator):
"""evaluate the return type of <self> <op> <othertype>,
@@ -97,7 +97,7 @@ class TypeEngine(Visitable):
return op, self.type
def __reduce__(self):
- return _reconstitute_comparator, (self.expr, )
+ return _reconstitute_comparator, (self.expr,)
hashable = True
"""Flag, if False, means values from this type aren't hashable.
@@ -313,8 +313,10 @@ class TypeEngine(Visitable):
"""
- return self.__class__.column_expression.__code__ \
+ return (
+ self.__class__.column_expression.__code__
is not TypeEngine.column_expression.__code__
+ )
def bind_expression(self, bindvalue):
""""Given a bind value (i.e. a :class:`.BindParameter` instance),
@@ -351,8 +353,10 @@ class TypeEngine(Visitable):
"""
- return self.__class__.bind_expression.__code__ \
+ return (
+ self.__class__.bind_expression.__code__
is not TypeEngine.bind_expression.__code__
+ )
@staticmethod
def _to_instance(cls_or_self):
@@ -441,9 +445,9 @@ class TypeEngine(Visitable):
"""
try:
- return dialect._type_memos[self]['impl']
+ return dialect._type_memos[self]["impl"]
except KeyError:
- return self._dialect_info(dialect)['impl']
+ return self._dialect_info(dialect)["impl"]
def _unwrapped_dialect_impl(self, dialect):
"""Return the 'unwrapped' dialect impl for this type.
@@ -462,20 +466,20 @@ class TypeEngine(Visitable):
def _cached_literal_processor(self, dialect):
"""Return a dialect-specific literal processor for this type."""
try:
- return dialect._type_memos[self]['literal']
+ return dialect._type_memos[self]["literal"]
except KeyError:
d = self._dialect_info(dialect)
- d['literal'] = lp = d['impl'].literal_processor(dialect)
+ d["literal"] = lp = d["impl"].literal_processor(dialect)
return lp
def _cached_bind_processor(self, dialect):
"""Return a dialect-specific bind processor for this type."""
try:
- return dialect._type_memos[self]['bind']
+ return dialect._type_memos[self]["bind"]
except KeyError:
d = self._dialect_info(dialect)
- d['bind'] = bp = d['impl'].bind_processor(dialect)
+ d["bind"] = bp = d["impl"].bind_processor(dialect)
return bp
def _cached_result_processor(self, dialect, coltype):
@@ -488,7 +492,7 @@ class TypeEngine(Visitable):
# key assumption: DBAPI type codes are
# constants. Else this dictionary would
# grow unbounded.
- d[coltype] = rp = d['impl'].result_processor(dialect, coltype)
+ d[coltype] = rp = d["impl"].result_processor(dialect, coltype)
return rp
def _cached_custom_processor(self, dialect, key, fn):
@@ -496,7 +500,7 @@ class TypeEngine(Visitable):
return dialect._type_memos[self][key]
except KeyError:
d = self._dialect_info(dialect)
- impl = d['impl']
+ impl = d["impl"]
d[key] = result = fn(impl)
return result
@@ -513,7 +517,7 @@ class TypeEngine(Visitable):
impl = self.adapt(type(self))
# this can't be self, else we create a cycle
assert impl is not self
- dialect._type_memos[self] = d = {'impl': impl}
+ dialect._type_memos[self] = d = {"impl": impl}
return d
def _gen_dialect_impl(self, dialect):
@@ -549,8 +553,10 @@ class TypeEngine(Visitable):
"""
_coerced_type = _resolve_value_to_type(value)
- if _coerced_type is NULLTYPE or _coerced_type._type_affinity \
- is self._type_affinity:
+ if (
+ _coerced_type is NULLTYPE
+ or _coerced_type._type_affinity is self._type_affinity
+ ):
return self
else:
return _coerced_type
@@ -586,8 +592,7 @@ class TypeEngine(Visitable):
def __str__(self):
if util.py2k:
- return unicode(self.compile()).\
- encode('ascii', 'backslashreplace')
+ return unicode(self.compile()).encode("ascii", "backslashreplace")
else:
return str(self.compile())
@@ -645,15 +650,16 @@ class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)):
``type_expression``, if it receives ``**kw`` in its signature.
"""
+
__visit_name__ = "user_defined"
- ensure_kwarg = 'get_col_spec'
+ ensure_kwarg = "get_col_spec"
class Comparator(TypeEngine.Comparator):
__slots__ = ()
def _adapt_expression(self, op, other_comparator):
- if hasattr(self.type, 'adapt_operator'):
+ if hasattr(self.type, "adapt_operator"):
util.warn_deprecated(
"UserDefinedType.adapt_operator is deprecated. Create "
"a UserDefinedType.Comparator subclass instead which "
@@ -854,6 +860,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
will cause the index value ``'foo'`` to be JSON encoded.
"""
+
__visit_name__ = "type_decorator"
def __init__(self, *args, **kwargs):
@@ -874,14 +881,16 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
- if not hasattr(self.__class__, 'impl'):
- raise AssertionError("TypeDecorator implementations "
- "require a class-level variable "
- "'impl' which refers to the class of "
- "type being decorated")
+ if not hasattr(self.__class__, "impl"):
+ raise AssertionError(
+ "TypeDecorator implementations "
+ "require a class-level variable "
+ "'impl' which refers to the class of "
+ "type being decorated"
+ )
self.impl = to_instance(self.__class__.impl, *args, **kwargs)
- coerce_to_is_types = (util.NoneType, )
+ coerce_to_is_types = (util.NoneType,)
"""Specify those Python types which should be coerced at the expression
level to "IS <constant>" when compared using ``==`` (and same for
``IS NOT`` in conjunction with ``!=``.
@@ -906,24 +915,27 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
__slots__ = ()
def operate(self, op, *other, **kwargs):
- kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types
+ kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
return super(TypeDecorator.Comparator, self).operate(
- op, *other, **kwargs)
+ op, *other, **kwargs
+ )
def reverse_operate(self, op, other, **kwargs):
- kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types
+ kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
return super(TypeDecorator.Comparator, self).reverse_operate(
- op, other, **kwargs)
+ op, other, **kwargs
+ )
@property
def comparator_factory(self):
if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__:
return self.impl.comparator_factory
else:
- return type("TDComparator",
- (TypeDecorator.Comparator,
- self.impl.comparator_factory),
- {})
+ return type(
+ "TDComparator",
+ (TypeDecorator.Comparator, self.impl.comparator_factory),
+ {},
+ )
def _gen_dialect_impl(self, dialect):
"""
@@ -939,10 +951,11 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
typedesc = self._unwrapped_dialect_impl(dialect)
tt = self.copy()
if not isinstance(tt, self.__class__):
- raise AssertionError('Type object %s does not properly '
- 'implement the copy() method, it must '
- 'return an object of type %s' %
- (self, self.__class__))
+ raise AssertionError(
+ "Type object %s does not properly "
+ "implement the copy() method, it must "
+ "return an object of type %s" % (self, self.__class__)
+ )
tt.impl = typedesc
return tt
@@ -1099,8 +1112,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
- return self.__class__.process_bind_param.__code__ \
+ return (
+ self.__class__.process_bind_param.__code__
is not TypeDecorator.process_bind_param.__code__
+ )
@util.memoized_property
def _has_literal_processor(self):
@@ -1109,8 +1124,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
- return self.__class__.process_literal_param.__code__ \
+ return (
+ self.__class__.process_literal_param.__code__
is not TypeDecorator.process_literal_param.__code__
+ )
def literal_processor(self, dialect):
"""Provide a literal processing function for the given
@@ -1147,9 +1164,12 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
if process_param:
impl_processor = self.impl.literal_processor(dialect)
if impl_processor:
+
def process(value):
return impl_processor(process_param(value, dialect))
+
else:
+
def process(value):
return process_param(value, dialect)
@@ -1180,10 +1200,12 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
process_param = self.process_bind_param
impl_processor = self.impl.bind_processor(dialect)
if impl_processor:
+
def process(value):
return impl_processor(process_param(value, dialect))
else:
+
def process(value):
return process_param(value, dialect)
@@ -1200,8 +1222,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
exception throw.
"""
- return self.__class__.process_result_value.__code__ \
+ return (
+ self.__class__.process_result_value.__code__
is not TypeDecorator.process_result_value.__code__
+ )
def result_processor(self, dialect, coltype):
"""Provide a result value processing function for the given
@@ -1225,13 +1249,14 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
if self._has_result_processor:
process_value = self.process_result_value
- impl_processor = self.impl.result_processor(dialect,
- coltype)
+ impl_processor = self.impl.result_processor(dialect, coltype)
if impl_processor:
+
def process(value):
return process_value(impl_processor(value), dialect)
else:
+
def process(value):
return process_value(value, dialect)
@@ -1397,7 +1422,8 @@ class Variant(TypeDecorator):
if dialect_name in self.mapping:
raise exc.ArgumentError(
"Dialect '%s' is already present in "
- "the mapping for this Variant" % dialect_name)
+ "the mapping for this Variant" % dialect_name
+ )
mapping = self.mapping.copy()
mapping[dialect_name] = type_
return Variant(self.impl, mapping)
@@ -1439,6 +1465,6 @@ def adapt_type(typeobj, colspecs):
# but it turns out the originally given "generic" type
# is actually a subclass of our resulting type, then we were already
# given a more specific type than that required; so use that.
- if (issubclass(typeobj.__class__, impltype)):
+ if issubclass(typeobj.__class__, impltype):
return typeobj
return typeobj.adapt(impltype)
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 12cfe09d1..4feaf9938 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -15,15 +15,29 @@ from . import operators, visitors
from itertools import chain
from collections import deque
-from .elements import BindParameter, ColumnClause, ColumnElement, \
- Null, UnaryExpression, literal_column, Label, _label_reference, \
- _textual_label_reference
-from .selectable import SelectBase, ScalarSelect, Join, FromClause, FromGrouping
+from .elements import (
+ BindParameter,
+ ColumnClause,
+ ColumnElement,
+ Null,
+ UnaryExpression,
+ literal_column,
+ Label,
+ _label_reference,
+ _textual_label_reference,
+)
+from .selectable import (
+ SelectBase,
+ ScalarSelect,
+ Join,
+ FromClause,
+ FromGrouping,
+)
from .schema import Column
join_condition = util.langhelpers.public_factory(
- Join._join_condition,
- ".sql.util.join_condition")
+ Join._join_condition, ".sql.util.join_condition"
+)
# names that are still being imported from the outside
from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate
@@ -88,8 +102,9 @@ def find_left_clause_that_matches_given(clauses, join_from):
for idx in liberal_idx:
f = clauses[idx]
for s in selectables:
- if set(surface_selectables(f)).\
- intersection(surface_selectables(s)):
+ if set(surface_selectables(f)).intersection(
+ surface_selectables(s)
+ ):
conservative_idx.append(idx)
break
if conservative_idx:
@@ -184,8 +199,9 @@ def visit_binary_product(fn, expr):
# we don't want to dig into correlated subqueries,
# those are just column elements by themselves
yield element
- elif element.__visit_name__ == 'binary' and \
- operators.is_comparison(element.operator):
+ elif element.__visit_name__ == "binary" and operators.is_comparison(
+ element.operator
+ ):
stack.insert(0, element)
for l in visit(element.left):
for r in visit(element.right):
@@ -199,38 +215,47 @@ def visit_binary_product(fn, expr):
for elem in element.get_children():
for e in visit(elem):
yield e
+
list(visit(expr))
-def find_tables(clause, check_columns=False,
- include_aliases=False, include_joins=False,
- include_selects=False, include_crud=False):
+def find_tables(
+ clause,
+ check_columns=False,
+ include_aliases=False,
+ include_joins=False,
+ include_selects=False,
+ include_crud=False,
+):
"""locate Table objects within the given expression."""
tables = []
_visitors = {}
if include_selects:
- _visitors['select'] = _visitors['compound_select'] = tables.append
+ _visitors["select"] = _visitors["compound_select"] = tables.append
if include_joins:
- _visitors['join'] = tables.append
+ _visitors["join"] = tables.append
if include_aliases:
- _visitors['alias'] = tables.append
+ _visitors["alias"] = tables.append
if include_crud:
- _visitors['insert'] = _visitors['update'] = \
- _visitors['delete'] = lambda ent: tables.append(ent.table)
+ _visitors["insert"] = _visitors["update"] = _visitors[
+ "delete"
+ ] = lambda ent: tables.append(ent.table)
if check_columns:
+
def visit_column(column):
tables.append(column.table)
- _visitors['column'] = visit_column
- _visitors['table'] = tables.append
+ _visitors["column"] = visit_column
- visitors.traverse(clause, {'column_collections': False}, _visitors)
+ _visitors["table"] = tables.append
+
+ visitors.traverse(clause, {"column_collections": False}, _visitors)
return tables
@@ -243,10 +268,9 @@ def unwrap_order_by(clause):
stack = deque([clause])
while stack:
t = stack.popleft()
- if isinstance(t, ColumnElement) and \
- (
- not isinstance(t, UnaryExpression) or
- not operators.is_ordering_modifier(t.modifier)
+ if isinstance(t, ColumnElement) and (
+ not isinstance(t, UnaryExpression)
+ or not operators.is_ordering_modifier(t.modifier)
):
if isinstance(t, _label_reference):
t = t.element
@@ -266,9 +290,7 @@ def unwrap_label_reference(element):
if isinstance(elem, (_label_reference, _textual_label_reference)):
return elem.element
- return visitors.replacement_traverse(
- element, {}, replace
- )
+ return visitors.replacement_traverse(element, {}, replace)
def expand_column_list_from_order_by(collist, order_by):
@@ -278,17 +300,16 @@ def expand_column_list_from_order_by(collist, order_by):
in the collist.
"""
- cols_already_present = set([
- col.element if col._order_by_label_element is not None
- else col for col in collist
- ])
+ cols_already_present = set(
+ [
+ col.element if col._order_by_label_element is not None else col
+ for col in collist
+ ]
+ )
return [
- col for col in
- chain(*[
- unwrap_order_by(o)
- for o in order_by
- ])
+ col
+ for col in chain(*[unwrap_order_by(o) for o in order_by])
if col not in cols_already_present
]
@@ -325,9 +346,9 @@ def surface_column_elements(clause, include_scalar_selects=True):
be addressable in the WHERE clause of a SELECT if this element were
in the columns clause."""
- filter_ = (FromGrouping, )
+ filter_ = (FromGrouping,)
if not include_scalar_selects:
- filter_ += (SelectBase, )
+ filter_ += (SelectBase,)
stack = deque([clause])
while stack:
@@ -343,9 +364,7 @@ def selectables_overlap(left, right):
"""Return True if left/right have some overlapping selectable"""
return bool(
- set(surface_selectables(left)).intersection(
- surface_selectables(right)
- )
+ set(surface_selectables(left)).intersection(surface_selectables(right))
)
@@ -366,7 +385,7 @@ def bind_values(clause):
def visit_bindparam(bind):
v.append(bind.effective_value)
- visitors.traverse(clause, {}, {'bindparam': visit_bindparam})
+ visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
return v
@@ -383,7 +402,7 @@ class _repr_base(object):
_TUPLE = 1
_DICT = 2
- __slots__ = 'max_chars',
+ __slots__ = ("max_chars",)
def trunc(self, value):
rep = repr(value)
@@ -391,10 +410,12 @@ class _repr_base(object):
if lenrep > self.max_chars:
segment_length = self.max_chars // 2
rep = (
- rep[0:segment_length] +
- (" ... (%d characters truncated) ... "
- % (lenrep - self.max_chars)) +
- rep[-segment_length:]
+ rep[0:segment_length]
+ + (
+ " ... (%d characters truncated) ... "
+ % (lenrep - self.max_chars)
+ )
+ + rep[-segment_length:]
)
return rep
@@ -402,7 +423,7 @@ class _repr_base(object):
class _repr_row(_repr_base):
"""Provide a string view of a row."""
- __slots__ = 'row',
+ __slots__ = ("row",)
def __init__(self, row, max_chars=300):
self.row = row
@@ -412,7 +433,7 @@ class _repr_row(_repr_base):
trunc = self.trunc
return "(%s%s)" % (
", ".join(trunc(value) for value in self.row),
- "," if len(self.row) == 1 else ""
+ "," if len(self.row) == 1 else "",
)
@@ -424,7 +445,7 @@ class _repr_params(_repr_base):
"""
- __slots__ = 'params', 'batches',
+ __slots__ = "params", "batches"
def __init__(self, params, batches, max_chars=300):
self.params = params
@@ -435,11 +456,13 @@ class _repr_params(_repr_base):
if isinstance(self.params, list):
typ = self._LIST
ismulti = self.params and isinstance(
- self.params[0], (list, dict, tuple))
+ self.params[0], (list, dict, tuple)
+ )
elif isinstance(self.params, tuple):
typ = self._TUPLE
ismulti = self.params and isinstance(
- self.params[0], (list, dict, tuple))
+ self.params[0], (list, dict, tuple)
+ )
elif isinstance(self.params, dict):
typ = self._DICT
ismulti = False
@@ -448,11 +471,15 @@ class _repr_params(_repr_base):
if ismulti and len(self.params) > self.batches:
msg = " ... displaying %i of %i total bound parameter sets ... "
- return ' '.join((
- self._repr_multi(self.params[:self.batches - 2], typ)[0:-1],
- msg % (self.batches, len(self.params)),
- self._repr_multi(self.params[-2:], typ)[1:]
- ))
+ return " ".join(
+ (
+ self._repr_multi(self.params[: self.batches - 2], typ)[
+ 0:-1
+ ],
+ msg % (self.batches, len(self.params)),
+ self._repr_multi(self.params[-2:], typ)[1:],
+ )
+ )
elif ismulti:
return self._repr_multi(self.params, typ)
else:
@@ -467,12 +494,13 @@ class _repr_params(_repr_base):
elif isinstance(multi_params[0], dict):
elem_type = self._DICT
else:
- assert False, \
- "Unknown parameter type %s" % (type(multi_params[0]))
+ assert False, "Unknown parameter type %s" % (
+ type(multi_params[0])
+ )
elements = ", ".join(
- self._repr_params(params, elem_type)
- for params in multi_params)
+ self._repr_params(params, elem_type) for params in multi_params
+ )
else:
elements = ""
@@ -493,13 +521,10 @@ class _repr_params(_repr_base):
elif typ is self._TUPLE:
return "(%s%s)" % (
", ".join(trunc(value) for value in params),
- "," if len(params) == 1 else ""
-
+ "," if len(params) == 1 else "",
)
else:
- return "[%s]" % (
- ", ".join(trunc(value) for value in params)
- )
+ return "[%s]" % (", ".join(trunc(value) for value in params))
def adapt_criterion_to_null(crit, nulls):
@@ -509,20 +534,24 @@ def adapt_criterion_to_null(crit, nulls):
"""
def visit_binary(binary):
- if isinstance(binary.left, BindParameter) \
- and binary.left._identifying_key in nulls:
+ if (
+ isinstance(binary.left, BindParameter)
+ and binary.left._identifying_key in nulls
+ ):
# reverse order if the NULL is on the left side
binary.left = binary.right
binary.right = Null()
binary.operator = operators.is_
binary.negate = operators.isnot
- elif isinstance(binary.right, BindParameter) \
- and binary.right._identifying_key in nulls:
+ elif (
+ isinstance(binary.right, BindParameter)
+ and binary.right._identifying_key in nulls
+ ):
binary.right = Null()
binary.operator = operators.is_
binary.negate = operators.isnot
- return visitors.cloned_traverse(crit, {}, {'binary': visit_binary})
+ return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
def splice_joins(left, right, stop_on=None):
@@ -570,8 +599,8 @@ def reduce_columns(columns, *clauses, **kw):
in the selectable to just those that are not repeated.
"""
- ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False)
- only_synonyms = kw.pop('only_synonyms', False)
+ ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
+ only_synonyms = kw.pop("only_synonyms", False)
columns = util.ordered_column_set(columns)
@@ -597,39 +626,48 @@ def reduce_columns(columns, *clauses, **kw):
continue
else:
raise
- if fk_col.shares_lineage(c) and \
- (not only_synonyms or
- c.name == col.name):
+ if fk_col.shares_lineage(c) and (
+ not only_synonyms or c.name == col.name
+ ):
omit.add(col)
break
if clauses:
+
def visit_binary(binary):
if binary.operator == operators.eq:
cols = util.column_set(
- chain(*[c.proxy_set for c in columns.difference(omit)]))
+ chain(*[c.proxy_set for c in columns.difference(omit)])
+ )
if binary.left in cols and binary.right in cols:
for c in reversed(columns):
- if c.shares_lineage(binary.right) and \
- (not only_synonyms or
- c.name == binary.left.name):
+ if c.shares_lineage(binary.right) and (
+ not only_synonyms or c.name == binary.left.name
+ ):
omit.add(c)
break
+
for clause in clauses:
if clause is not None:
- visitors.traverse(clause, {}, {'binary': visit_binary})
+ visitors.traverse(clause, {}, {"binary": visit_binary})
return ColumnSet(columns.difference(omit))
-def criterion_as_pairs(expression, consider_as_foreign_keys=None,
- consider_as_referenced_keys=None, any_operator=False):
+def criterion_as_pairs(
+ expression,
+ consider_as_foreign_keys=None,
+ consider_as_referenced_keys=None,
+ any_operator=False,
+):
"""traverse an expression and locate binary criterion pairs."""
if consider_as_foreign_keys and consider_as_referenced_keys:
- raise exc.ArgumentError("Can only specify one of "
- "'consider_as_foreign_keys' or "
- "'consider_as_referenced_keys'")
+ raise exc.ArgumentError(
+ "Can only specify one of "
+ "'consider_as_foreign_keys' or "
+ "'consider_as_referenced_keys'"
+ )
def col_is(a, b):
# return a is b
@@ -638,37 +676,44 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
def visit_binary(binary):
if not any_operator and binary.operator is not operators.eq:
return
- if not isinstance(binary.left, ColumnElement) or \
- not isinstance(binary.right, ColumnElement):
+ if not isinstance(binary.left, ColumnElement) or not isinstance(
+ binary.right, ColumnElement
+ ):
return
if consider_as_foreign_keys:
- if binary.left in consider_as_foreign_keys and \
- (col_is(binary.right, binary.left) or
- binary.right not in consider_as_foreign_keys):
+ if binary.left in consider_as_foreign_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_foreign_keys
+ ):
pairs.append((binary.right, binary.left))
- elif binary.right in consider_as_foreign_keys and \
- (col_is(binary.left, binary.right) or
- binary.left not in consider_as_foreign_keys):
+ elif binary.right in consider_as_foreign_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_foreign_keys
+ ):
pairs.append((binary.left, binary.right))
elif consider_as_referenced_keys:
- if binary.left in consider_as_referenced_keys and \
- (col_is(binary.right, binary.left) or
- binary.right not in consider_as_referenced_keys):
+ if binary.left in consider_as_referenced_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_referenced_keys
+ ):
pairs.append((binary.left, binary.right))
- elif binary.right in consider_as_referenced_keys and \
- (col_is(binary.left, binary.right) or
- binary.left not in consider_as_referenced_keys):
+ elif binary.right in consider_as_referenced_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_referenced_keys
+ ):
pairs.append((binary.right, binary.left))
else:
- if isinstance(binary.left, Column) and \
- isinstance(binary.right, Column):
+ if isinstance(binary.left, Column) and isinstance(
+ binary.right, Column
+ ):
if binary.left.references(binary.right):
pairs.append((binary.right, binary.left))
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
+
pairs = []
- visitors.traverse(expression, {}, {'binary': visit_binary})
+ visitors.traverse(expression, {}, {"binary": visit_binary})
return pairs
@@ -699,28 +744,38 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""
- def __init__(self, selectable, equivalents=None,
- include_fn=None, exclude_fn=None,
- adapt_on_names=False, anonymize_labels=False):
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ anonymize_labels=False,
+ ):
self.__traverse_options__ = {
- 'stop_on': [selectable],
- 'anonymize_labels': anonymize_labels}
+ "stop_on": [selectable],
+ "anonymize_labels": anonymize_labels,
+ }
self.selectable = selectable
self.include_fn = include_fn
self.exclude_fn = exclude_fn
self.equivalents = util.column_dict(equivalents or {})
self.adapt_on_names = adapt_on_names
- def _corresponding_column(self, col, require_embedded,
- _seen=util.EMPTY_SET):
+ def _corresponding_column(
+ self, col, require_embedded, _seen=util.EMPTY_SET
+ ):
newcol = self.selectable.corresponding_column(
- col,
- require_embedded=require_embedded)
+ col, require_embedded=require_embedded
+ )
if newcol is None and col in self.equivalents and col not in _seen:
for equiv in self.equivalents[col]:
newcol = self._corresponding_column(
- equiv, require_embedded=require_embedded,
- _seen=_seen.union([col]))
+ equiv,
+ require_embedded=require_embedded,
+ _seen=_seen.union([col]),
+ )
if newcol is not None:
return newcol
if self.adapt_on_names and newcol is None:
@@ -728,8 +783,9 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
return newcol
def replace(self, col):
- if isinstance(col, FromClause) and \
- self.selectable.is_derived_from(col):
+ if isinstance(col, FromClause) and self.selectable.is_derived_from(
+ col
+ ):
return self.selectable
elif not isinstance(col, ColumnElement):
return None
@@ -772,16 +828,27 @@ class ColumnAdapter(ClauseAdapter):
"""
- def __init__(self, selectable, equivalents=None,
- chain_to=None, adapt_required=False,
- include_fn=None, exclude_fn=None,
- adapt_on_names=False,
- allow_label_resolve=True,
- anonymize_labels=False):
- ClauseAdapter.__init__(self, selectable, equivalents,
- include_fn=include_fn, exclude_fn=exclude_fn,
- adapt_on_names=adapt_on_names,
- anonymize_labels=anonymize_labels)
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ chain_to=None,
+ adapt_required=False,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ allow_label_resolve=True,
+ anonymize_labels=False,
+ ):
+ ClauseAdapter.__init__(
+ self,
+ selectable,
+ equivalents,
+ include_fn=include_fn,
+ exclude_fn=exclude_fn,
+ adapt_on_names=adapt_on_names,
+ anonymize_labels=anonymize_labels,
+ )
if chain_to:
self.chain(chain_to)
@@ -800,9 +867,7 @@ class ColumnAdapter(ClauseAdapter):
def __getitem__(self, key):
if (
self.parent.include_fn and not self.parent.include_fn(key)
- ) or (
- self.parent.exclude_fn and self.parent.exclude_fn(key)
- ):
+ ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
if self.parent._wrap:
return self.parent._wrap.columns[key]
else:
@@ -843,7 +908,7 @@ class ColumnAdapter(ClauseAdapter):
def __getstate__(self):
d = self.__dict__.copy()
- del d['columns']
+ del d["columns"]
return d
def __setstate__(self, state):
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index b39ec8167..bf1743643 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -29,11 +29,20 @@ from .. import util
import operator
from .. import exc
-__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
- 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
- 'iterate_depthfirst', 'traverse_using', 'traverse',
- 'traverse_depthfirst',
- 'cloned_traverse', 'replacement_traverse']
+__all__ = [
+ "VisitableType",
+ "Visitable",
+ "ClauseVisitor",
+ "CloningVisitor",
+ "ReplacingCloningVisitor",
+ "iterate",
+ "iterate_depthfirst",
+ "traverse_using",
+ "traverse",
+ "traverse_depthfirst",
+ "cloned_traverse",
+ "replacement_traverse",
+]
class VisitableType(type):
@@ -53,8 +62,7 @@ class VisitableType(type):
"""
def __init__(cls, clsname, bases, clsdict):
- if clsname != 'Visitable' and \
- hasattr(cls, '__visit_name__'):
+ if clsname != "Visitable" and hasattr(cls, "__visit_name__"):
_generate_dispatch(cls)
super(VisitableType, cls).__init__(clsname, bases, clsdict)
@@ -64,7 +72,7 @@ def _generate_dispatch(cls):
"""Return an optimized visit dispatch function for the cls
for use by the compiler.
"""
- if '__visit_name__' in cls.__dict__:
+ if "__visit_name__" in cls.__dict__:
visit_name = cls.__visit_name__
if isinstance(visit_name, str):
# There is an optimization opportunity here because the
@@ -79,12 +87,13 @@ def _generate_dispatch(cls):
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
+
else:
# The optimization opportunity is lost for this case because the
# __visit_name__ is not yet a string. As a result, the visit
# string has to be recalculated with each compilation.
def _compiler_dispatch(self, visitor, **kw):
- visit_attr = 'visit_%s' % self.__visit_name__
+ visit_attr = "visit_%s" % self.__visit_name__
try:
meth = getattr(visitor, visit_attr)
except AttributeError:
@@ -92,8 +101,7 @@ def _generate_dispatch(cls):
else:
return meth(self, **kw)
- _compiler_dispatch.__doc__ = \
- """Look for an attribute named "visit_" + self.__visit_name__
+ _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.
"""
cls._compiler_dispatch = _compiler_dispatch
@@ -137,7 +145,7 @@ class ClauseVisitor(object):
visitors = {}
for name in dir(self):
- if name.startswith('visit_'):
+ if name.startswith("visit_"):
visitors[name[6:]] = getattr(self, name)
return visitors
@@ -148,7 +156,7 @@ class ClauseVisitor(object):
v = self
while v:
yield v
- v = getattr(v, '_next', None)
+ v = getattr(v, "_next", None)
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
@@ -178,7 +186,8 @@ class CloningVisitor(ClauseVisitor):
"""traverse and visit the given expression structure."""
return cloned_traverse(
- obj, self.__traverse_options__, self._visitor_dict)
+ obj, self.__traverse_options__, self._visitor_dict
+ )
class ReplacingCloningVisitor(CloningVisitor):
@@ -204,6 +213,7 @@ class ReplacingCloningVisitor(CloningVisitor):
e = v.replace(elem)
if e is not None:
return e
+
return replacement_traverse(obj, self.__traverse_options__, replace)
@@ -282,7 +292,7 @@ def cloned_traverse(obj, opts, visitors):
modifications by visitors."""
cloned = {}
- stop_on = set(opts.get('stop_on', []))
+ stop_on = set(opts.get("stop_on", []))
def clone(elem):
if elem in stop_on:
@@ -306,11 +316,13 @@ def replacement_traverse(obj, opts, replace):
replacement by a given replacement function."""
cloned = {}
- stop_on = {id(x) for x in opts.get('stop_on', [])}
+ stop_on = {id(x) for x in opts.get("stop_on", [])}
def clone(elem, **kw):
- if id(elem) in stop_on or \
- 'no_replacement_traverse' in elem._annotations:
+ if (
+ id(elem) in stop_on
+ or "no_replacement_traverse" in elem._annotations
+ ):
return elem
else:
newelem = replace(elem)