diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/__init__.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/annotation.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 114 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 2030 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 440 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/ddl.py | 306 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/default_comparator.py | 234 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 194 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 800 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 205 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 139 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/naming.py | 47 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/operators.py | 109 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 1129 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 812 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 649 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 122 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 327 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 48 |
19 files changed, 4603 insertions, 3129 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index aa811388b..87e2fb6c3 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -72,7 +72,7 @@ from .expression import ( union, union_all, update, - within_group + within_group, ) from .visitors import ClauseVisitor @@ -84,12 +84,16 @@ def __go(lcls): import inspect as _inspect - __all__ = sorted(name for name, obj in lcls.items() - if not (name.startswith('_') or _inspect.ismodule(obj))) + __all__ = sorted( + name + for name, obj in lcls.items() + if not (name.startswith("_") or _inspect.ismodule(obj)) + ) from .annotation import _prepare_annotations, Annotated from .elements import AnnotatedColumnElement, ClauseList from .selectable import AnnotatedFromClause + _prepare_annotations(ColumnElement, AnnotatedColumnElement) _prepare_annotations(FromClause, AnnotatedFromClause) _prepare_annotations(ClauseList, Annotated) @@ -98,4 +102,5 @@ def __go(lcls): from . import naming + __go(locals()) diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index c1d484d95..64cfa630e 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -76,8 +76,7 @@ class Annotated(object): return self._with_annotations(_values) def _compiler_dispatch(self, visitor, **kw): - return self.__element.__class__._compiler_dispatch( - self, visitor, **kw) + return self.__element.__class__._compiler_dispatch(self, visitor, **kw) @property def _constructor(self): @@ -120,10 +119,13 @@ def _deep_annotate(element, annotations, exclude=None): Elements within the exclude collection will be cloned but not annotated. """ + def clone(elem): - if exclude and \ - hasattr(elem, 'proxy_set') and \ - elem.proxy_set.intersection(exclude): + if ( + exclude + and hasattr(elem, "proxy_set") + and elem.proxy_set.intersection(exclude) + ): newelem = elem._clone() elif annotations != elem._annotations: newelem = elem._annotate(annotations) @@ -191,8 +193,8 @@ def _new_annotation_type(cls, base_cls): break annotated_classes[cls] = anno_cls = type( - "Annotated%s" % cls.__name__, - (base_cls, cls), {}) + "Annotated%s" % cls.__name__, (base_cls, cls), {} + ) globals()["Annotated%s" % cls.__name__] = anno_cls return anno_cls diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6b9b55753..45db215fe 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -15,8 +15,8 @@ import itertools from .visitors import ClauseVisitor import re -PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT') -NO_ARG = util.symbol('NO_ARG') +PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT") +NO_ARG = util.symbol("NO_ARG") class Immutable(object): @@ -77,7 +77,8 @@ class _DialectArgView(util.collections_abc.MutableMapping): dialect, value_key = self._key(key) except KeyError: raise exc.ArgumentError( - "Keys must be of the form <dialectname>_<argname>") + "Keys must be of the form <dialectname>_<argname>" + ) else: self.obj.dialect_options[dialect][value_key] = value @@ -86,15 +87,18 @@ class _DialectArgView(util.collections_abc.MutableMapping): del self.obj.dialect_options[dialect][value_key] def __len__(self): - return sum(len(args._non_defaults) for args in - self.obj.dialect_options.values()) + return sum( + len(args._non_defaults) + for args in self.obj.dialect_options.values() + ) def __iter__(self): return ( util.safe_kwarg("%s_%s" % (dialect_name, value_name)) for dialect_name in self.obj.dialect_options - for value_name in - self.obj.dialect_options[dialect_name]._non_defaults + for value_name in self.obj.dialect_options[ + dialect_name + ]._non_defaults ) @@ -187,8 +191,8 @@ class DialectKWArgs(object): if construct_arg_dictionary is None: raise exc.ArgumentError( "Dialect '%s' does have keyword-argument " - "validation and defaults enabled configured" % - dialect_name) + "validation and defaults enabled configured" % dialect_name + ) if cls not in construct_arg_dictionary: construct_arg_dictionary[cls] = {} construct_arg_dictionary[cls][argument_name] = default @@ -230,6 +234,7 @@ class DialectKWArgs(object): if dialect_cls.construct_arguments is None: return None return dict(dialect_cls.construct_arguments) + _kw_registry = util.PopulateDict(_kw_reg_for_dialect) def _kw_reg_for_dialect_cls(self, dialect_name): @@ -274,11 +279,12 @@ class DialectKWArgs(object): return for k in kwargs: - m = re.match('^(.+?)_(.+)$', k) + m = re.match("^(.+?)_(.+)$", k) if not m: raise TypeError( "Additional arguments should be " - "named <dialectname>_<argument>, got '%s'" % k) + "named <dialectname>_<argument>, got '%s'" % k + ) dialect_name, arg_name = m.group(1, 2) try: @@ -286,20 +292,22 @@ class DialectKWArgs(object): except exc.NoSuchModuleError: util.warn( "Can't validate argument %r; can't " - "locate any SQLAlchemy dialect named %r" % - (k, dialect_name)) + "locate any SQLAlchemy dialect named %r" + % (k, dialect_name) + ) self.dialect_options[dialect_name] = d = _DialectArgDict() d._defaults.update({"*": None}) d._non_defaults[arg_name] = kwargs[k] else: - if "*" not in construct_arg_dictionary and \ - arg_name not in construct_arg_dictionary: + if ( + "*" not in construct_arg_dictionary + and arg_name not in construct_arg_dictionary + ): raise exc.ArgumentError( "Argument %r is not accepted by " - "dialect %r on behalf of %r" % ( - k, - dialect_name, self.__class__ - )) + "dialect %r on behalf of %r" + % (k, dialect_name, self.__class__) + ) else: construct_arg_dictionary[arg_name] = kwargs[k] @@ -359,14 +367,14 @@ class Executable(Generative): :meth:`.Query.execution_options()` """ - if 'isolation_level' in kw: + if "isolation_level" in kw: raise exc.ArgumentError( "'isolation_level' execution option may only be specified " "on Connection.execution_options(), or " "per-engine using the isolation_level " "argument to create_engine()." ) - if 'compiled_cache' in kw: + if "compiled_cache" in kw: raise exc.ArgumentError( "'compiled_cache' execution option may only be specified " "on Connection.execution_options(), not per statement." @@ -377,10 +385,12 @@ class Executable(Generative): """Compile and execute this :class:`.Executable`.""" e = self.bind if e is None: - label = getattr(self, 'description', self.__class__.__name__) - msg = ('This %s is not directly bound to a Connection or Engine. ' - 'Use the .execute() method of a Connection or Engine ' - 'to execute this construct.' % label) + label = getattr(self, "description", self.__class__.__name__) + msg = ( + "This %s is not directly bound to a Connection or Engine. " + "Use the .execute() method of a Connection or Engine " + "to execute this construct." % label + ) raise exc.UnboundExecutionError(msg) return e._execute_clauseelement(self, multiparams, params) @@ -434,7 +444,7 @@ class SchemaEventTarget(object): class SchemaVisitor(ClauseVisitor): """Define the visiting for ``SchemaItem`` objects.""" - __traverse_options__ = {'schema_visitor': True} + __traverse_options__ = {"schema_visitor": True} class ColumnCollection(util.OrderedProperties): @@ -446,11 +456,11 @@ class ColumnCollection(util.OrderedProperties): """ - __slots__ = '_all_columns' + __slots__ = "_all_columns" def __init__(self, *columns): super(ColumnCollection, self).__init__() - object.__setattr__(self, '_all_columns', []) + object.__setattr__(self, "_all_columns", []) for c in columns: self.add(c) @@ -485,8 +495,9 @@ class ColumnCollection(util.OrderedProperties): self._data[column.key] = column if remove_col is not None: - self._all_columns[:] = [column if c is remove_col - else c for c in self._all_columns] + self._all_columns[:] = [ + column if c is remove_col else c for c in self._all_columns + ] else: self._all_columns.append(column) @@ -499,7 +510,8 @@ class ColumnCollection(util.OrderedProperties): """ if not column.key: raise exc.ArgumentError( - "Can't add unnamed column to column collection") + "Can't add unnamed column to column collection" + ) self[column.key] = column def __delitem__(self, key): @@ -521,10 +533,12 @@ class ColumnCollection(util.OrderedProperties): return if not existing.shares_lineage(value): - util.warn('Column %r on table %r being replaced by ' - '%r, which has the same key. Consider ' - 'use_labels for select() statements.' % - (key, getattr(existing, 'table', None), value)) + util.warn( + "Column %r on table %r being replaced by " + "%r, which has the same key. Consider " + "use_labels for select() statements." + % (key, getattr(existing, "table", None), value) + ) # pop out memoized proxy_set as this # operation may very well be occurring @@ -540,13 +554,15 @@ class ColumnCollection(util.OrderedProperties): def remove(self, column): del self._data[column.key] self._all_columns[:] = [ - c for c in self._all_columns if c is not column] + c for c in self._all_columns if c is not column + ] def update(self, iter): cols = list(iter) all_col_set = set(self._all_columns) self._all_columns.extend( - c for label, c in cols if c not in all_col_set) + c for label, c in cols if c not in all_col_set + ) self._data.update((label, c) for label, c in cols) def extend(self, iter): @@ -572,12 +588,11 @@ class ColumnCollection(util.OrderedProperties): return util.OrderedProperties.__contains__(self, other) def __getstate__(self): - return {'_data': self._data, - '_all_columns': self._all_columns} + return {"_data": self._data, "_all_columns": self._all_columns} def __setstate__(self, state): - object.__setattr__(self, '_data', state['_data']) - object.__setattr__(self, '_all_columns', state['_all_columns']) + object.__setattr__(self, "_data", state["_data"]) + object.__setattr__(self, "_all_columns", state["_all_columns"]) def contains_column(self, col): return col in set(self._all_columns) @@ -589,7 +604,7 @@ class ColumnCollection(util.OrderedProperties): class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): def __init__(self, data, all_columns): util.ImmutableProperties.__init__(self, data) - object.__setattr__(self, '_all_columns', all_columns) + object.__setattr__(self, "_all_columns", all_columns) extend = remove = util.ImmutableProperties._immutable @@ -622,15 +637,18 @@ def _bind_or_error(schemaitem, msg=None): bind = schemaitem.bind if not bind: name = schemaitem.__class__.__name__ - label = getattr(schemaitem, 'fullname', - getattr(schemaitem, 'name', None)) + label = getattr( + schemaitem, "fullname", getattr(schemaitem, "name", None) + ) if label: - item = '%s object %r' % (name, label) + item = "%s object %r" % (name, label) else: - item = '%s object' % name + item = "%s object" % name if msg is None: - msg = "%s is not bound to an Engine or Connection. "\ - "Execution can not proceed without a database to execute "\ + msg = ( + "%s is not bound to an Engine or Connection. " + "Execution can not proceed without a database to execute " "against." % item + ) raise exc.UnboundExecutionError(msg) return bind diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 80ed707ed..f641d0a84 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -25,133 +25,218 @@ To generate user-defined SQL strings, see import contextlib import re -from . import schema, sqltypes, operators, functions, visitors, \ - elements, selectable, crud +from . import ( + schema, + sqltypes, + operators, + functions, + visitors, + elements, + selectable, + crud, +) from .. import util, exc import itertools -RESERVED_WORDS = set([ - 'all', 'analyse', 'analyze', 'and', 'any', 'array', - 'as', 'asc', 'asymmetric', 'authorization', 'between', - 'binary', 'both', 'case', 'cast', 'check', 'collate', - 'column', 'constraint', 'create', 'cross', 'current_date', - 'current_role', 'current_time', 'current_timestamp', - 'current_user', 'default', 'deferrable', 'desc', - 'distinct', 'do', 'else', 'end', 'except', 'false', - 'for', 'foreign', 'freeze', 'from', 'full', 'grant', - 'group', 'having', 'ilike', 'in', 'initially', 'inner', - 'intersect', 'into', 'is', 'isnull', 'join', 'leading', - 'left', 'like', 'limit', 'localtime', 'localtimestamp', - 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', - 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', - 'placing', 'primary', 'references', 'right', 'select', - 'session_user', 'set', 'similar', 'some', 'symmetric', 'table', - 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', - 'using', 'verbose', 'when', 'where']) - -LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I) -ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(['$']) - -BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE) -BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]*)(?![:\w\$])', re.UNICODE) +RESERVED_WORDS = set( + [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "authorization", + "between", + "binary", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "for", + "foreign", + "freeze", + "from", + "full", + "grant", + "group", + "having", + "ilike", + "in", + "initially", + "inner", + "intersect", + "into", + "is", + "isnull", + "join", + "leading", + "left", + "like", + "limit", + "localtime", + "localtimestamp", + "natural", + "new", + "not", + "notnull", + "null", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "outer", + "overlaps", + "placing", + "primary", + "references", + "right", + "select", + "session_user", + "set", + "similar", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "verbose", + "when", + "where", + ] +) + +LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I) +ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"]) + +BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE) +BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE) BIND_TEMPLATES = { - 'pyformat': "%%(%(name)s)s", - 'qmark': "?", - 'format': "%%s", - 'numeric': ":[_POSITION]", - 'named': ":%(name)s" + "pyformat": "%%(%(name)s)s", + "qmark": "?", + "format": "%%s", + "numeric": ":[_POSITION]", + "named": ":%(name)s", } OPERATORS = { # binary - operators.and_: ' AND ', - operators.or_: ' OR ', - operators.add: ' + ', - operators.mul: ' * ', - operators.sub: ' - ', - operators.div: ' / ', - operators.mod: ' % ', - operators.truediv: ' / ', - operators.neg: '-', - operators.lt: ' < ', - operators.le: ' <= ', - operators.ne: ' != ', - operators.gt: ' > ', - operators.ge: ' >= ', - operators.eq: ' = ', - operators.is_distinct_from: ' IS DISTINCT FROM ', - operators.isnot_distinct_from: ' IS NOT DISTINCT FROM ', - operators.concat_op: ' || ', - operators.match_op: ' MATCH ', - operators.notmatch_op: ' NOT MATCH ', - operators.in_op: ' IN ', - operators.notin_op: ' NOT IN ', - operators.comma_op: ', ', - operators.from_: ' FROM ', - operators.as_: ' AS ', - operators.is_: ' IS ', - operators.isnot: ' IS NOT ', - operators.collate: ' COLLATE ', - + operators.and_: " AND ", + operators.or_: " OR ", + operators.add: " + ", + operators.mul: " * ", + operators.sub: " - ", + operators.div: " / ", + operators.mod: " % ", + operators.truediv: " / ", + operators.neg: "-", + operators.lt: " < ", + operators.le: " <= ", + operators.ne: " != ", + operators.gt: " > ", + operators.ge: " >= ", + operators.eq: " = ", + operators.is_distinct_from: " IS DISTINCT FROM ", + operators.isnot_distinct_from: " IS NOT DISTINCT FROM ", + operators.concat_op: " || ", + operators.match_op: " MATCH ", + operators.notmatch_op: " NOT MATCH ", + operators.in_op: " IN ", + operators.notin_op: " NOT IN ", + operators.comma_op: ", ", + operators.from_: " FROM ", + operators.as_: " AS ", + operators.is_: " IS ", + operators.isnot: " IS NOT ", + operators.collate: " COLLATE ", # unary - operators.exists: 'EXISTS ', - operators.distinct_op: 'DISTINCT ', - operators.inv: 'NOT ', - operators.any_op: 'ANY ', - operators.all_op: 'ALL ', - + operators.exists: "EXISTS ", + operators.distinct_op: "DISTINCT ", + operators.inv: "NOT ", + operators.any_op: "ANY ", + operators.all_op: "ALL ", # modifiers - operators.desc_op: ' DESC', - operators.asc_op: ' ASC', - operators.nullsfirst_op: ' NULLS FIRST', - operators.nullslast_op: ' NULLS LAST', - + operators.desc_op: " DESC", + operators.asc_op: " ASC", + operators.nullsfirst_op: " NULLS FIRST", + operators.nullslast_op: " NULLS LAST", } FUNCTIONS = { - functions.coalesce: 'coalesce', - functions.current_date: 'CURRENT_DATE', - functions.current_time: 'CURRENT_TIME', - functions.current_timestamp: 'CURRENT_TIMESTAMP', - functions.current_user: 'CURRENT_USER', - functions.localtime: 'LOCALTIME', - functions.localtimestamp: 'LOCALTIMESTAMP', - functions.random: 'random', - functions.sysdate: 'sysdate', - functions.session_user: 'SESSION_USER', - functions.user: 'USER', - functions.cube: 'CUBE', - functions.rollup: 'ROLLUP', - functions.grouping_sets: 'GROUPING SETS', + functions.coalesce: "coalesce", + functions.current_date: "CURRENT_DATE", + functions.current_time: "CURRENT_TIME", + functions.current_timestamp: "CURRENT_TIMESTAMP", + functions.current_user: "CURRENT_USER", + functions.localtime: "LOCALTIME", + functions.localtimestamp: "LOCALTIMESTAMP", + functions.random: "random", + functions.sysdate: "sysdate", + functions.session_user: "SESSION_USER", + functions.user: "USER", + functions.cube: "CUBE", + functions.rollup: "ROLLUP", + functions.grouping_sets: "GROUPING SETS", } EXTRACT_MAP = { - 'month': 'month', - 'day': 'day', - 'year': 'year', - 'second': 'second', - 'hour': 'hour', - 'doy': 'doy', - 'minute': 'minute', - 'quarter': 'quarter', - 'dow': 'dow', - 'week': 'week', - 'epoch': 'epoch', - 'milliseconds': 'milliseconds', - 'microseconds': 'microseconds', - 'timezone_hour': 'timezone_hour', - 'timezone_minute': 'timezone_minute' + "month": "month", + "day": "day", + "year": "year", + "second": "second", + "hour": "hour", + "doy": "doy", + "minute": "minute", + "quarter": "quarter", + "dow": "dow", + "week": "week", + "epoch": "epoch", + "milliseconds": "milliseconds", + "microseconds": "microseconds", + "timezone_hour": "timezone_hour", + "timezone_minute": "timezone_minute", } COMPOUND_KEYWORDS = { - selectable.CompoundSelect.UNION: 'UNION', - selectable.CompoundSelect.UNION_ALL: 'UNION ALL', - selectable.CompoundSelect.EXCEPT: 'EXCEPT', - selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL', - selectable.CompoundSelect.INTERSECT: 'INTERSECT', - selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' + selectable.CompoundSelect.UNION: "UNION", + selectable.CompoundSelect.UNION_ALL: "UNION ALL", + selectable.CompoundSelect.EXCEPT: "EXCEPT", + selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL", + selectable.CompoundSelect.INTERSECT: "INTERSECT", + selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL", } @@ -177,9 +262,14 @@ class Compiled(object): sub-elements of the statement can modify these. """ - def __init__(self, dialect, statement, bind=None, - schema_translate_map=None, - compile_kwargs=util.immutabledict()): + def __init__( + self, + dialect, + statement, + bind=None, + schema_translate_map=None, + compile_kwargs=util.immutabledict(), + ): """Construct a new :class:`.Compiled` object. :param dialect: :class:`.Dialect` to compile against. @@ -209,7 +299,8 @@ class Compiled(object): self.preparer = self.dialect.identifier_preparer if schema_translate_map: self.preparer = self.preparer._with_schema_translate( - schema_translate_map) + schema_translate_map + ) if statement is not None: self.statement = statement @@ -218,8 +309,10 @@ class Compiled(object): self.execution_options = statement._execution_options self.string = self.process(self.statement, **compile_kwargs) - @util.deprecated("0.7", ":class:`.Compiled` objects now compile " - "within the constructor.") + @util.deprecated( + "0.7", + ":class:`.Compiled` objects now compile " "within the constructor.", + ) def compile(self): """Produce the internal string representation of this element. """ @@ -247,7 +340,7 @@ class Compiled(object): def __str__(self): """Return the string text of the generated SQL or DDL.""" - return self.string or '' + return self.string or "" def construct_params(self, params=None): """Return the bind params for this compiled object. @@ -271,7 +364,9 @@ class Compiled(object): if e is None: raise exc.UnboundExecutionError( "This Compiled object is not bound to any Engine " - "or Connection.", code="2afi") + "or Connection.", + code="2afi", + ) return e._execute_compiled(self, multiparams, params) def scalar(self, *multiparams, **params): @@ -284,7 +379,7 @@ class Compiled(object): class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): """Produces DDL specification for TypeEngine objects.""" - ensure_kwarg = r'visit_\w+' + ensure_kwarg = r"visit_\w+" def __init__(self, dialect): self.dialect = dialect @@ -297,8 +392,8 @@ class _CompileLabel(visitors.Visitable): """lightweight label object which acts as an expression.Label.""" - __visit_name__ = 'label' - __slots__ = 'element', 'name' + __visit_name__ = "label" + __slots__ = "element", "name" def __init__(self, col, name, alt_names=()): self.element = col @@ -390,8 +485,9 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () - def __init__(self, dialect, statement, column_keys=None, - inline=False, **kwargs): + def __init__( + self, dialect, statement, column_keys=None, inline=False, **kwargs + ): """Construct a new :class:`.SQLCompiler` object. :param dialect: :class:`.Dialect` to be used @@ -412,7 +508,7 @@ class SQLCompiler(Compiled): # compile INSERT/UPDATE defaults/sequences inlined (no pre- # execute) - self.inline = inline or getattr(statement, 'inline', False) + self.inline = inline or getattr(statement, "inline", False) # a dictionary of bind parameter keys to BindParameter # instances. @@ -440,8 +536,9 @@ class SQLCompiler(Compiled): self.ctes = None - self.label_length = dialect.label_length \ - or dialect.max_identifier_length + self.label_length = ( + dialect.label_length or dialect.max_identifier_length + ) # a map which tracks "anonymous" identifiers that are created on # the fly here @@ -453,7 +550,7 @@ class SQLCompiler(Compiled): Compiled.__init__(self, dialect, statement, **kwargs) if ( - self.isinsert or self.isupdate or self.isdelete + self.isinsert or self.isupdate or self.isdelete ) and statement._returning: self.returning = statement._returning @@ -482,37 +579,43 @@ class SQLCompiler(Compiled): def _nested_result(self): """special API to support the use case of 'nested result sets'""" result_columns, ordered_columns = ( - self._result_columns, self._ordered_columns) + self._result_columns, + self._ordered_columns, + ) self._result_columns, self._ordered_columns = [], False try: if self.stack: entry = self.stack[-1] - entry['need_result_map_for_nested'] = True + entry["need_result_map_for_nested"] = True else: entry = None yield self._result_columns, self._ordered_columns finally: if entry: - entry.pop('need_result_map_for_nested') + entry.pop("need_result_map_for_nested") self._result_columns, self._ordered_columns = ( - result_columns, ordered_columns) + result_columns, + ordered_columns, + ) def _apply_numbered_params(self): poscount = itertools.count(1) self.string = re.sub( - r'\[_POSITION\]', - lambda m: str(util.next(poscount)), - self.string) + r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string + ) @util.memoized_property def _bind_processors(self): return dict( - (key, value) for key, value in - ((self.bind_names[bindparam], - bindparam.type._cached_bind_processor(self.dialect) - ) - for bindparam in self.bind_names) + (key, value) + for key, value in ( + ( + self.bind_names[bindparam], + bindparam.type._cached_bind_processor(self.dialect), + ) + for bindparam in self.bind_names + ) if value is not None ) @@ -539,12 +642,16 @@ class SQLCompiler(Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % - (bindparam.key, _group_number), code="cd3x") + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) else: raise exc.InvalidRequestError( "A value is required for bind parameter %r" - % bindparam.key, code="cd3x") + % bindparam.key, + code="cd3x", + ) elif bindparam.callable: pd[name] = bindparam.effective_value @@ -558,12 +665,16 @@ class SQLCompiler(Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % - (bindparam.key, _group_number), code="cd3x") + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) else: raise exc.InvalidRequestError( "A value is required for bind parameter %r" - % bindparam.key, code="cd3x") + % bindparam.key, + code="cd3x", + ) if bindparam.callable: pd[self.bind_names[bindparam]] = bindparam.effective_value @@ -595,9 +706,10 @@ class SQLCompiler(Compiled): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" def visit_label_reference( - self, element, within_columns_clause=False, **kwargs): + self, element, within_columns_clause=False, **kwargs + ): if self.stack and self.dialect.supports_simple_order_by_label: - selectable = self.stack[-1]['selectable'] + selectable = self.stack[-1]["selectable"] with_cols, only_froms, only_cols = selectable._label_resolve_dict if within_columns_clause: @@ -611,25 +723,30 @@ class SQLCompiler(Compiled): # to something else like a ColumnClause expression. order_by_elem = element.element._order_by_label_element - if order_by_elem is not None and order_by_elem.name in \ - resolve_dict and \ - order_by_elem.shares_lineage( - resolve_dict[order_by_elem.name]): - kwargs['render_label_as_label'] = \ - element.element._order_by_label_element + if ( + order_by_elem is not None + and order_by_elem.name in resolve_dict + and order_by_elem.shares_lineage( + resolve_dict[order_by_elem.name] + ) + ): + kwargs[ + "render_label_as_label" + ] = element.element._order_by_label_element return self.process( - element.element, within_columns_clause=within_columns_clause, - **kwargs) + element.element, + within_columns_clause=within_columns_clause, + **kwargs + ) def visit_textual_label_reference( - self, element, within_columns_clause=False, **kwargs): + self, element, within_columns_clause=False, **kwargs + ): if not self.stack: # compiling the element outside of the context of a SELECT - return self.process( - element._text_clause - ) + return self.process(element._text_clause) - selectable = self.stack[-1]['selectable'] + selectable = self.stack[-1]["selectable"] with_cols, only_froms, only_cols = selectable._label_resolve_dict try: if within_columns_clause: @@ -640,26 +757,30 @@ class SQLCompiler(Compiled): # treat it like text() util.warn_limited( "Can't resolve label reference %r; converting to text()", - util.ellipses_string(element.element)) - return self.process( - element._text_clause + util.ellipses_string(element.element), ) + return self.process(element._text_clause) else: - kwargs['render_label_as_label'] = col + kwargs["render_label_as_label"] = col return self.process( - col, within_columns_clause=within_columns_clause, **kwargs) - - def visit_label(self, label, - add_to_result_map=None, - within_label_clause=False, - within_columns_clause=False, - render_label_as_label=None, - **kw): + col, within_columns_clause=within_columns_clause, **kwargs + ) + + def visit_label( + self, + label, + add_to_result_map=None, + within_label_clause=False, + within_columns_clause=False, + render_label_as_label=None, + **kw + ): # only render labels within the columns clause # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. - render_label_with_as = (within_columns_clause and not - within_label_clause) + render_label_with_as = ( + within_columns_clause and not within_label_clause + ) render_label_only = render_label_as_label is label if render_label_only or render_label_with_as: @@ -673,27 +794,35 @@ class SQLCompiler(Compiled): add_to_result_map( labelname, label.name, - (label, labelname, ) + label._alt_names, - label.type + (label, labelname) + label._alt_names, + label.type, ) - return label.element._compiler_dispatch( - self, within_columns_clause=True, - within_label_clause=True, **kw) + \ - OPERATORS[operators.as_] + \ - self.preparer.format_label(label, labelname) + return ( + label.element._compiler_dispatch( + self, + within_columns_clause=True, + within_label_clause=True, + **kw + ) + + OPERATORS[operators.as_] + + self.preparer.format_label(label, labelname) + ) elif render_label_only: return self.preparer.format_label(label, labelname) else: return label.element._compiler_dispatch( - self, within_columns_clause=False, **kw) + self, within_columns_clause=False, **kw + ) def _fallback_column_name(self, column): - raise exc.CompileError("Cannot compile Column object until " - "its 'name' is assigned.") + raise exc.CompileError( + "Cannot compile Column object until " "its 'name' is assigned." + ) - def visit_column(self, column, add_to_result_map=None, - include_table=True, **kwargs): + def visit_column( + self, column, add_to_result_map=None, include_table=True, **kwargs + ): name = orig_name = column.name if name is None: name = self._fallback_column_name(column) @@ -704,10 +833,7 @@ class SQLCompiler(Compiled): if add_to_result_map is not None: add_to_result_map( - name, - orig_name, - (column, name, column.key), - column.type + name, orig_name, (column, name, column.key), column.type ) if is_literal: @@ -721,17 +847,16 @@ class SQLCompiler(Compiled): effective_schema = self.preparer.schema_for_object(table) if effective_schema: - schema_prefix = self.preparer.quote_schema( - effective_schema) + '.' + schema_prefix = ( + self.preparer.quote_schema(effective_schema) + "." + ) else: - schema_prefix = '' + schema_prefix = "" tablename = table.name if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) - return schema_prefix + \ - self.preparer.quote(tablename) + \ - "." + name + return schema_prefix + self.preparer.quote(tablename) + "." + name def visit_collation(self, element, **kw): return self.preparer.format_collation(element.collation) @@ -743,17 +868,17 @@ class SQLCompiler(Compiled): return index.name def visit_typeclause(self, typeclause, **kw): - kw['type_expression'] = typeclause + kw["type_expression"] = typeclause return self.dialect.type_compiler.process(typeclause.type, **kw) def post_process_text(self, text): if self.preparer._double_percents: - text = text.replace('%', '%%') + text = text.replace("%", "%%") return text def escape_literal_column(self, text): if self.preparer._double_percents: - text = text.replace('%', '%%') + text = text.replace("%", "%%") return text def visit_textclause(self, textclause, **kw): @@ -771,30 +896,36 @@ class SQLCompiler(Compiled): return BIND_PARAMS_ESC.sub( lambda m: m.group(1), BIND_PARAMS.sub( - do_bindparam, - self.post_process_text(textclause.text)) + do_bindparam, self.post_process_text(textclause.text) + ), ) - def visit_text_as_from(self, taf, - compound_index=None, - asfrom=False, - parens=True, **kw): + def visit_text_as_from( + self, taf, compound_index=None, asfrom=False, parens=True, **kw + ): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = toplevel or \ - ( - compound_index == 0 and entry.get( - 'need_result_map_for_compound', False) - ) or entry.get('need_result_map_for_nested', False) + populate_result_map = ( + toplevel + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) if populate_result_map: - self._ordered_columns = \ - self._textual_ordered_columns = taf.positional + self._ordered_columns = ( + self._textual_ordered_columns + ) = taf.positional for c in taf.column_args: - self.process(c, within_columns_clause=True, - add_to_result_map=self._add_to_result_map) + self.process( + c, + within_columns_clause=True, + add_to_result_map=self._add_to_result_map, + ) text = self.process(taf.element, **kw) if asfrom and parens: @@ -802,17 +933,17 @@ class SQLCompiler(Compiled): return text def visit_null(self, expr, **kw): - return 'NULL' + return "NULL" def visit_true(self, expr, **kw): if self.dialect.supports_native_boolean: - return 'true' + return "true" else: return "1" def visit_false(self, expr, **kw): if self.dialect.supports_native_boolean: - return 'false' + return "false" else: return "0" @@ -823,25 +954,29 @@ class SQLCompiler(Compiled): else: sep = OPERATORS[clauselist.operator] return sep.join( - s for s in - ( - c._compiler_dispatch(self, **kw) - for c in clauselist.clauses) - if s) + s + for s in ( + c._compiler_dispatch(self, **kw) for c in clauselist.clauses + ) + if s + ) def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: x += clause.value._compiler_dispatch(self, **kwargs) + " " for cond, result in clause.whens: - x += "WHEN " + cond._compiler_dispatch( - self, **kwargs - ) + " THEN " + result._compiler_dispatch( - self, **kwargs) + " " + x += ( + "WHEN " + + cond._compiler_dispatch(self, **kwargs) + + " THEN " + + result._compiler_dispatch(self, **kwargs) + + " " + ) if clause.else_ is not None: - x += "ELSE " + clause.else_._compiler_dispatch( - self, **kwargs - ) + " " + x += ( + "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " " + ) x += "END" return x @@ -849,79 +984,84 @@ class SQLCompiler(Compiled): return type_coerce.typed_expression._compiler_dispatch(self, **kw) def visit_cast(self, cast, **kwargs): - return "CAST(%s AS %s)" % \ - (cast.clause._compiler_dispatch(self, **kwargs), - cast.typeclause._compiler_dispatch(self, **kwargs)) + return "CAST(%s AS %s)" % ( + cast.clause._compiler_dispatch(self, **kwargs), + cast.typeclause._compiler_dispatch(self, **kwargs), + ) def _format_frame_clause(self, range_, **kw): - return '%s AND %s' % ( + return "%s AND %s" % ( "UNBOUNDED PRECEDING" if range_[0] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" if range_[0] is elements.RANGE_CURRENT - else "%s PRECEDING" % ( - self.process(elements.literal(abs(range_[0])), **kw), ) + else "CURRENT ROW" + if range_[0] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[0])), **kw),) if range_[0] < 0 - else "%s FOLLOWING" % ( - self.process(elements.literal(range_[0]), **kw), ), - + else "%s FOLLOWING" + % (self.process(elements.literal(range_[0]), **kw),), "UNBOUNDED FOLLOWING" if range_[1] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" if range_[1] is elements.RANGE_CURRENT - else "%s PRECEDING" % ( - self.process(elements.literal(abs(range_[1])), **kw), ) + else "CURRENT ROW" + if range_[1] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[1])), **kw),) if range_[1] < 0 - else "%s FOLLOWING" % ( - self.process(elements.literal(range_[1]), **kw), ), + else "%s FOLLOWING" + % (self.process(elements.literal(range_[1]), **kw),), ) def visit_over(self, over, **kwargs): if over.range_: range_ = "RANGE BETWEEN %s" % self._format_frame_clause( - over.range_, **kwargs) + over.range_, **kwargs + ) elif over.rows: range_ = "ROWS BETWEEN %s" % self._format_frame_clause( - over.rows, **kwargs) + over.rows, **kwargs + ) else: range_ = None return "%s OVER (%s)" % ( over.element._compiler_dispatch(self, **kwargs), - ' '.join([ - '%s BY %s' % ( - word, clause._compiler_dispatch(self, **kwargs) - ) - for word, clause in ( - ('PARTITION', over.partition_by), - ('ORDER', over.order_by) - ) - if clause is not None and len(clause) - ] + ([range_] if range_ else []) - ) + " ".join( + [ + "%s BY %s" + % (word, clause._compiler_dispatch(self, **kwargs)) + for word, clause in ( + ("PARTITION", over.partition_by), + ("ORDER", over.order_by), + ) + if clause is not None and len(clause) + ] + + ([range_] if range_ else []) + ), ) def visit_withingroup(self, withingroup, **kwargs): return "%s WITHIN GROUP (ORDER BY %s)" % ( withingroup.element._compiler_dispatch(self, **kwargs), - withingroup.order_by._compiler_dispatch(self, **kwargs) + withingroup.order_by._compiler_dispatch(self, **kwargs), ) def visit_funcfilter(self, funcfilter, **kwargs): return "%s FILTER (WHERE %s)" % ( funcfilter.func._compiler_dispatch(self, **kwargs), - funcfilter.criterion._compiler_dispatch(self, **kwargs) + funcfilter.criterion._compiler_dispatch(self, **kwargs), ) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) return "EXTRACT(%s FROM %s)" % ( - field, extract.expr._compiler_dispatch(self, **kwargs)) + field, + extract.expr._compiler_dispatch(self, **kwargs), + ) def visit_function(self, func, add_to_result_map=None, **kwargs): if add_to_result_map is not None: - add_to_result_map( - func.name, func.name, (), func.type - ) + add_to_result_map(func.name, func.name, (), func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) if disp: @@ -933,51 +1073,63 @@ class SQLCompiler(Compiled): name += "%(expr)s" else: name = func.name + "%(expr)s" - return ".".join(list(func.packagenames) + [name]) % \ - {'expr': self.function_argspec(func, **kwargs)} + return ".".join(list(func.packagenames) + [name]) % { + "expr": self.function_argspec(func, **kwargs) + } def visit_next_value_func(self, next_value, **kw): return self.visit_sequence(next_value.sequence) def visit_sequence(self, sequence, **kw): raise NotImplementedError( - "Dialect '%s' does not support sequence increments." % - self.dialect.name + "Dialect '%s' does not support sequence increments." + % self.dialect.name ) def function_argspec(self, func, **kwargs): return func.clause_expr._compiler_dispatch(self, **kwargs) - def visit_compound_select(self, cs, asfrom=False, - parens=True, compound_index=0, **kwargs): + def visit_compound_select( + self, cs, asfrom=False, parens=True, compound_index=0, **kwargs + ): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - need_result_map = toplevel or \ - (compound_index == 0 - and entry.get('need_result_map_for_compound', False)) + need_result_map = toplevel or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) self.stack.append( { - 'correlate_froms': entry['correlate_froms'], - 'asfrom_froms': entry['asfrom_froms'], - 'selectable': cs, - 'need_result_map_for_compound': need_result_map - }) + "correlate_froms": entry["correlate_froms"], + "asfrom_froms": entry["asfrom_froms"], + "selectable": cs, + "need_result_map_for_compound": need_result_map, + } + ) keyword = self.compound_keywords.get(cs.keyword) text = (" " + keyword + " ").join( - (c._compiler_dispatch(self, - asfrom=asfrom, parens=False, - compound_index=i, **kwargs) - for i, c in enumerate(cs.selects)) + ( + c._compiler_dispatch( + self, + asfrom=asfrom, + parens=False, + compound_index=i, + **kwargs + ) + for i, c in enumerate(cs.selects) + ) ) text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs)) text += self.order_by_clause(cs, **kwargs) - text += (cs._limit_clause is not None - or cs._offset_clause is not None) and \ - self.limit_clause(cs, **kwargs) or "" + text += ( + (cs._limit_clause is not None or cs._offset_clause is not None) + and self.limit_clause(cs, **kwargs) + or "" + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -990,8 +1142,10 @@ class SQLCompiler(Compiled): def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): attrname = "visit_%s_%s%s" % ( - operator_.__name__, qualifier1, - "_" + qualifier2 if qualifier2 else "") + operator_.__name__, + qualifier1, + "_" + qualifier2 if qualifier2 else "", + ) return getattr(self, attrname, None) def visit_unary(self, unary, **kw): @@ -999,51 +1153,63 @@ class SQLCompiler(Compiled): if unary.modifier: raise exc.CompileError( "Unary expression does not support operator " - "and modifier simultaneously") + "and modifier simultaneously" + ) disp = self._get_operator_dispatch( - unary.operator, "unary", "operator") + unary.operator, "unary", "operator" + ) if disp: return disp(unary, unary.operator, **kw) else: return self._generate_generic_unary_operator( - unary, OPERATORS[unary.operator], **kw) + unary, OPERATORS[unary.operator], **kw + ) elif unary.modifier: disp = self._get_operator_dispatch( - unary.modifier, "unary", "modifier") + unary.modifier, "unary", "modifier" + ) if disp: return disp(unary, unary.modifier, **kw) else: return self._generate_generic_unary_modifier( - unary, OPERATORS[unary.modifier], **kw) + unary, OPERATORS[unary.modifier], **kw + ) else: raise exc.CompileError( - "Unary expression has no operator or modifier") + "Unary expression has no operator or modifier" + ) def visit_istrue_unary_operator(self, element, operator, **kw): - if element._is_implicitly_boolean or \ - self.dialect.supports_native_boolean: + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): return self.process(element.element, **kw) else: return "%s = 1" % self.process(element.element, **kw) def visit_isfalse_unary_operator(self, element, operator, **kw): - if element._is_implicitly_boolean or \ - self.dialect.supports_native_boolean: + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): return "NOT %s" % self.process(element.element, **kw) else: return "%s = 0" % self.process(element.element, **kw) def visit_notmatch_op_binary(self, binary, operator, **kw): return "NOT %s" % self.visit_binary( - binary, override_operator=operators.match_op) + binary, override_operator=operators.match_op + ) def _emit_empty_in_warning(self): util.warn( - 'The IN-predicate was invoked with an ' - 'empty sequence. This results in a ' - 'contradiction, which nonetheless can be ' - 'expensive to evaluate. Consider alternative ' - 'strategies for improved performance.') + "The IN-predicate was invoked with an " + "empty sequence. This results in a " + "contradiction, which nonetheless can be " + "expensive to evaluate. Consider alternative " + "strategies for improved performance." + ) def visit_empty_in_op_binary(self, binary, operator, **kw): if self.dialect._use_static_in: @@ -1063,18 +1229,21 @@ class SQLCompiler(Compiled): def visit_empty_set_expr(self, element_types): raise NotImplementedError( - "Dialect '%s' does not support empty set expression." % - self.dialect.name + "Dialect '%s' does not support empty set expression." + % self.dialect.name ) - def visit_binary(self, binary, override_operator=None, - eager_grouping=False, **kw): + def visit_binary( + self, binary, override_operator=None, eager_grouping=False, **kw + ): # don't allow "? = ?" to render - if self.ansi_bind_rules and \ - isinstance(binary.left, elements.BindParameter) and \ - isinstance(binary.right, elements.BindParameter): - kw['literal_binds'] = True + if ( + self.ansi_bind_rules + and isinstance(binary.left, elements.BindParameter) + and isinstance(binary.right, elements.BindParameter) + ): + kw["literal_binds"] = True operator_ = override_operator or binary.operator disp = self._get_operator_dispatch(operator_, "binary", None) @@ -1093,36 +1262,50 @@ class SQLCompiler(Compiled): def visit_mod_binary(self, binary, operator, **kw): if self.preparer._double_percents: - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) else: - return self.process(binary.left, **kw) + " % " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) def visit_custom_op_binary(self, element, operator, **kw): - kw['eager_grouping'] = operator.eager_grouping + kw["eager_grouping"] = operator.eager_grouping return self._generate_generic_binary( - element, " " + operator.opstring + " ", **kw) + element, " " + operator.opstring + " ", **kw + ) def visit_custom_op_unary_operator(self, element, operator, **kw): return self._generate_generic_unary_operator( - element, operator.opstring + " ", **kw) + element, operator.opstring + " ", **kw + ) def visit_custom_op_unary_modifier(self, element, operator, **kw): return self._generate_generic_unary_modifier( - element, " " + operator.opstring, **kw) + element, " " + operator.opstring, **kw + ) def _generate_generic_binary( - self, binary, opstring, eager_grouping=False, **kw): + self, binary, opstring, eager_grouping=False, **kw + ): - _in_binary = kw.get('_in_binary', False) + _in_binary = kw.get("_in_binary", False) - kw['_in_binary'] = True - text = binary.left._compiler_dispatch( - self, eager_grouping=eager_grouping, **kw) + \ - opstring + \ - binary.right._compiler_dispatch( - self, eager_grouping=eager_grouping, **kw) + kw["_in_binary"] = True + text = ( + binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + opstring + + binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + ) if _in_binary and eager_grouping: text = "(%s)" % text @@ -1153,17 +1336,13 @@ class SQLCompiler(Compiled): def visit_startswith_op_binary(self, binary, operator, **kw): binary = binary._clone() percent = self._like_percent_literal - binary.right = percent.__radd__( - binary.right - ) + binary.right = percent.__radd__(binary.right) return self.visit_like_op_binary(binary, operator, **kw) def visit_notstartswith_op_binary(self, binary, operator, **kw): binary = binary._clone() percent = self._like_percent_literal - binary.right = percent.__radd__( - binary.right - ) + binary.right = percent.__radd__(binary.right) return self.visit_notlike_op_binary(binary, operator, **kw) def visit_endswith_op_binary(self, binary, operator, **kw): @@ -1182,98 +1361,105 @@ class SQLCompiler(Compiled): escape = binary.modifiers.get("escape", None) # TODO: use ternary here, not "and"/ "or" - return '%s LIKE %s' % ( + return "%s LIKE %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_notlike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return '%s NOT LIKE %s' % ( + return "%s NOT LIKE %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return 'lower(%s) LIKE lower(%s)' % ( + return "lower(%s) LIKE lower(%s)" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return 'lower(%s) NOT LIKE lower(%s)' % ( + return "lower(%s) NOT LIKE lower(%s)" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_between_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) return self._generate_generic_binary( - binary, " BETWEEN SYMMETRIC " - if symmetric else " BETWEEN ", **kw) + binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw + ) def visit_notbetween_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) return self._generate_generic_binary( - binary, " NOT BETWEEN SYMMETRIC " - if symmetric else " NOT BETWEEN ", **kw) + binary, + " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ", + **kw + ) - def visit_bindparam(self, bindparam, within_columns_clause=False, - literal_binds=False, - skip_bind_expression=False, - **kwargs): + def visit_bindparam( + self, + bindparam, + within_columns_clause=False, + literal_binds=False, + skip_bind_expression=False, + **kwargs + ): if not skip_bind_expression: impl = bindparam.type.dialect_impl(self.dialect) if impl._has_bind_expression: bind_expression = impl.bind_expression(bindparam) return self.process( - bind_expression, skip_bind_expression=True, + bind_expression, + skip_bind_expression=True, within_columns_clause=within_columns_clause, literal_binds=literal_binds, **kwargs ) - if literal_binds or \ - (within_columns_clause and - self.ansi_bind_rules): + if literal_binds or (within_columns_clause and self.ansi_bind_rules): if bindparam.value is None and bindparam.callable is None: - raise exc.CompileError("Bind parameter '%s' without a " - "renderable value not allowed here." - % bindparam.key) + raise exc.CompileError( + "Bind parameter '%s' without a " + "renderable value not allowed here." % bindparam.key + ) return self.render_literal_bindparam( - bindparam, within_columns_clause=True, **kwargs) + bindparam, within_columns_clause=True, **kwargs + ) name = self._truncate_bindparam(bindparam) if name in self.binds: existing = self.binds[name] if existing is not bindparam: - if (existing.unique or bindparam.unique) and \ - not existing.proxy_set.intersection( - bindparam.proxy_set): + if ( + existing.unique or bindparam.unique + ) and not existing.proxy_set.intersection(bindparam.proxy_set): raise exc.CompileError( "Bind parameter '%s' conflicts with " - "unique bind parameter of the same name" % - bindparam.key + "unique bind parameter of the same name" + % bindparam.key ) elif existing._is_crud or bindparam._is_crud: raise exc.CompileError( @@ -1282,14 +1468,15 @@ class SQLCompiler(Compiled): "clause of this " "insert/update statement. Please use a " "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." % - (bindparam.key, bindparam.key) + "with insert() or update() (for example, 'b_%s')." + % (bindparam.key, bindparam.key) ) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string( - name, expanding=bindparam.expanding, **kwargs) + name, expanding=bindparam.expanding, **kwargs + ) def render_literal_bindparam(self, bindparam, **kw): value = bindparam.effective_value @@ -1311,7 +1498,8 @@ class SQLCompiler(Compiled): return processor(value) else: raise NotImplementedError( - "Don't know how to literal-quote value %r" % value) + "Don't know how to literal-quote value %r" % value + ) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: @@ -1334,8 +1522,11 @@ class SQLCompiler(Compiled): if len(anonname) > self.label_length - 6: counter = self.truncated_names.get(ident_class, 1) - truncname = anonname[0:max(self.label_length - 6, 0)] + \ - "_" + hex(counter)[2:] + truncname = ( + anonname[0 : max(self.label_length - 6, 0)] + + "_" + + hex(counter)[2:] + ) self.truncated_names[ident_class] = counter + 1 else: truncname = anonname @@ -1346,13 +1537,14 @@ class SQLCompiler(Compiled): return name % self.anon_map def _process_anon(self, key): - (ident, derived) = key.split(' ', 1) + (ident, derived) = key.split(" ", 1) anonymous_counter = self.anon_map.get(derived, 1) self.anon_map[derived] = anonymous_counter + 1 return derived + "_" + str(anonymous_counter) def bindparam_string( - self, name, positional_names=None, expanding=False, **kw): + self, name, positional_names=None, expanding=False, **kw + ): if self.positional: if positional_names is not None: positional_names.append(name) @@ -1362,14 +1554,20 @@ class SQLCompiler(Compiled): self.contains_expanding_parameters = True return "([EXPANDING_%s])" % name else: - return self.bindtemplate % {'name': name} - - def visit_cte(self, cte, asfrom=False, ashint=False, - fromhints=None, visiting_cte=None, - **kwargs): + return self.bindtemplate % {"name": name} + + def visit_cte( + self, + cte, + asfrom=False, + ashint=False, + fromhints=None, + visiting_cte=None, + **kwargs + ): self._init_cte_state() - kwargs['visiting_cte'] = cte + kwargs["visiting_cte"] = cte if isinstance(cte.name, elements._truncated_label): cte_name = self._truncated_identifier("alias", cte.name) else: @@ -1394,8 +1592,8 @@ class SQLCompiler(Compiled): else: raise exc.CompileError( "Multiple, unrelated CTEs found with " - "the same name: %r" % - cte_name) + "the same name: %r" % cte_name + ) if asfrom or is_new_cte: if cte._cte_alias is not None: @@ -1403,7 +1601,8 @@ class SQLCompiler(Compiled): cte_pre_alias_name = cte._cte_alias.name if isinstance(cte_pre_alias_name, elements._truncated_label): cte_pre_alias_name = self._truncated_identifier( - "alias", cte_pre_alias_name) + "alias", cte_pre_alias_name + ) else: pre_alias_cte = cte cte_pre_alias_name = None @@ -1412,11 +1611,17 @@ class SQLCompiler(Compiled): self.ctes_by_name[cte_name] = cte # look for embedded DML ctes and propagate autocommit - if 'autocommit' in cte.element._execution_options and \ - 'autocommit' not in self.execution_options: + if ( + "autocommit" in cte.element._execution_options + and "autocommit" not in self.execution_options + ): self.execution_options = self.execution_options.union( - {"autocommit": - cte.element._execution_options['autocommit']}) + { + "autocommit": cte.element._execution_options[ + "autocommit" + ] + } + ) if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -1432,25 +1637,30 @@ class SQLCompiler(Compiled): col_source = cte.original.selects[0] else: assert False - recur_cols = [c for c in - util.unique_list(col_source.inner_columns) - if c is not None] - - text += "(%s)" % (", ".join( - self.preparer.format_column(ident) - for ident in recur_cols)) + recur_cols = [ + c + for c in util.unique_list(col_source.inner_columns) + if c is not None + ] + + text += "(%s)" % ( + ", ".join( + self.preparer.format_column(ident) + for ident in recur_cols + ) + ) if self.positional: - kwargs['positional_names'] = self.cte_positional[cte] = [] + kwargs["positional_names"] = self.cte_positional[cte] = [] - text += " AS \n" + \ - cte.original._compiler_dispatch( - self, asfrom=True, **kwargs - ) + text += " AS \n" + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) if cte._suffixes: text += " " + self._generate_prefixes( - cte, cte._suffixes, **kwargs) + cte, cte._suffixes, **kwargs + ) self.ctes[cte] = text @@ -1467,9 +1677,15 @@ class SQLCompiler(Compiled): else: return self.preparer.format_alias(cte, cte_name) - def visit_alias(self, alias, asfrom=False, ashint=False, - iscrud=False, - fromhints=None, **kwargs): + def visit_alias( + self, + alias, + asfrom=False, + ashint=False, + iscrud=False, + fromhints=None, + **kwargs + ): if asfrom or ashint: if isinstance(alias.name, elements._truncated_label): alias_name = self._truncated_identifier("alias", alias.name) @@ -1479,31 +1695,35 @@ class SQLCompiler(Compiled): if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: - ret = alias.original._compiler_dispatch(self, - asfrom=True, **kwargs) + \ - self.get_render_as_alias_suffix( - self.preparer.format_alias(alias, alias_name)) + ret = alias.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name) + ) if fromhints and alias in fromhints: - ret = self.format_from_hint_text(ret, alias, - fromhints[alias], iscrud) + ret = self.format_from_hint_text( + ret, alias, fromhints[alias], iscrud + ) return ret else: return alias.original._compiler_dispatch(self, **kwargs) def visit_lateral(self, lateral, **kw): - kw['lateral'] = True + kw["lateral"] = True return "LATERAL %s" % self.visit_alias(lateral, **kw) def visit_tablesample(self, tablesample, asfrom=False, **kw): text = "%s TABLESAMPLE %s" % ( self.visit_alias(tablesample, asfrom=True, **kw), - tablesample._get_method()._compiler_dispatch(self, **kw)) + tablesample._get_method()._compiler_dispatch(self, **kw), + ) if tablesample.seed is not None: text += " REPEATABLE (%s)" % ( - tablesample.seed._compiler_dispatch(self, **kw)) + tablesample.seed._compiler_dispatch(self, **kw) + ) return text @@ -1513,22 +1733,27 @@ class SQLCompiler(Compiled): def _add_to_result_map(self, keyname, name, objects, type_): self._result_columns.append((keyname, name, objects, type_)) - def _label_select_column(self, select, column, - populate_result_map, - asfrom, column_clause_args, - name=None, - within_columns_clause=True): + def _label_select_column( + self, + select, + column, + populate_result_map, + asfrom, + column_clause_args, + name=None, + within_columns_clause=True, + ): """produce labeled columns present in a select().""" impl = column.type.dialect_impl(self.dialect) - if impl._has_column_expression and \ - populate_result_map: + if impl._has_column_expression and populate_result_map: col_expr = impl.column_expression(column) def add_to_result_map(keyname, name, objects, type_): self._add_to_result_map( - keyname, name, - (column,) + objects, type_) + keyname, name, (column,) + objects, type_ + ) + else: col_expr = column if populate_result_map: @@ -1541,58 +1766,56 @@ class SQLCompiler(Compiled): elif isinstance(column, elements.Label): if col_expr is not column: result_expr = _CompileLabel( - col_expr, - column.name, - alt_names=(column.element,) + col_expr, column.name, alt_names=(column.element,) ) else: result_expr = col_expr elif select is not None and name: result_expr = _CompileLabel( + col_expr, name, alt_names=(column._key_label,) + ) + + elif ( + asfrom + and isinstance(column, elements.ColumnClause) + and not column.is_literal + and column.table is not None + and not isinstance(column.table, selectable.Select) + ): + result_expr = _CompileLabel( col_expr, - name, - alt_names=(column._key_label,) - ) - - elif \ - asfrom and \ - isinstance(column, elements.ColumnClause) and \ - not column.is_literal and \ - column.table is not None and \ - not isinstance(column.table, selectable.Select): - result_expr = _CompileLabel(col_expr, - elements._as_truncated(column.name), - alt_names=(column.key,)) + elements._as_truncated(column.name), + alt_names=(column.key,), + ) elif ( - not isinstance(column, elements.TextClause) and - ( - not isinstance(column, elements.UnaryExpression) or - column.wraps_column_expression - ) and - ( - not hasattr(column, 'name') or - isinstance(column, functions.Function) + not isinstance(column, elements.TextClause) + and ( + not isinstance(column, elements.UnaryExpression) + or column.wraps_column_expression + ) + and ( + not hasattr(column, "name") + or isinstance(column, functions.Function) ) ): result_expr = _CompileLabel(col_expr, column.anon_label) elif col_expr is not column: # TODO: are we sure "column" has a .name and .key here ? # assert isinstance(column, elements.ColumnClause) - result_expr = _CompileLabel(col_expr, - elements._as_truncated(column.name), - alt_names=(column.key,)) + result_expr = _CompileLabel( + col_expr, + elements._as_truncated(column.name), + alt_names=(column.key,), + ) else: result_expr = col_expr column_clause_args.update( within_columns_clause=within_columns_clause, - add_to_result_map=add_to_result_map - ) - return result_expr._compiler_dispatch( - self, - **column_clause_args + add_to_result_map=add_to_result_map, ) + return result_expr._compiler_dispatch(self, **column_clause_args) def format_from_hint_text(self, sqltext, table, hint, iscrud): hinttext = self.get_from_hint_text(table, hint) @@ -1631,8 +1854,11 @@ class SQLCompiler(Compiled): newelem = cloned[element] = element._clone() - if newelem.is_selectable and newelem._is_join and \ - isinstance(newelem.right, selectable.FromGrouping): + if ( + newelem.is_selectable + and newelem._is_join + and isinstance(newelem.right, selectable.FromGrouping) + ): newelem._reset_exported() newelem.left = visit(newelem.left, **kw) @@ -1640,8 +1866,8 @@ class SQLCompiler(Compiled): right = visit(newelem.right, **kw) selectable_ = selectable.Select( - [right.element], - use_labels=True).alias() + [right.element], use_labels=True + ).alias() for c in selectable_.c: c._key_label = c.key @@ -1680,17 +1906,18 @@ class SQLCompiler(Compiled): elif newelem._is_from_container: # if we hit an Alias, CompoundSelect or ScalarSelect, put a # marker in the stack. - kw['transform_clue'] = 'select_container' + kw["transform_clue"] = "select_container" newelem._copy_internals(clone=visit, **kw) elif newelem.is_selectable and newelem._is_select: - barrier_select = kw.get('transform_clue', None) == \ - 'select_container' + barrier_select = ( + kw.get("transform_clue", None) == "select_container" + ) # if we're still descended from an # Alias/CompoundSelect/ScalarSelect, we're # in a FROM clause, so start with a new translate collection if barrier_select: column_translate.append({}) - kw['transform_clue'] = 'inside_select' + kw["transform_clue"] = "inside_select" newelem._copy_internals(clone=visit, **kw) if barrier_select: del column_translate[-1] @@ -1702,24 +1929,22 @@ class SQLCompiler(Compiled): return visit(select) def _transform_result_map_for_nested_joins( - self, select, transformed_select): - inner_col = dict((c._key_label, c) for - c in transformed_select.inner_columns) - - d = dict( - (inner_col[c._key_label], c) - for c in select.inner_columns + self, select, transformed_select + ): + inner_col = dict( + (c._key_label, c) for c in transformed_select.inner_columns ) + d = dict((inner_col[c._key_label], c) for c in select.inner_columns) + self._result_columns = [ (key, name, tuple([d.get(col, col) for col in objs]), typ) for key, name, objs, typ in self._result_columns ] - _default_stack_entry = util.immutabledict([ - ('correlate_froms', frozenset()), - ('asfrom_froms', frozenset()) - ]) + _default_stack_entry = util.immutabledict( + [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] + ) def _display_froms_for_select(self, select, asfrom, lateral=False): # utility method to help external dialects @@ -1729,72 +1954,88 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - correlate_froms = entry['correlate_froms'] - asfrom_froms = entry['asfrom_froms'] + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] if asfrom and not lateral: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms.difference( - asfrom_froms), - implicit_correlate_froms=()) + asfrom_froms + ), + implicit_correlate_froms=(), + ) else: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) + implicit_correlate_froms=asfrom_froms, + ) return froms - def visit_select(self, select, asfrom=False, parens=True, - fromhints=None, - compound_index=0, - nested_join_translation=False, - select_wraps_for=None, - lateral=False, - **kwargs): - - needs_nested_translation = \ - select.use_labels and \ - not nested_join_translation and \ - not self.stack and \ - not self.dialect.supports_right_nested_joins + def visit_select( + self, + select, + asfrom=False, + parens=True, + fromhints=None, + compound_index=0, + nested_join_translation=False, + select_wraps_for=None, + lateral=False, + **kwargs + ): + + needs_nested_translation = ( + select.use_labels + and not nested_join_translation + and not self.stack + and not self.dialect.supports_right_nested_joins + ) if needs_nested_translation: transformed_select = self._transform_select_for_nested_joins( - select) + select + ) text = self.visit_select( - transformed_select, asfrom=asfrom, parens=parens, + transformed_select, + asfrom=asfrom, + parens=parens, fromhints=fromhints, compound_index=compound_index, - nested_join_translation=True, **kwargs + nested_join_translation=True, + **kwargs ) toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = toplevel or \ - ( - compound_index == 0 and entry.get( - 'need_result_map_for_compound', False) - ) or entry.get('need_result_map_for_nested', False) + populate_result_map = ( + toplevel + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) # this was first proposed as part of #3372; however, it is not # reached in current tests and could possibly be an assertion # instead. - if not populate_result_map and 'add_to_result_map' in kwargs: - del kwargs['add_to_result_map'] + if not populate_result_map and "add_to_result_map" in kwargs: + del kwargs["add_to_result_map"] if needs_nested_translation: if populate_result_map: self._transform_result_map_for_nested_joins( - select, transformed_select) + select, transformed_select + ) return text froms = self._setup_select_stack(select, entry, asfrom, lateral) column_clause_args = kwargs.copy() - column_clause_args.update({ - 'within_label_clause': False, - 'within_columns_clause': False - }) + column_clause_args.update( + {"within_label_clause": False, "within_columns_clause": False} + ) text = "SELECT " # we're off to a good start ! @@ -1806,19 +2047,21 @@ class SQLCompiler(Compiled): byfrom = None if select._prefixes: - text += self._generate_prefixes( - select, select._prefixes, **kwargs) + text += self._generate_prefixes(select, select._prefixes, **kwargs) text += self.get_select_precolumns(select, **kwargs) # the actual list of columns to print in the SELECT column list. inner_columns = [ - c for c in [ + c + for c in [ self._label_select_column( select, column, - populate_result_map, asfrom, + populate_result_map, + asfrom, column_clause_args, - name=name) + name=name, + ) for name, column in select._columns_plus_names ] if c is not None @@ -1831,8 +2074,11 @@ class SQLCompiler(Compiled): translate = dict( zip( [name for (key, name) in select._columns_plus_names], - [name for (key, name) in - select_wraps_for._columns_plus_names]) + [ + name + for (key, name) in select_wraps_for._columns_plus_names + ], + ) ) self._result_columns = [ @@ -1841,13 +2087,14 @@ class SQLCompiler(Compiled): ] text = self._compose_select_body( - text, select, inner_columns, froms, byfrom, kwargs) + text, select, inner_columns, froms, byfrom, kwargs + ) if select._statement_hints: per_dialect = [ - ht for (dialect_name, ht) - in select._statement_hints - if dialect_name in ('*', self.dialect.name) + ht + for (dialect_name, ht) in select._statement_hints + if dialect_name in ("*", self.dialect.name) ] if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) @@ -1857,7 +2104,8 @@ class SQLCompiler(Compiled): if select._suffixes: text += " " + self._generate_prefixes( - select, select._suffixes, **kwargs) + select, select._suffixes, **kwargs + ) self.stack.pop(-1) @@ -1867,60 +2115,73 @@ class SQLCompiler(Compiled): return text def _setup_select_hints(self, select): - byfrom = dict([ - (from_, hinttext % { - 'name': from_._compiler_dispatch( - self, ashint=True) - }) - for (from_, dialect), hinttext in - select._hints.items() - if dialect in ('*', self.dialect.name) - ]) + byfrom = dict( + [ + ( + from_, + hinttext + % {"name": from_._compiler_dispatch(self, ashint=True)}, + ) + for (from_, dialect), hinttext in select._hints.items() + if dialect in ("*", self.dialect.name) + ] + ) hint_text = self.get_select_hint_text(byfrom) return hint_text, byfrom def _setup_select_stack(self, select, entry, asfrom, lateral): - correlate_froms = entry['correlate_froms'] - asfrom_froms = entry['asfrom_froms'] + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] if asfrom and not lateral: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms.difference( - asfrom_froms), - implicit_correlate_froms=()) + asfrom_froms + ), + implicit_correlate_froms=(), + ) else: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) + implicit_correlate_froms=asfrom_froms, + ) new_correlate_froms = set(selectable._from_objects(*froms)) all_correlate_froms = new_correlate_froms.union(correlate_froms) new_entry = { - 'asfrom_froms': new_correlate_froms, - 'correlate_froms': all_correlate_froms, - 'selectable': select, + "asfrom_froms": new_correlate_froms, + "correlate_froms": all_correlate_froms, + "selectable": select, } self.stack.append(new_entry) return froms def _compose_select_body( - self, text, select, inner_columns, froms, byfrom, kwargs): - text += ', '.join(inner_columns) + self, text, select, inner_columns, froms, byfrom, kwargs + ): + text += ", ".join(inner_columns) if froms: text += " \nFROM " if select._hints: - text += ', '.join( - [f._compiler_dispatch(self, asfrom=True, - fromhints=byfrom, **kwargs) - for f in froms]) + text += ", ".join( + [ + f._compiler_dispatch( + self, asfrom=True, fromhints=byfrom, **kwargs + ) + for f in froms + ] + ) else: - text += ', '.join( - [f._compiler_dispatch(self, asfrom=True, **kwargs) - for f in froms]) + text += ", ".join( + [ + f._compiler_dispatch(self, asfrom=True, **kwargs) + for f in froms + ] + ) else: text += self.default_from() @@ -1940,8 +2201,10 @@ class SQLCompiler(Compiled): if select._order_by_clause.clauses: text += self.order_by_clause(select, **kwargs) - if (select._limit_clause is not None or - select._offset_clause is not None): + if ( + select._limit_clause is not None + or select._offset_clause is not None + ): text += self.limit_clause(select, **kwargs) if select._for_update_arg is not None: @@ -1953,8 +2216,7 @@ class SQLCompiler(Compiled): clause = " ".join( prefix._compiler_dispatch(self, **kw) for prefix, dialect_name in prefixes - if dialect_name is None or - dialect_name == self.dialect.name + if dialect_name is None or dialect_name == self.dialect.name ) if clause: clause += " " @@ -1962,14 +2224,12 @@ class SQLCompiler(Compiled): def _render_cte_clause(self): if self.positional: - self.positiontup = sum([ - self.cte_positional[cte] - for cte in self.ctes], []) + \ - self.positiontup + self.positiontup = ( + sum([self.cte_positional[cte] for cte in self.ctes], []) + + self.positiontup + ) cte_text = self.get_cte_preamble(self.ctes_recursive) + " " - cte_text += ", \n".join( - [txt for txt in self.ctes.values()] - ) + cte_text += ", \n".join([txt for txt in self.ctes.values()]) cte_text += "\n " return cte_text @@ -2010,7 +2270,8 @@ class SQLCompiler(Compiled): def returning_clause(self, stmt, returning_cols): raise exc.CompileError( "RETURNING is not supported by this " - "dialect's statement compiler.") + "dialect's statement compiler." + ) def limit_clause(self, select, **kw): text = "" @@ -2022,19 +2283,31 @@ class SQLCompiler(Compiled): text += " OFFSET " + self.process(select._offset_clause, **kw) return text - def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, - fromhints=None, use_schema=True, **kwargs): + def visit_table( + self, + table, + asfrom=False, + iscrud=False, + ashint=False, + fromhints=None, + use_schema=True, + **kwargs + ): if asfrom or ashint: effective_schema = self.preparer.schema_for_object(table) if use_schema and effective_schema: - ret = self.preparer.quote_schema(effective_schema) + \ - "." + self.preparer.quote(table.name) + ret = ( + self.preparer.quote_schema(effective_schema) + + "." + + self.preparer.quote(table.name) + ) else: ret = self.preparer.quote(table.name) if fromhints and table in fromhints: - ret = self.format_from_hint_text(ret, table, - fromhints[table], iscrud) + ret = self.format_from_hint_text( + ret, table, fromhints[table], iscrud + ) return ret else: return "" @@ -2047,26 +2320,24 @@ class SQLCompiler(Compiled): else: join_type = " JOIN " return ( - join.left._compiler_dispatch(self, asfrom=True, **kwargs) + - join_type + - join.right._compiler_dispatch(self, asfrom=True, **kwargs) + - " ON " + - join.onclause._compiler_dispatch(self, **kwargs) + join.left._compiler_dispatch(self, asfrom=True, **kwargs) + + join_type + + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + " ON " + + join.onclause._compiler_dispatch(self, **kwargs) ) def _setup_crud_hints(self, stmt, table_text): - dialect_hints = dict([ - (table, hint_text) - for (table, dialect), hint_text in - stmt._hints.items() - if dialect in ('*', self.dialect.name) - ]) + dialect_hints = dict( + [ + (table, hint_text) + for (table, dialect), hint_text in stmt._hints.items() + if dialect in ("*", self.dialect.name) + ] + ) if stmt.table in dialect_hints: table_text = self.format_from_hint_text( - table_text, - stmt.table, - dialect_hints[stmt.table], - True + table_text, stmt.table, dialect_hints[stmt.table], True ) return dialect_hints, table_text @@ -2074,28 +2345,35 @@ class SQLCompiler(Compiled): toplevel = not self.stack self.stack.append( - {'correlate_froms': set(), - "asfrom_froms": set(), - "selectable": insert_stmt}) + { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": insert_stmt, + } + ) crud_params = crud._setup_crud_params( - self, insert_stmt, crud.ISINSERT, **kw) + self, insert_stmt, crud.ISINSERT, **kw + ) - if not crud_params and \ - not self.dialect.supports_default_values and \ - not self.dialect.supports_empty_insert: - raise exc.CompileError("The '%s' dialect with current database " - "version settings does not support empty " - "inserts." % - self.dialect.name) + if ( + not crud_params + and not self.dialect.supports_default_values + and not self.dialect.supports_empty_insert + ): + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support empty " + "inserts." % self.dialect.name + ) if insert_stmt._has_multi_parameters: if not self.dialect.supports_multivalues_insert: raise exc.CompileError( "The '%s' dialect with current database " "version settings does not support " - "in-place multirow inserts." % - self.dialect.name) + "in-place multirow inserts." % self.dialect.name + ) crud_params_single = crud_params[0] else: crud_params_single = crud_params @@ -2106,27 +2384,31 @@ class SQLCompiler(Compiled): text = "INSERT " if insert_stmt._prefixes: - text += self._generate_prefixes(insert_stmt, - insert_stmt._prefixes, **kw) + text += self._generate_prefixes( + insert_stmt, insert_stmt._prefixes, **kw + ) text += "INTO " table_text = preparer.format_table(insert_stmt.table) if insert_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - insert_stmt, table_text) + insert_stmt, table_text + ) else: dialect_hints = None text += table_text if crud_params_single or not supports_default_values: - text += " (%s)" % ', '.join([preparer.format_column(c[0]) - for c in crud_params_single]) + text += " (%s)" % ", ".join( + [preparer.format_column(c[0]) for c in crud_params_single] + ) if self.returning or insert_stmt._returning: returning_clause = self.returning_clause( - insert_stmt, self.returning or insert_stmt._returning) + insert_stmt, self.returning or insert_stmt._returning + ) if self.returning_precedes_values: text += " " + returning_clause @@ -2145,19 +2427,17 @@ class SQLCompiler(Compiled): elif insert_stmt._has_multi_parameters: text += " VALUES %s" % ( ", ".join( - "(%s)" % ( - ', '.join(c[1] for c in crud_param_set) - ) + "(%s)" % (", ".join(c[1] for c in crud_param_set)) for crud_param_set in crud_params ) ) else: - text += " VALUES (%s)" % \ - ', '.join([c[1] for c in crud_params]) + text += " VALUES (%s)" % ", ".join([c[1] for c in crud_params]) if insert_stmt._post_values_clause is not None: post_values_clause = self.process( - insert_stmt._post_values_clause, **kw) + insert_stmt._post_values_clause, **kw + ) if post_values_clause: text += " " + post_values_clause @@ -2178,21 +2458,19 @@ class SQLCompiler(Compiled): """Provide a hook for MySQL to add LIMIT to the UPDATE""" return None - def update_tables_clause(self, update_stmt, from_table, - extra_froms, **kw): + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): """Provide a hook to override the initial table clause in an UPDATE statement. MySQL overrides this. """ - kw['asfrom'] = True + kw["asfrom"] = True return from_table._compiler_dispatch(self, iscrud=True, **kw) - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Provide a hook to override the generation of an UPDATE..FROM clause. @@ -2201,7 +2479,8 @@ class SQLCompiler(Compiled): """ raise NotImplementedError( "This backend does not support multiple-table " - "criteria within UPDATE") + "criteria within UPDATE" + ) def visit_update(self, update_stmt, asfrom=False, **kw): toplevel = not self.stack @@ -2221,49 +2500,61 @@ class SQLCompiler(Compiled): correlate_froms = {update_stmt.table} self.stack.append( - {'correlate_froms': correlate_froms, - "asfrom_froms": correlate_froms, - "selectable": update_stmt}) + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": update_stmt, + } + ) text = "UPDATE " if update_stmt._prefixes: - text += self._generate_prefixes(update_stmt, - update_stmt._prefixes, **kw) + text += self._generate_prefixes( + update_stmt, update_stmt._prefixes, **kw + ) - table_text = self.update_tables_clause(update_stmt, update_stmt.table, - render_extra_froms, **kw) + table_text = self.update_tables_clause( + update_stmt, update_stmt.table, render_extra_froms, **kw + ) crud_params = crud._setup_crud_params( - self, update_stmt, crud.ISUPDATE, **kw) + self, update_stmt, crud.ISUPDATE, **kw + ) if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - update_stmt, table_text) + update_stmt, table_text + ) else: dialect_hints = None text += table_text - text += ' SET ' - include_table = is_multitable and \ - self.render_table_with_column_in_update_from - text += ', '.join( - c[0]._compiler_dispatch(self, - include_table=include_table) + - '=' + c[1] for c in crud_params + text += " SET " + include_table = ( + is_multitable and self.render_table_with_column_in_update_from + ) + text += ", ".join( + c[0]._compiler_dispatch(self, include_table=include_table) + + "=" + + c[1] + for c in crud_params ) if self.returning or update_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning) + update_stmt, self.returning or update_stmt._returning + ) if extra_froms: extra_from_text = self.update_from_clause( update_stmt, update_stmt.table, render_extra_froms, - dialect_hints, **kw) + dialect_hints, + **kw + ) if extra_from_text: text += " " + extra_from_text @@ -2276,10 +2567,12 @@ class SQLCompiler(Compiled): if limit_clause: text += " " + limit_clause - if (self.returning or update_stmt._returning) and \ - not self.returning_precedes_values: + if ( + self.returning or update_stmt._returning + ) and not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning) + update_stmt, self.returning or update_stmt._returning + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -2295,9 +2588,9 @@ class SQLCompiler(Compiled): def _key_getters_for_crud_column(self): return crud._key_getters_for_crud_column(self, self.statement) - def delete_extra_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, **kw): + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Provide a hook to override the generation of an DELETE..FROM clause. @@ -2308,10 +2601,10 @@ class SQLCompiler(Compiled): """ raise NotImplementedError( "This backend does not support multiple-table " - "criteria within DELETE") + "criteria within DELETE" + ) - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) def visit_delete(self, delete_stmt, asfrom=False, **kw): @@ -2322,23 +2615,30 @@ class SQLCompiler(Compiled): extra_froms = delete_stmt._extra_froms correlate_froms = {delete_stmt.table}.union(extra_froms) - self.stack.append({'correlate_froms': correlate_froms, - "asfrom_froms": correlate_froms, - "selectable": delete_stmt}) + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": delete_stmt, + } + ) text = "DELETE " if delete_stmt._prefixes: - text += self._generate_prefixes(delete_stmt, - delete_stmt._prefixes, **kw) + text += self._generate_prefixes( + delete_stmt, delete_stmt._prefixes, **kw + ) text += "FROM " - table_text = self.delete_table_clause(delete_stmt, delete_stmt.table, - extra_froms) + table_text = self.delete_table_clause( + delete_stmt, delete_stmt.table, extra_froms + ) if delete_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - delete_stmt, table_text) + delete_stmt, table_text + ) else: dialect_hints = None @@ -2347,14 +2647,17 @@ class SQLCompiler(Compiled): if delete_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning) + delete_stmt, delete_stmt._returning + ) if extra_froms: extra_from_text = self.delete_extra_from_clause( delete_stmt, delete_stmt.table, extra_froms, - dialect_hints, **kw) + dialect_hints, + **kw + ) if extra_from_text: text += " " + extra_from_text @@ -2365,7 +2668,8 @@ class SQLCompiler(Compiled): if delete_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning) + delete_stmt, delete_stmt._returning + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -2381,12 +2685,14 @@ class SQLCompiler(Compiled): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): - return "ROLLBACK TO SAVEPOINT %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_release_savepoint(self, savepoint_stmt): - return "RELEASE SAVEPOINT %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) class StrSQLCompiler(SQLCompiler): @@ -2403,7 +2709,7 @@ class StrSQLCompiler(SQLCompiler): def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_json_getitem_op_binary(self, binary, operator, **kw): @@ -2421,29 +2727,26 @@ class StrSQLCompiler(SQLCompiler): for c in elements._select_iterables(returning_cols) ] - return 'RETURNING ' + ', '.join(columns) + return "RETURNING " + ", ".join(columns) - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) - def delete_extra_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): - return ', ' + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return ", " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) class DDLCompiler(Compiled): - @util.memoized_property def sql_compiler(self): return self.dialect.statement_compiler(self.dialect, None) @@ -2464,13 +2767,13 @@ class DDLCompiler(Compiled): preparer = self.preparer path = preparer.format_table_seq(ddl.target) if len(path) == 1: - table, sch = path[0], '' + table, sch = path[0], "" else: table, sch = path[-1], path[0] - context.setdefault('table', table) - context.setdefault('schema', sch) - context.setdefault('fullname', preparer.format_table(ddl.target)) + context.setdefault("table", table) + context.setdefault("schema", sch) + context.setdefault("fullname", preparer.format_table(ddl.target)) return self.sql_compiler.post_process_text(ddl.statement % context) @@ -2507,9 +2810,9 @@ class DDLCompiler(Compiled): for create_column in create.columns: column = create_column.element try: - processed = self.process(create_column, - first_pk=column.primary_key - and not first_pk) + processed = self.process( + create_column, first_pk=column.primary_key and not first_pk + ) if processed is not None: text += separator separator = ", \n" @@ -2519,13 +2822,15 @@ class DDLCompiler(Compiled): except exc.CompileError as ce: util.raise_from_cause( exc.CompileError( - util.u("(in table '%s', column '%s'): %s") % - (table.description, column.name, ce.args[0]) - )) + util.u("(in table '%s', column '%s'): %s") + % (table.description, column.name, ce.args[0]) + ) + ) const = self.create_table_constraints( - table, _include_foreign_key_constraints= # noqa - create.include_foreign_key_constraints) + table, + _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa + ) if const: text += separator + "\t" + const @@ -2538,20 +2843,18 @@ class DDLCompiler(Compiled): if column.system: return None - text = self.get_column_specification( - column, - first_pk=first_pk + text = self.get_column_specification(column, first_pk=first_pk) + const = " ".join( + self.process(constraint) for constraint in column.constraints ) - const = " ".join(self.process(constraint) - for constraint in column.constraints) if const: text += " " + const return text def create_table_constraints( - self, table, - _include_foreign_key_constraints=None): + self, table, _include_foreign_key_constraints=None + ): # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) @@ -2565,21 +2868,29 @@ class DDLCompiler(Compiled): else: omit_fkcs = set() - constraints.extend([c for c in table._sorted_constraints - if c is not table.primary_key and - c not in omit_fkcs]) + constraints.extend( + [ + c + for c in table._sorted_constraints + if c is not table.primary_key and c not in omit_fkcs + ] + ) return ", \n\t".join( - p for p in - (self.process(constraint) + p + for p in ( + self.process(constraint) for constraint in constraints if ( - constraint._create_rule is None or - constraint._create_rule(self)) + constraint._create_rule is None + or constraint._create_rule(self) + ) and ( - not self.dialect.supports_alter or - not getattr(constraint, 'use_alter', False) - )) if p is not None + not self.dialect.supports_alter + or not getattr(constraint, "use_alter", False) + ) + ) + if p is not None ) def visit_drop_table(self, drop): @@ -2590,34 +2901,38 @@ class DDLCompiler(Compiled): def _verify_index_table(self, index): if index.table is None: - raise exc.CompileError("Index '%s' is not associated " - "with any table." % index.name) + raise exc.CompileError( + "Index '%s' is not associated " "with any table." % index.name + ) - def visit_create_index(self, create, include_schema=False, - include_table_schema=True): + def visit_create_index( + self, create, include_schema=False, include_table_schema=True + ): index = create.element self._verify_index_table(index) preparer = self.preparer text = "CREATE " if index.unique: text += "UNIQUE " - text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=include_schema), - preparer.format_table(index.table, - use_schema=include_table_schema), - ', '.join( - self.sql_compiler.process( - expr, include_table=False, literal_binds=True) for - expr in index.expressions) - ) + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table( + index.table, use_schema=include_table_schema + ), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) return text def visit_drop_index(self, drop): index = drop.element return "\nDROP INDEX " + self._prepared_index_name( - index, include_schema=True) + index, include_schema=True + ) def _prepared_index_name(self, index, include_schema=False): if index.table is not None: @@ -2638,35 +2953,41 @@ class DDLCompiler(Compiled): def visit_add_constraint(self, create): return "ALTER TABLE %s ADD %s" % ( self.preparer.format_table(create.element.table), - self.process(create.element) + self.process(create.element), ) def visit_set_table_comment(self, create): return "COMMENT ON TABLE %s IS %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_drop_table_comment(self, drop): - return "COMMENT ON TABLE %s IS NULL" % \ - self.preparer.format_table(drop.element) + return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( + drop.element + ) def visit_set_column_comment(self, create): return "COMMENT ON COLUMN %s IS %s" % ( self.preparer.format_column( - create.element, use_table=True, use_schema=True), + create.element, use_table=True, use_schema=True + ), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_drop_column_comment(self, drop): - return "COMMENT ON COLUMN %s IS NULL" % \ - self.preparer.format_column(drop.element, use_table=True) + return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column( + drop.element, use_table=True + ) def visit_create_sequence(self, create): - text = "CREATE SEQUENCE %s" % \ - self.preparer.format_sequence(create.element) + text = "CREATE SEQUENCE %s" % self.preparer.format_sequence( + create.element + ) if create.element.increment is not None: text += " INCREMENT BY %d" % create.element.increment if create.element.start is not None: @@ -2688,8 +3009,7 @@ class DDLCompiler(Compiled): return text def visit_drop_sequence(self, drop): - return "DROP SEQUENCE %s" % \ - self.preparer.format_sequence(drop.element) + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) def visit_drop_constraint(self, drop): constraint = drop.element @@ -2701,17 +3021,22 @@ class DDLCompiler(Compiled): if formatted_name is None: raise exc.CompileError( "Can't emit DROP CONSTRAINT for constraint %r; " - "it has no name" % drop.element) + "it has no name" % drop.element + ) return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( self.preparer.format_table(drop.element.table), formatted_name, - drop.cascade and " CASCADE" or "" + drop.cascade and " CASCADE" or "", ) def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process( - column.type, type_expression=column) + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler.process( + column.type, type_expression=column + ) + ) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -2721,19 +3046,21 @@ class DDLCompiler(Compiled): return colspec def create_table_suffix(self, table): - return '' + return "" def post_create_table(self, table): - return '' + return "" def get_column_default_string(self, column): if isinstance(column.server_default, schema.DefaultClause): if isinstance(column.server_default.arg, util.string_types): return self.sql_compiler.render_literal_value( - column.server_default.arg, sqltypes.STRINGTYPE) + column.server_default.arg, sqltypes.STRINGTYPE + ) else: return self.sql_compiler.process( - column.server_default.arg, literal_binds=True) + column.server_default.arg, literal_binds=True + ) else: return None @@ -2743,9 +3070,9 @@ class DDLCompiler(Compiled): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, - include_table=False, - literal_binds=True) + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) text += self.define_constraint_deferrability(constraint) return text @@ -2755,25 +3082,29 @@ class DDLCompiler(Compiled): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, - include_table=False, - literal_binds=True) + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) text += self.define_constraint_deferrability(constraint) return text def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name text += "PRIMARY KEY " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in (constraint.columns_autoinc_first - if constraint._implicit_generated - else constraint.columns)) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) + for c in ( + constraint.columns_autoinc_first + if constraint._implicit_generated + else constraint.columns + ) + ) text += self.define_constraint_deferrability(constraint) return text @@ -2786,12 +3117,15 @@ class DDLCompiler(Compiled): text += "CONSTRAINT %s " % formatted_name remote_table = list(constraint.elements)[0].column.table text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - ', '.join(preparer.quote(f.parent.name) - for f in constraint.elements), + ", ".join( + preparer.quote(f.parent.name) for f in constraint.elements + ), self.define_constraint_remote_table( - constraint, remote_table, preparer), - ', '.join(preparer.quote(f.column.name) - for f in constraint.elements) + constraint, remote_table, preparer + ), + ", ".join( + preparer.quote(f.column.name) for f in constraint.elements + ), ) text += self.define_constraint_match(constraint) text += self.define_constraint_cascades(constraint) @@ -2805,14 +3139,14 @@ class DDLCompiler(Compiled): def visit_unique_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) text += "CONSTRAINT %s " % formatted_name text += "UNIQUE (%s)" % ( - ', '.join(self.preparer.quote(c.name) - for c in constraint)) + ", ".join(self.preparer.quote(c.name) for c in constraint) + ) text += self.define_constraint_deferrability(constraint) return text @@ -2843,7 +3177,6 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_, **kw): return "FLOAT" @@ -2854,23 +3187,23 @@ class GenericTypeCompiler(TypeCompiler): if type_.precision is None: return "NUMERIC" elif type_.scale is None: - return "NUMERIC(%(precision)s)" % \ - {'precision': type_.precision} + return "NUMERIC(%(precision)s)" % {"precision": type_.precision} else: - return "NUMERIC(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, - 'scale': type_.scale} + return "NUMERIC(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return "DECIMAL" elif type_.scale is None: - return "DECIMAL(%(precision)s)" % \ - {'precision': type_.precision} + return "DECIMAL(%(precision)s)" % {"precision": type_.precision} else: - return "DECIMAL(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, - 'scale': type_.scale} + return "DECIMAL(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } def visit_INTEGER(self, type_, **kw): return "INTEGER" @@ -2882,7 +3215,7 @@ class GenericTypeCompiler(TypeCompiler): return "BIGINT" def visit_TIMESTAMP(self, type_, **kw): - return 'TIMESTAMP' + return "TIMESTAMP" def visit_DATETIME(self, type_, **kw): return "DATETIME" @@ -2984,9 +3317,11 @@ class GenericTypeCompiler(TypeCompiler): return self.visit_VARCHAR(type_, **kw) def visit_null(self, type_, **kw): - raise exc.CompileError("Can't generate DDL for %r; " - "did you forget to specify a " - "type on this Column?" % type_) + raise exc.CompileError( + "Can't generate DDL for %r; " + "did you forget to specify a " + "type on this Column?" % type_ + ) def visit_type_decorator(self, type_, **kw): return self.process(type_.type_engine(self.dialect), **kw) @@ -3018,9 +3353,15 @@ class IdentifierPreparer(object): schema_for_object = schema._schema_getter(None) - def __init__(self, dialect, initial_quote='"', - final_quote=None, escape_quote='"', - quote_case_sensitive_collations=True, omit_schema=False): + def __init__( + self, + dialect, + initial_quote='"', + final_quote=None, + escape_quote='"', + quote_case_sensitive_collations=True, + omit_schema=False, + ): """Construct a new ``IdentifierPreparer`` object. initial_quote @@ -3043,7 +3384,10 @@ class IdentifierPreparer(object): self.omit_schema = omit_schema self.quote_case_sensitive_collations = quote_case_sensitive_collations self._strings = {} - self._double_percents = self.dialect.paramstyle in ('format', 'pyformat') + self._double_percents = self.dialect.paramstyle in ( + "format", + "pyformat", + ) def _with_schema_translate(self, schema_translate_map): prep = self.__class__.__new__(self.__class__) @@ -3060,7 +3404,7 @@ class IdentifierPreparer(object): value = value.replace(self.escape_quote, self.escape_to_quote) if self._double_percents: - value = value.replace('%', '%%') + value = value.replace("%", "%%") return value def _unescape_identifier(self, value): @@ -3079,17 +3423,21 @@ class IdentifierPreparer(object): quoting behavior. """ - return self.initial_quote + \ - self._escape_identifier(value) + \ - self.final_quote + return ( + self.initial_quote + + self._escape_identifier(value) + + self.final_quote + ) def _requires_quotes(self, value): """Return True if the given identifier requires quoting.""" lc_value = value.lower() - return (lc_value in self.reserved_words - or value[0] in self.illegal_initial_characters - or not self.legal_characters.match(util.text_type(value)) - or (lc_value != value)) + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(util.text_type(value)) + or (lc_value != value) + ) def quote_schema(self, schema, force=None): """Conditionally quote a schema. @@ -3135,8 +3483,11 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(sequence) - if (not self.omit_schema and use_schema and - effective_schema is not None): + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): name = self.quote_schema(effective_schema) + "." + name return name @@ -3159,7 +3510,8 @@ class IdentifierPreparer(object): def format_constraint(self, naming, constraint): if isinstance(constraint.name, elements._defer_name): name = naming._constraint_name_for_table( - constraint, constraint.table) + constraint, constraint.table + ) if name is None: if isinstance(constraint.name, elements._defer_none_name): @@ -3170,14 +3522,15 @@ class IdentifierPreparer(object): name = constraint.name if isinstance(name, elements._truncated_label): - if constraint.__visit_name__ == 'index': - max_ = self.dialect.max_index_name_length or \ - self.dialect.max_identifier_length + if constraint.__visit_name__ == "index": + max_ = ( + self.dialect.max_index_name_length + or self.dialect.max_identifier_length + ) else: max_ = self.dialect.max_identifier_length if len(name) > max_: - name = name[0:max_ - 8] + \ - "_" + util.md5_hex(name)[-4:] + name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:] else: self.dialect.validate_identifier(name) @@ -3195,8 +3548,7 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(table) - if not self.omit_schema and use_schema \ - and effective_schema: + if not self.omit_schema and use_schema and effective_schema: result = self.quote_schema(effective_schema) + "." + result return result @@ -3205,17 +3557,27 @@ class IdentifierPreparer(object): return self.quote(name, quote) - def format_column(self, column, use_table=False, - name=None, table_name=None, use_schema=False): + def format_column( + self, + column, + use_table=False, + name=None, + table_name=None, + use_schema=False, + ): """Prepare a quoted column name.""" if name is None: name = column.name - if not getattr(column, 'is_literal', False): + if not getattr(column, "is_literal", False): if use_table: - return self.format_table( - column.table, use_schema=use_schema, - name=table_name) + "." + self.quote(name) + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + self.quote(name) + ) else: return self.quote(name) else: @@ -3223,9 +3585,13 @@ class IdentifierPreparer(object): # which shouldn't get quoted if use_table: - return self.format_table( - column.table, use_schema=use_schema, - name=table_name) + '.' + name + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + name + ) else: return name @@ -3238,31 +3604,37 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(table) - if not self.omit_schema and use_schema and \ - effective_schema: - return (self.quote_schema(effective_schema), - self.format_table(table, use_schema=False)) + if not self.omit_schema and use_schema and effective_schema: + return ( + self.quote_schema(effective_schema), + self.format_table(table, use_schema=False), + ) else: - return (self.format_table(table, use_schema=False), ) + return (self.format_table(table, use_schema=False),) @util.memoized_property def _r_identifiers(self): - initial, final, escaped_final = \ - [re.escape(s) for s in - (self.initial_quote, self.final_quote, - self._escape_identifier(self.final_quote))] + initial, final, escaped_final = [ + re.escape(s) + for s in ( + self.initial_quote, + self.final_quote, + self._escape_identifier(self.final_quote), + ) + ] r = re.compile( - r'(?:' - r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' - r'|([^\.]+))(?=\.|$))+' % - {'initial': initial, - 'final': final, - 'escaped': escaped_final}) + r"(?:" + r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s" + r"|([^\.]+))(?=\.|$))+" + % {"initial": initial, "final": final, "escaped": escaped_final} + ) return r def unformat_identifiers(self, identifiers): """Unpack 'schema.table.column'-like strings into components.""" r = self._r_identifiers - return [self._unescape_identifier(i) - for i in [a or b for a, b in r.findall(identifiers)]] + return [ + self._unescape_identifier(i) + for i in [a or b for a, b in r.findall(identifiers)] + ] diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 999d48a55..602b91a25 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -15,7 +15,9 @@ from . import dml from . import elements import operator -REQUIRED = util.symbol('REQUIRED', """ +REQUIRED = util.symbol( + "REQUIRED", + """ Placeholder for the value within a :class:`.BindParameter` which is required to be present when the statement is passed to :meth:`.Connection.execute`. @@ -24,11 +26,12 @@ This symbol is typically used when a :func:`.expression.insert` or :func:`.expression.update` statement is compiled without parameter values present. -""") +""", +) -ISINSERT = util.symbol('ISINSERT') -ISUPDATE = util.symbol('ISUPDATE') -ISDELETE = util.symbol('ISDELETE') +ISINSERT = util.symbol("ISINSERT") +ISUPDATE = util.symbol("ISUPDATE") +ISDELETE = util.symbol("ISDELETE") def _setup_crud_params(compiler, stmt, local_stmt_type, **kw): @@ -82,8 +85,7 @@ def _get_crud_params(compiler, stmt, **kw): # compiled params - return binds for all columns if compiler.column_keys is None and stmt.parameters is None: return [ - (c, _create_bind_param( - compiler, c, None, required=True)) + (c, _create_bind_param(compiler, c, None, required=True)) for c in stmt.table.columns ] @@ -95,26 +97,28 @@ def _get_crud_params(compiler, stmt, **kw): # getters - these are normally just column.key, # but in the case of mysql multi-table update, the rules for # .key must conditionally take tablename into account - _column_as_key, _getattr_col_key, _col_bind_name = \ - _key_getters_for_crud_column(compiler, stmt) + _column_as_key, _getattr_col_key, _col_bind_name = _key_getters_for_crud_column( + compiler, stmt + ) # if we have statement parameters - set defaults in the # compiled params if compiler.column_keys is None: parameters = {} else: - parameters = dict((_column_as_key(key), REQUIRED) - for key in compiler.column_keys - if not stmt_parameters or - key not in stmt_parameters) + parameters = dict( + (_column_as_key(key), REQUIRED) + for key in compiler.column_keys + if not stmt_parameters or key not in stmt_parameters + ) # create a list of column assignment clauses as tuples values = [] if stmt_parameters is not None: _get_stmt_parameters_params( - compiler, - parameters, stmt_parameters, _column_as_key, values, kw) + compiler, parameters, stmt_parameters, _column_as_key, values, kw + ) check_columns = {} @@ -122,28 +126,51 @@ def _get_crud_params(compiler, stmt, **kw): # statements if compiler.isupdate and stmt._extra_froms and stmt_parameters: _get_multitable_params( - compiler, stmt, stmt_parameters, check_columns, - _col_bind_name, _getattr_col_key, values, kw) + compiler, + stmt, + stmt_parameters, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, + ) if compiler.isinsert and stmt.select_names: _scan_insert_from_select_cols( - compiler, stmt, parameters, - _getattr_col_key, _column_as_key, - _col_bind_name, check_columns, values, kw) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, + ) else: _scan_cols( - compiler, stmt, parameters, - _getattr_col_key, _column_as_key, - _col_bind_name, check_columns, values, kw) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, + ) if parameters and stmt_parameters: - check = set(parameters).intersection( - _column_as_key(k) for k in stmt_parameters - ).difference(check_columns) + check = ( + set(parameters) + .intersection(_column_as_key(k) for k in stmt_parameters) + .difference(check_columns) + ) if check: raise exc.CompileError( - "Unconsumed column names: %s" % - (", ".join("%s" % c for c in check)) + "Unconsumed column names: %s" + % (", ".join("%s" % c for c in check)) ) if stmt._has_multi_parameters: @@ -153,12 +180,13 @@ def _get_crud_params(compiler, stmt, **kw): def _create_bind_param( - compiler, col, value, process=True, - required=False, name=None, **kw): + compiler, col, value, process=True, required=False, name=None, **kw +): if name is None: name = col.key bindparam = elements.BindParameter( - name, value, type_=col.type, required=required) + name, value, type_=col.type, required=required + ) bindparam._is_crud = True if process: bindparam = bindparam._compiler_dispatch(compiler, **kw) @@ -177,7 +205,7 @@ def _key_getters_for_crud_column(compiler, stmt): def _column_as_key(key): str_key = elements._column_as_key(key) - if hasattr(key, 'table') and key.table in _et: + if hasattr(key, "table") and key.table in _et: return (key.table.name, str_key) else: return str_key @@ -202,15 +230,22 @@ def _key_getters_for_crud_column(compiler, stmt): def _scan_insert_from_select_cols( - compiler, stmt, parameters, _getattr_col_key, - _column_as_key, _col_bind_name, check_columns, values, kw): - - need_pks, implicit_returning, \ - implicit_return_defaults, postfetch_lastrowid = \ - _get_returning_modifiers(compiler, stmt) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, +): + + need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers( + compiler, stmt + ) - cols = [stmt.table.c[_column_as_key(name)] - for name in stmt.select_names] + cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names] compiler._insert_from_select = stmt.select @@ -228,32 +263,39 @@ def _scan_insert_from_select_cols( values.append((c, None)) else: _append_param_insert_select_hasdefault( - compiler, stmt, c, add_select_cols, kw) + compiler, stmt, c, add_select_cols, kw + ) if add_select_cols: values.extend(add_select_cols) compiler._insert_from_select = compiler._insert_from_select._generate() - compiler._insert_from_select._raw_columns = \ - tuple(compiler._insert_from_select._raw_columns) + tuple( - expr for col, expr in add_select_cols) + compiler._insert_from_select._raw_columns = tuple( + compiler._insert_from_select._raw_columns + ) + tuple(expr for col, expr in add_select_cols) def _scan_cols( - compiler, stmt, parameters, _getattr_col_key, - _column_as_key, _col_bind_name, check_columns, values, kw): - - need_pks, implicit_returning, \ - implicit_return_defaults, postfetch_lastrowid = \ - _get_returning_modifiers(compiler, stmt) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, +): + + need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers( + compiler, stmt + ) if stmt._parameter_ordering: parameter_ordering = [ _column_as_key(key) for key in stmt._parameter_ordering ] ordered_keys = set(parameter_ordering) - cols = [ - stmt.table.c[key] for key in parameter_ordering - ] + [ + cols = [stmt.table.c[key] for key in parameter_ordering] + [ c for c in stmt.table.c if c.key not in ordered_keys ] else: @@ -265,72 +307,95 @@ def _scan_cols( if col_key in parameters and col_key not in check_columns: _append_param_parameter( - compiler, stmt, c, col_key, parameters, _col_bind_name, - implicit_returning, implicit_return_defaults, values, kw) + compiler, + stmt, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + values, + kw, + ) elif compiler.isinsert: - if c.primary_key and \ - need_pks and \ - ( - implicit_returning or - not postfetch_lastrowid or - c is not stmt.table._autoincrement_column - ): + if ( + c.primary_key + and need_pks + and ( + implicit_returning + or not postfetch_lastrowid + or c is not stmt.table._autoincrement_column + ) + ): if implicit_returning: _append_param_insert_pk_returning( - compiler, stmt, c, values, kw) + compiler, stmt, c, values, kw + ) else: _append_param_insert_pk(compiler, stmt, c, values, kw) elif c.default is not None: _append_param_insert_hasdefault( - compiler, stmt, c, implicit_return_defaults, - values, kw) + compiler, stmt, c, implicit_return_defaults, values, kw + ) elif c.server_default is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) - elif implicit_return_defaults and \ - c in implicit_return_defaults: + elif implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) - elif c.primary_key and \ - c is not stmt.table._autoincrement_column and \ - not c.nullable: + elif ( + c.primary_key + and c is not stmt.table._autoincrement_column + and not c.nullable + ): _warn_pk_with_no_anticipated_value(c) elif compiler.isupdate: _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw) + compiler, stmt, c, implicit_return_defaults, values, kw + ) def _append_param_parameter( - compiler, stmt, c, col_key, parameters, _col_bind_name, - implicit_returning, implicit_return_defaults, values, kw): + compiler, + stmt, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + values, + kw, +): value = parameters.pop(col_key) if elements._is_literal(value): value = _create_bind_param( - compiler, c, value, required=value is REQUIRED, + compiler, + c, + value, + required=value is REQUIRED, name=_col_bind_name(c) if not stmt._has_multi_parameters else "%s_m0" % _col_bind_name(c), **kw ) else: - if isinstance(value, elements.BindParameter) and \ - value.type._isnull: + if isinstance(value, elements.BindParameter) and value.type._isnull: value = value._clone() value.type = c.type if c.primary_key and implicit_returning: compiler.returning.append(c) value = compiler.process(value.self_group(), **kw) - elif implicit_return_defaults and \ - c in implicit_return_defaults: + elif implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) value = compiler.process(value.self_group(), **kw) else: @@ -358,22 +423,20 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): """ if c.default is not None: if c.default.is_sequence: - if compiler.dialect.supports_sequences and \ - (not c.default.optional or - not compiler.dialect.sequences_optional): + if compiler.dialect.supports_sequences and ( + not c.default.optional + or not compiler.dialect.sequences_optional + ): proc = compiler.process(c.default, **kw) values.append((c, proc)) compiler.returning.append(c) elif c.default.is_clause_element: values.append( - (c, compiler.process( - c.default.arg.self_group(), **kw)) + (c, compiler.process(c.default.arg.self_group(), **kw)) ) compiler.returning.append(c) else: - values.append( - (c, _create_insert_prefetch_bind_param(compiler, c)) - ) + values.append((c, _create_insert_prefetch_bind_param(compiler, c))) elif c is stmt.table._autoincrement_column or c.server_default is not None: compiler.returning.append(c) elif not c.nullable: @@ -405,9 +468,11 @@ class _multiparam_column(elements.ColumnElement): self.type = original.type def __eq__(self, other): - return isinstance(other, _multiparam_column) and \ - other.key == self.key and \ - other.original == self.original + return ( + isinstance(other, _multiparam_column) + and other.key == self.key + and other.original == self.original + ) def _process_multiparam_default_bind(compiler, stmt, c, index, kw): @@ -416,7 +481,8 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw): raise exc.CompileError( "INSERT value for column %s is explicitly rendered as a bound" "parameter in the VALUES clause; " - "a Python-side value or SQL expression is required" % c) + "a Python-side value or SQL expression is required" % c + ) elif c.default.is_clause_element: return compiler.process(c.default.arg.self_group(), **kw) else: @@ -440,30 +506,24 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): """ if ( - ( - # column has a Python-side default - c.default is not None and - ( - # and it won't be a Sequence - not c.default.is_sequence or - compiler.dialect.supports_sequences - ) - ) - or - ( - # column is the "autoincrement column" - c is stmt.table._autoincrement_column and - ( - # and it's either a "sequence" or a - # pre-executable "autoincrement" sequence - compiler.dialect.supports_sequences or - compiler.dialect.preexecute_autoincrement_sequences - ) - ) - ): - values.append( - (c, _create_insert_prefetch_bind_param(compiler, c)) + # column has a Python-side default + c.default is not None + and ( + # and it won't be a Sequence + not c.default.is_sequence + or compiler.dialect.supports_sequences ) + ) or ( + # column is the "autoincrement column" + c is stmt.table._autoincrement_column + and ( + # and it's either a "sequence" or a + # pre-executable "autoincrement" sequence + compiler.dialect.supports_sequences + or compiler.dialect.preexecute_autoincrement_sequences + ) + ): + values.append((c, _create_insert_prefetch_bind_param(compiler, c))) elif c.default is None and c.server_default is None and not c.nullable: # no .default, no .server_default, not autoincrement, we have # no indication this primary key column will have any value @@ -471,16 +531,16 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): def _append_param_insert_hasdefault( - compiler, stmt, c, implicit_return_defaults, values, kw): + compiler, stmt, c, implicit_return_defaults, values, kw +): if c.default.is_sequence: - if compiler.dialect.supports_sequences and \ - (not c.default.optional or - not compiler.dialect.sequences_optional): + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): proc = compiler.process(c.default, **kw) values.append((c, proc)) - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) @@ -488,25 +548,21 @@ def _append_param_insert_hasdefault( proc = compiler.process(c.default.arg.self_group(), **kw) values.append((c, proc)) - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: # don't add primary key column to postfetch compiler.postfetch.append(c) else: - values.append( - (c, _create_insert_prefetch_bind_param(compiler, c)) - ) + values.append((c, _create_insert_prefetch_bind_param(compiler, c))) -def _append_param_insert_select_hasdefault( - compiler, stmt, c, values, kw): +def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): if c.default.is_sequence: - if compiler.dialect.supports_sequences and \ - (not c.default.optional or - not compiler.dialect.sequences_optional): + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): proc = c.default values.append((c, proc.next_value())) elif c.default.is_clause_element: @@ -519,38 +575,43 @@ def _append_param_insert_select_hasdefault( def _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw): + compiler, stmt, c, implicit_return_defaults, values, kw +): if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, compiler.process( - c.onupdate.arg.self_group(), **kw)) + (c, compiler.process(c.onupdate.arg.self_group(), **kw)) ) - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) else: compiler.postfetch.append(c) else: - values.append( - (c, _create_update_prefetch_bind_param(compiler, c)) - ) + values.append((c, _create_update_prefetch_bind_param(compiler, c))) elif c.server_onupdate is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) else: compiler.postfetch.append(c) - elif implicit_return_defaults and \ - stmt._return_defaults is not True and \ - c in implicit_return_defaults: + elif ( + implicit_return_defaults + and stmt._return_defaults is not True + and c in implicit_return_defaults + ): compiler.returning.append(c) def _get_multitable_params( - compiler, stmt, stmt_parameters, check_columns, - _col_bind_name, _getattr_col_key, values, kw): + compiler, + stmt, + stmt_parameters, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, +): normalized_params = dict( (elements._clause_element_as_expr(c), param) @@ -565,8 +626,12 @@ def _get_multitable_params( value = normalized_params[c] if elements._is_literal(value): value = _create_bind_param( - compiler, c, value, required=value is REQUIRED, - name=_col_bind_name(c)) + compiler, + c, + value, + required=value is REQUIRED, + name=_col_bind_name(c), + ) else: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) @@ -577,20 +642,25 @@ def _get_multitable_params( for c in t.c: if c in normalized_params: continue - elif (c.onupdate is not None and not - c.onupdate.is_sequence): + elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, compiler.process( - c.onupdate.arg.self_group(), - **kw) - ) + ( + c, + compiler.process( + c.onupdate.arg.self_group(), **kw + ), + ) ) compiler.postfetch.append(c) else: values.append( - (c, _create_update_prefetch_bind_param( - compiler, c, name=_col_bind_name(c))) + ( + c, + _create_update_prefetch_bind_param( + compiler, c, name=_col_bind_name(c) + ), + ) ) elif c.server_onupdate is not None: compiler.postfetch.append(c) @@ -608,8 +678,11 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): if elements._is_literal(row[key]): new_param = _create_bind_param( - compiler, col, row[key], - name="%s_m%d" % (col.key, i + 1), **kw + compiler, + col, + row[key], + name="%s_m%d" % (col.key, i + 1), + **kw ) else: new_param = compiler.process(row[key].self_group(), **kw) @@ -626,7 +699,8 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): def _get_stmt_parameters_params( - compiler, parameters, stmt_parameters, _column_as_key, values, kw): + compiler, parameters, stmt_parameters, _column_as_key, values, kw +): for k, v in stmt_parameters.items(): colkey = _column_as_key(k) if colkey is not None: @@ -637,8 +711,8 @@ def _get_stmt_parameters_params( # coercing right side to bound param if elements._is_literal(v): v = compiler.process( - elements.BindParameter(None, v, type_=k.type), - **kw) + elements.BindParameter(None, v, type_=k.type), **kw + ) else: v = compiler.process(v.self_group(), **kw) @@ -646,22 +720,27 @@ def _get_stmt_parameters_params( def _get_returning_modifiers(compiler, stmt): - need_pks = compiler.isinsert and \ - not compiler.inline and \ - not stmt._returning and \ - not stmt._has_multi_parameters + need_pks = ( + compiler.isinsert + and not compiler.inline + and not stmt._returning + and not stmt._has_multi_parameters + ) - implicit_returning = need_pks and \ - compiler.dialect.implicit_returning and \ - stmt.table.implicit_returning + implicit_returning = ( + need_pks + and compiler.dialect.implicit_returning + and stmt.table.implicit_returning + ) if compiler.isinsert: - implicit_return_defaults = (implicit_returning and - stmt._return_defaults) + implicit_return_defaults = implicit_returning and stmt._return_defaults elif compiler.isupdate: - implicit_return_defaults = (compiler.dialect.implicit_returning and - stmt.table.implicit_returning and - stmt._return_defaults) + implicit_return_defaults = ( + compiler.dialect.implicit_returning + and stmt.table.implicit_returning + and stmt._return_defaults + ) else: # this line is unused, currently we are always # isinsert or isupdate @@ -675,8 +754,12 @@ def _get_returning_modifiers(compiler, stmt): postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid - return need_pks, implicit_returning, \ - implicit_return_defaults, postfetch_lastrowid + return ( + need_pks, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + ) def _warn_pk_with_no_anticipated_value(c): @@ -687,8 +770,8 @@ def _warn_pk_with_no_anticipated_value(c): "nor does it indicate 'autoincrement=True' or 'nullable=True', " "and no explicit value is passed. " "Primary key columns typically may not store NULL." - % - (c.table.fullname, c.name, c.table.fullname)) + % (c.table.fullname, c.name, c.table.fullname) + ) if len(c.table.primary_key) > 1: msg += ( " Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be " @@ -696,5 +779,6 @@ def _warn_pk_with_no_anticipated_value(c): "keys if AUTO_INCREMENT/SERIAL/IDENTITY " "behavior is expected for one of the columns in the primary key. " "CREATE TABLE statements are impacted by this change as well on " - "most backends.") + "most backends." + ) util.warn(msg) diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 91e93efe7..f21b3d7f0 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -56,8 +56,9 @@ class DDLElement(Executable, _DDLCompiles): """ - _execution_options = Executable.\ - _execution_options.union({'autocommit': True}) + _execution_options = Executable._execution_options.union( + {"autocommit": True} + ) target = None on = None @@ -95,11 +96,13 @@ class DDLElement(Executable, _DDLCompiles): if self._should_execute(target, bind): return bind.execute(self.against(target)) else: - bind.engine.logger.info( - "DDL execution skipped, criteria not met.") + bind.engine.logger.info("DDL execution skipped, criteria not met.") - @util.deprecated("0.7", "See :class:`.DDLEvents`, as well as " - ":meth:`.DDLElement.execute_if`.") + @util.deprecated( + "0.7", + "See :class:`.DDLEvents`, as well as " + ":meth:`.DDLElement.execute_if`.", + ) def execute_at(self, event_name, target): """Link execution of this DDL to the DDL lifecycle of a SchemaItem. @@ -129,11 +132,12 @@ class DDLElement(Executable, _DDLCompiles): """ def call_event(target, connection, **kw): - if self._should_execute_deprecated(event_name, - target, connection, **kw): + if self._should_execute_deprecated( + event_name, target, connection, **kw + ): return connection.execute(self.against(target)) - event.listen(target, "" + event_name.replace('-', '_'), call_event) + event.listen(target, "" + event_name.replace("-", "_"), call_event) @_generative def against(self, target): @@ -211,8 +215,9 @@ class DDLElement(Executable, _DDLCompiles): self.state = state def _should_execute(self, target, bind, **kw): - if self.on is not None and \ - not self._should_execute_deprecated(None, target, bind, **kw): + if self.on is not None and not self._should_execute_deprecated( + None, target, bind, **kw + ): return False if isinstance(self.dialect, util.string_types): @@ -221,9 +226,9 @@ class DDLElement(Executable, _DDLCompiles): elif isinstance(self.dialect, (tuple, list, set)): if bind.engine.name not in self.dialect: return False - if (self.callable_ is not None and - not self.callable_(self, target, bind, - state=self.state, **kw)): + if self.callable_ is not None and not self.callable_( + self, target, bind, state=self.state, **kw + ): return False return True @@ -245,13 +250,15 @@ class DDLElement(Executable, _DDLCompiles): return bind.execute(self.against(target)) def _check_ddl_on(self, on): - if (on is not None and - (not isinstance(on, util.string_types + (tuple, list, set)) and - not util.callable(on))): + if on is not None and ( + not isinstance(on, util.string_types + (tuple, list, set)) + and not util.callable(on) + ): raise exc.ArgumentError( "Expected the name of a database dialect, a tuple " "of names, or a callable for " - "'on' criteria, got type '%s'." % type(on).__name__) + "'on' criteria, got type '%s'." % type(on).__name__ + ) def bind(self): if self._bind: @@ -259,6 +266,7 @@ class DDLElement(Executable, _DDLCompiles): def _set_bind(self, bind): self._bind = bind + bind = property(bind, _set_bind) def _generate(self): @@ -375,8 +383,9 @@ class DDL(DDLElement): if not isinstance(statement, util.string_types): raise exc.ArgumentError( - "Expected a string or unicode SQL statement, got '%r'" % - statement) + "Expected a string or unicode SQL statement, got '%r'" + % statement + ) self.statement = statement self.context = context or {} @@ -386,12 +395,18 @@ class DDL(DDLElement): self._bind = bind def __repr__(self): - return '<%s@%s; %s>' % ( - type(self).__name__, id(self), - ', '.join([repr(self.statement)] + - ['%s=%r' % (key, getattr(self, key)) - for key in ('on', 'context') - if getattr(self, key)])) + return "<%s@%s; %s>" % ( + type(self).__name__, + id(self), + ", ".join( + [repr(self.statement)] + + [ + "%s=%r" % (key, getattr(self, key)) + for key in ("on", "context") + if getattr(self, key) + ] + ), + ) class _CreateDropBase(DDLElement): @@ -464,8 +479,8 @@ class CreateTable(_CreateDropBase): __visit_name__ = "create_table" def __init__( - self, element, on=None, bind=None, - include_foreign_key_constraints=None): + self, element, on=None, bind=None, include_foreign_key_constraints=None + ): """Create a :class:`.CreateTable` construct. :param element: a :class:`.Table` that's the subject @@ -481,9 +496,7 @@ class CreateTable(_CreateDropBase): """ super(CreateTable, self).__init__(element, on=on, bind=bind) - self.columns = [CreateColumn(column) - for column in element.columns - ] + self.columns = [CreateColumn(column) for column in element.columns] self.include_foreign_key_constraints = include_foreign_key_constraints @@ -494,6 +507,7 @@ class _DropView(_CreateDropBase): This object will eventually be part of a public "view" API. """ + __visit_name__ = "drop_view" @@ -602,7 +616,8 @@ class CreateColumn(_DDLCompiles): to support custom column creation styles. """ - __visit_name__ = 'create_column' + + __visit_name__ = "create_column" def __init__(self, element): self.element = element @@ -646,7 +661,8 @@ class AddConstraint(_CreateDropBase): def __init__(self, element, *args, **kw): super(AddConstraint, self).__init__(element, *args, **kw) element._create_rule = util.portable_instancemethod( - self._create_rule_disable) + self._create_rule_disable + ) class DropConstraint(_CreateDropBase): @@ -658,7 +674,8 @@ class DropConstraint(_CreateDropBase): self.cascade = cascade super(DropConstraint, self).__init__(element, **kw) element._create_rule = util.portable_instancemethod( - self._create_rule_disable) + self._create_rule_disable + ) class SetTableComment(_CreateDropBase): @@ -691,9 +708,9 @@ class DDLBase(SchemaVisitor): class SchemaGenerator(DDLBase): - - def __init__(self, dialect, connection, checkfirst=False, - tables=None, **kwargs): + def __init__( + self, dialect, connection, checkfirst=False, tables=None, **kwargs + ): super(SchemaGenerator, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables @@ -706,25 +723,22 @@ class SchemaGenerator(DDLBase): effective_schema = self.connection.schema_for_object(table) if effective_schema: self.dialect.validate_identifier(effective_schema) - return not self.checkfirst or \ - not self.dialect.has_table(self.connection, - table.name, schema=effective_schema) + return not self.checkfirst or not self.dialect.has_table( + self.connection, table.name, schema=effective_schema + ) def _can_create_sequence(self, sequence): effective_schema = self.connection.schema_for_object(sequence) - return self.dialect.supports_sequences and \ - ( - (not self.dialect.sequences_optional or - not sequence.optional) and - ( - not self.checkfirst or - not self.dialect.has_sequence( - self.connection, - sequence.name, - schema=effective_schema) + return self.dialect.supports_sequences and ( + (not self.dialect.sequences_optional or not sequence.optional) + and ( + not self.checkfirst + or not self.dialect.has_sequence( + self.connection, sequence.name, schema=effective_schema ) ) + ) def visit_metadata(self, metadata): if self.tables is not None: @@ -733,18 +747,23 @@ class SchemaGenerator(DDLBase): tables = list(metadata.tables.values()) collection = sort_tables_and_constraints( - [t for t in tables if self._can_create_table(t)]) - - seq_coll = [s for s in metadata._sequences.values() - if s.column is None and self._can_create_sequence(s)] + [t for t in tables if self._can_create_table(t)] + ) - event_collection = [ - t for (t, fks) in collection if t is not None + seq_coll = [ + s + for s in metadata._sequences.values() + if s.column is None and self._can_create_sequence(s) ] - metadata.dispatch.before_create(metadata, self.connection, - tables=event_collection, - checkfirst=self.checkfirst, - _ddl_runner=self) + + event_collection = [t for (t, fks) in collection if t is not None] + metadata.dispatch.before_create( + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) for seq in seq_coll: self.traverse_single(seq, create_ok=True) @@ -752,30 +771,40 @@ class SchemaGenerator(DDLBase): for table, fkcs in collection: if table is not None: self.traverse_single( - table, create_ok=True, + table, + create_ok=True, include_foreign_key_constraints=fkcs, - _is_metadata_operation=True) + _is_metadata_operation=True, + ) else: for fkc in fkcs: self.traverse_single(fkc) - metadata.dispatch.after_create(metadata, self.connection, - tables=event_collection, - checkfirst=self.checkfirst, - _ddl_runner=self) + metadata.dispatch.after_create( + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) def visit_table( - self, table, create_ok=False, - include_foreign_key_constraints=None, - _is_metadata_operation=False): + self, + table, + create_ok=False, + include_foreign_key_constraints=None, + _is_metadata_operation=False, + ): if not create_ok and not self._can_create_table(table): return table.dispatch.before_create( - table, self.connection, + table, + self.connection, checkfirst=self.checkfirst, _ddl_runner=self, - _is_metadata_operation=_is_metadata_operation) + _is_metadata_operation=_is_metadata_operation, + ) for column in table.columns: if column.default is not None: @@ -788,10 +817,11 @@ class SchemaGenerator(DDLBase): self.connection.execute( CreateTable( table, - include_foreign_key_constraints=include_foreign_key_constraints - )) + include_foreign_key_constraints=include_foreign_key_constraints, + ) + ) - if hasattr(table, 'indexes'): + if hasattr(table, "indexes"): for index in table.indexes: self.traverse_single(index) @@ -804,10 +834,12 @@ class SchemaGenerator(DDLBase): self.connection.execute(SetColumnComment(column)) table.dispatch.after_create( - table, self.connection, + table, + self.connection, checkfirst=self.checkfirst, _ddl_runner=self, - _is_metadata_operation=_is_metadata_operation) + _is_metadata_operation=_is_metadata_operation, + ) def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: @@ -824,9 +856,9 @@ class SchemaGenerator(DDLBase): class SchemaDropper(DDLBase): - - def __init__(self, dialect, connection, checkfirst=False, - tables=None, **kwargs): + def __init__( + self, dialect, connection, checkfirst=False, tables=None, **kwargs + ): super(SchemaDropper, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables @@ -842,15 +874,17 @@ class SchemaDropper(DDLBase): try: unsorted_tables = [t for t in tables if self._can_drop_table(t)] - collection = list(reversed( - sort_tables_and_constraints( - unsorted_tables, - filter_fn=lambda constraint: False - if not self.dialect.supports_alter - or constraint.name is None - else None + collection = list( + reversed( + sort_tables_and_constraints( + unsorted_tables, + filter_fn=lambda constraint: False + if not self.dialect.supports_alter + or constraint.name is None + else None, + ) ) - )) + ) except exc.CircularDependencyError as err2: if not self.dialect.supports_alter: util.warn( @@ -862,16 +896,15 @@ class SchemaDropper(DDLBase): "ForeignKeyConstraint " "objects involved in the cycle to mark these as known " "cycles that will be ignored." - % ( - ", ".join(sorted([t.fullname for t in err2.cycles])) - ) + % (", ".join(sorted([t.fullname for t in err2.cycles]))) ) collection = [(t, ()) for t in unsorted_tables] else: util.raise_from_cause( exc.CircularDependencyError( err2.args[0], - err2.cycles, err2.edges, + err2.cycles, + err2.edges, msg="Can't sort tables for DROP; an " "unresolvable foreign key " "dependency exists between tables: %s. Please ensure " @@ -880,9 +913,10 @@ class SchemaDropper(DDLBase): "names so that they can be dropped using " "DROP CONSTRAINT." % ( - ", ".join(sorted([t.fullname for t in err2.cycles])) - ) - + ", ".join( + sorted([t.fullname for t in err2.cycles]) + ) + ), ) ) @@ -892,18 +926,21 @@ class SchemaDropper(DDLBase): if s.column is None and self._can_drop_sequence(s) ] - event_collection = [ - t for (t, fks) in collection if t is not None - ] + event_collection = [t for (t, fks) in collection if t is not None] metadata.dispatch.before_drop( - metadata, self.connection, tables=event_collection, - checkfirst=self.checkfirst, _ddl_runner=self) + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) for table, fkcs in collection: if table is not None: self.traverse_single( - table, drop_ok=True, _is_metadata_operation=True) + table, drop_ok=True, _is_metadata_operation=True + ) else: for fkc in fkcs: self.traverse_single(fkc) @@ -912,8 +949,12 @@ class SchemaDropper(DDLBase): self.traverse_single(seq, drop_ok=True) metadata.dispatch.after_drop( - metadata, self.connection, tables=event_collection, - checkfirst=self.checkfirst, _ddl_runner=self) + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) def _can_drop_table(self, table): self.dialect.validate_identifier(table.name) @@ -921,19 +962,20 @@ class SchemaDropper(DDLBase): if effective_schema: self.dialect.validate_identifier(effective_schema) return not self.checkfirst or self.dialect.has_table( - self.connection, table.name, schema=effective_schema) + self.connection, table.name, schema=effective_schema + ) def _can_drop_sequence(self, sequence): effective_schema = self.connection.schema_for_object(sequence) - return self.dialect.supports_sequences and \ - ((not self.dialect.sequences_optional or - not sequence.optional) and - (not self.checkfirst or - self.dialect.has_sequence( - self.connection, - sequence.name, - schema=effective_schema)) - ) + return self.dialect.supports_sequences and ( + (not self.dialect.sequences_optional or not sequence.optional) + and ( + not self.checkfirst + or self.dialect.has_sequence( + self.connection, sequence.name, schema=effective_schema + ) + ) + ) def visit_index(self, index): self.connection.execute(DropIndex(index)) @@ -943,10 +985,12 @@ class SchemaDropper(DDLBase): return table.dispatch.before_drop( - table, self.connection, + table, + self.connection, checkfirst=self.checkfirst, _ddl_runner=self, - _is_metadata_operation=_is_metadata_operation) + _is_metadata_operation=_is_metadata_operation, + ) self.connection.execute(DropTable(table)) @@ -960,10 +1004,12 @@ class SchemaDropper(DDLBase): self.traverse_single(column.default) table.dispatch.after_drop( - table, self.connection, + table, + self.connection, checkfirst=self.checkfirst, _ddl_runner=self, - _is_metadata_operation=_is_metadata_operation) + _is_metadata_operation=_is_metadata_operation, + ) def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: @@ -1019,25 +1065,29 @@ def sort_tables(tables, skip_fn=None, extra_dependencies=None): """ if skip_fn is not None: + def _skip_fn(fkc): for fk in fkc.elements: if skip_fn(fk): return True else: return None + else: _skip_fn = None return [ - t for (t, fkcs) in - sort_tables_and_constraints( - tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies) + t + for (t, fkcs) in sort_tables_and_constraints( + tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies + ) if t is not None ] def sort_tables_and_constraints( - tables, filter_fn=None, extra_dependencies=None): + tables, filter_fn=None, extra_dependencies=None +): """sort a collection of :class:`.Table` / :class:`.ForeignKeyConstraint` objects. @@ -1109,8 +1159,9 @@ def sort_tables_and_constraints( try: candidate_sort = list( topological.sort( - fixed_dependencies.union(mutable_dependencies), tables, - deterministic_order=True + fixed_dependencies.union(mutable_dependencies), + tables, + deterministic_order=True, ) ) except exc.CircularDependencyError as err: @@ -1118,8 +1169,10 @@ def sort_tables_and_constraints( if edge in mutable_dependencies: table = edge[1] can_remove = [ - fkc for fkc in table.foreign_key_constraints - if filter_fn is None or filter_fn(fkc) is not False] + fkc + for fkc in table.foreign_key_constraints + if filter_fn is None or filter_fn(fkc) is not False + ] remaining_fkcs.update(can_remove) for fkc in can_remove: dependent_on = fkc.referred_table @@ -1127,8 +1180,9 @@ def sort_tables_and_constraints( mutable_dependencies.discard((dependent_on, table)) candidate_sort = list( topological.sort( - fixed_dependencies.union(mutable_dependencies), tables, - deterministic_order=True + fixed_dependencies.union(mutable_dependencies), + tables, + deterministic_order=True, ) ) diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 8149f9731..fa0052198 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -11,19 +11,43 @@ from .. import exc, util from . import type_api from . import operators -from .elements import BindParameter, True_, False_, BinaryExpression, \ - Null, _const_expr, _clause_element_as_expr, \ - ClauseList, ColumnElement, TextClause, UnaryExpression, \ - collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \ - Slice, Visitable, _literal_as_binds, CollectionAggregate, \ - Tuple +from .elements import ( + BindParameter, + True_, + False_, + BinaryExpression, + Null, + _const_expr, + _clause_element_as_expr, + ClauseList, + ColumnElement, + TextClause, + UnaryExpression, + collate, + _is_literal, + _literal_as_text, + ClauseElement, + and_, + or_, + Slice, + Visitable, + _literal_as_binds, + CollectionAggregate, + Tuple, +) from .selectable import SelectBase, Alias, Selectable, ScalarSelect -def _boolean_compare(expr, op, obj, negate=None, reverse=False, - _python_is_types=(util.NoneType, bool), - result_type = None, - **kwargs): +def _boolean_compare( + expr, + op, + obj, + negate=None, + reverse=False, + _python_is_types=(util.NoneType, bool), + result_type=None, + **kwargs +): if result_type is None: result_type = type_api.BOOLEANTYPE @@ -33,57 +57,64 @@ def _boolean_compare(expr, op, obj, negate=None, reverse=False, # allow x ==/!= True/False to be treated as a literal. # this comes out to "== / != true/false" or "1/0" if those # constants aren't supported and works on all platforms - if op in (operators.eq, operators.ne) and \ - isinstance(obj, (bool, True_, False_)): - return BinaryExpression(expr, - _literal_as_text(obj), - op, - type_=result_type, - negate=negate, modifiers=kwargs) + if op in (operators.eq, operators.ne) and isinstance( + obj, (bool, True_, False_) + ): + return BinaryExpression( + expr, + _literal_as_text(obj), + op, + type_=result_type, + negate=negate, + modifiers=kwargs, + ) elif op in (operators.is_distinct_from, operators.isnot_distinct_from): - return BinaryExpression(expr, - _literal_as_text(obj), - op, - type_=result_type, - negate=negate, modifiers=kwargs) + return BinaryExpression( + expr, + _literal_as_text(obj), + op, + type_=result_type, + negate=negate, + modifiers=kwargs, + ) else: # all other None/True/False uses IS, IS NOT if op in (operators.eq, operators.is_): - return BinaryExpression(expr, _const_expr(obj), - operators.is_, - negate=operators.isnot, - type_=result_type - ) + return BinaryExpression( + expr, + _const_expr(obj), + operators.is_, + negate=operators.isnot, + type_=result_type, + ) elif op in (operators.ne, operators.isnot): - return BinaryExpression(expr, _const_expr(obj), - operators.isnot, - negate=operators.is_, - type_=result_type - ) + return BinaryExpression( + expr, + _const_expr(obj), + operators.isnot, + negate=operators.is_, + type_=result_type, + ) else: raise exc.ArgumentError( "Only '=', '!=', 'is_()', 'isnot()', " "'is_distinct_from()', 'isnot_distinct_from()' " - "operators can be used with None/True/False") + "operators can be used with None/True/False" + ) else: obj = _check_literal(expr, op, obj) if reverse: - return BinaryExpression(obj, - expr, - op, - type_=result_type, - negate=negate, modifiers=kwargs) + return BinaryExpression( + obj, expr, op, type_=result_type, negate=negate, modifiers=kwargs + ) else: - return BinaryExpression(expr, - obj, - op, - type_=result_type, - negate=negate, modifiers=kwargs) + return BinaryExpression( + expr, obj, op, type_=result_type, negate=negate, modifiers=kwargs + ) -def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, - **kw): +def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw): if result_type is None: if op.return_type: result_type = op.return_type @@ -91,11 +122,11 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, result_type = type_api.BOOLEANTYPE return _binary_operate( - expr, op, obj, reverse=reverse, result_type=result_type, **kw) + expr, op, obj, reverse=reverse, result_type=result_type, **kw + ) -def _binary_operate(expr, op, obj, reverse=False, result_type=None, - **kw): +def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw): obj = _check_literal(expr, op, obj) if reverse: @@ -105,10 +136,10 @@ def _binary_operate(expr, op, obj, reverse=False, result_type=None, if result_type is None: op, result_type = left.comparator._adapt_expression( - op, right.comparator) + op, right.comparator + ) - return BinaryExpression( - left, right, op, type_=result_type, modifiers=kw) + return BinaryExpression(left, right, op, type_=result_type, modifiers=kw) def _conjunction_operate(expr, op, other, **kw): @@ -128,8 +159,7 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): seq_or_selectable = _clause_element_as_expr(seq_or_selectable) if isinstance(seq_or_selectable, ScalarSelect): - return _boolean_compare(expr, op, seq_or_selectable, - negate=negate_op) + return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op) elif isinstance(seq_or_selectable, SelectBase): # TODO: if we ever want to support (x, y, z) IN (select x, @@ -138,32 +168,33 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): # does not export itself as a FROM clause return _boolean_compare( - expr, op, seq_or_selectable.as_scalar(), - negate=negate_op, **kw) + expr, op, seq_or_selectable.as_scalar(), negate=negate_op, **kw + ) elif isinstance(seq_or_selectable, (Selectable, TextClause)): - return _boolean_compare(expr, op, seq_or_selectable, - negate=negate_op, **kw) + return _boolean_compare( + expr, op, seq_or_selectable, negate=negate_op, **kw + ) elif isinstance(seq_or_selectable, ClauseElement): - if isinstance(seq_or_selectable, BindParameter) and \ - seq_or_selectable.expanding: + if ( + isinstance(seq_or_selectable, BindParameter) + and seq_or_selectable.expanding + ): if isinstance(expr, Tuple): - seq_or_selectable = ( - seq_or_selectable._with_expanding_in_types( - [elem.type for elem in expr] - ) + seq_or_selectable = seq_or_selectable._with_expanding_in_types( + [elem.type for elem in expr] ) return _boolean_compare( - expr, op, - seq_or_selectable, - negate=negate_op) + expr, op, seq_or_selectable, negate=negate_op + ) else: raise exc.InvalidRequestError( - 'in_() accepts' - ' either a list of expressions, ' + "in_() accepts" + " either a list of expressions, " 'a selectable, or an "expanding" bound parameter: %r' - % seq_or_selectable) + % seq_or_selectable + ) # Handle non selectable arguments as sequences args = [] @@ -171,9 +202,10 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): if not _is_literal(o): if not isinstance(o, operators.ColumnOperators): raise exc.InvalidRequestError( - 'in_() accepts' - ' either a list of expressions, ' - 'a selectable, or an "expanding" bound parameter: %r' % o) + "in_() accepts" + " either a list of expressions, " + 'a selectable, or an "expanding" bound parameter: %r' % o + ) elif o is None: o = Null() else: @@ -182,15 +214,14 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): if len(args) == 0: op, negate_op = ( - operators.empty_in_op, - operators.empty_notin_op) if op is operators.in_op \ - else ( - operators.empty_notin_op, - operators.empty_in_op) + (operators.empty_in_op, operators.empty_notin_op) + if op is operators.in_op + else (operators.empty_notin_op, operators.empty_in_op) + ) - return _boolean_compare(expr, op, - ClauseList(*args).self_group(against=op), - negate=negate_op) + return _boolean_compare( + expr, op, ClauseList(*args).self_group(against=op), negate=negate_op + ) def _getitem_impl(expr, op, other, **kw): @@ -202,13 +233,14 @@ def _getitem_impl(expr, op, other, **kw): def _unsupported_impl(expr, op, *arg, **kw): - raise NotImplementedError("Operator '%s' is not supported on " - "this expression" % op.__name__) + raise NotImplementedError( + "Operator '%s' is not supported on " "this expression" % op.__name__ + ) def _inv_impl(expr, op, **kw): """See :meth:`.ColumnOperators.__inv__`.""" - if hasattr(expr, 'negation_clause'): + if hasattr(expr, "negation_clause"): return expr.negation_clause else: return expr._negate() @@ -223,20 +255,22 @@ def _match_impl(expr, op, other, **kw): """See :meth:`.ColumnOperators.match`.""" return _boolean_compare( - expr, operators.match_op, - _check_literal( - expr, operators.match_op, other), + expr, + operators.match_op, + _check_literal(expr, operators.match_op, other), result_type=type_api.MATCHTYPE, negate=operators.notmatch_op - if op is operators.match_op else operators.match_op, + if op is operators.match_op + else operators.match_op, **kw ) def _distinct_impl(expr, op, **kw): """See :meth:`.ColumnOperators.distinct`.""" - return UnaryExpression(expr, operator=operators.distinct_op, - type_=expr.type) + return UnaryExpression( + expr, operator=operators.distinct_op, type_=expr.type + ) def _between_impl(expr, op, cleft, cright, **kw): @@ -247,17 +281,21 @@ def _between_impl(expr, op, cleft, cright, **kw): _check_literal(expr, operators.and_, cleft), _check_literal(expr, operators.and_, cright), operator=operators.and_, - group=False, group_contents=False), + group=False, + group_contents=False, + ), op, negate=operators.notbetween_op if op is operators.between_op else operators.between_op, - modifiers=kw) + modifiers=kw, + ) def _collate_impl(expr, op, other, **kw): return collate(expr, other) + # a mapping of operators with the method they use, along with # their negated operator for comparison operators operator_lookup = { @@ -271,8 +309,8 @@ operator_lookup = { "mod": (_binary_operate,), "truediv": (_binary_operate,), "custom_op": (_custom_op_operate,), - "json_path_getitem_op": (_binary_operate, ), - "json_getitem_op": (_binary_operate, ), + "json_path_getitem_op": (_binary_operate,), + "json_getitem_op": (_binary_operate,), "concat_op": (_binary_operate,), "any_op": (_scalar, CollectionAggregate._create_any), "all_op": (_scalar, CollectionAggregate._create_all), @@ -303,8 +341,8 @@ operator_lookup = { "match_op": (_match_impl,), "notmatch_op": (_match_impl,), "distinct_op": (_distinct_impl,), - "between_op": (_between_impl, ), - "notbetween_op": (_between_impl, ), + "between_op": (_between_impl,), + "notbetween_op": (_between_impl,), "neg": (_neg_impl,), "getitem": (_getitem_impl,), "lshift": (_unsupported_impl,), @@ -315,12 +353,11 @@ operator_lookup = { def _check_literal(expr, operator, other, bindparam_type=None): if isinstance(other, (ColumnElement, TextClause)): - if isinstance(other, BindParameter) and \ - other.type._isnull: + if isinstance(other, BindParameter) and other.type._isnull: other = other._clone() other.type = expr.type return other - elif hasattr(other, '__clause_element__'): + elif hasattr(other, "__clause_element__"): other = other.__clause_element__() elif isinstance(other, type_api.TypeEngine.Comparator): other = other.expr @@ -331,4 +368,3 @@ def _check_literal(expr, operator, other, bindparam_type=None): return expr._bind_param(operator, other, type_=bindparam_type) else: return other - diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index d6890de15..0cea5ccc4 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -9,26 +9,43 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`. """ -from .base import Executable, _generative, _from_objects, DialectKWArgs, \ - ColumnCollection -from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \ - _column_as_key -from .selectable import _interpret_as_from, _interpret_as_select, \ - HasPrefixes, HasCTE +from .base import ( + Executable, + _generative, + _from_objects, + DialectKWArgs, + ColumnCollection, +) +from .elements import ( + ClauseElement, + _literal_as_text, + Null, + and_, + _clone, + _column_as_key, +) +from .selectable import ( + _interpret_as_from, + _interpret_as_select, + HasPrefixes, + HasCTE, +) from .. import util from .. import exc class UpdateBase( - HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement): + HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement +): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements. """ - __visit_name__ = 'update_base' + __visit_name__ = "update_base" - _execution_options = \ - Executable._execution_options.union({'autocommit': True}) + _execution_options = Executable._execution_options.union( + {"autocommit": True} + ) _hints = util.immutabledict() _parameter_ordering = None _prefixes = () @@ -37,30 +54,33 @@ class UpdateBase( def _process_colparams(self, parameters): def process_single(p): if isinstance(p, (list, tuple)): - return dict( - (c.key, pval) - for c, pval in zip(self.table.c, p) - ) + return dict((c.key, pval) for c, pval in zip(self.table.c, p)) else: return p if self._preserve_parameter_order and parameters is not None: - if not isinstance(parameters, list) or \ - (parameters and not isinstance(parameters[0], tuple)): + if not isinstance(parameters, list) or ( + parameters and not isinstance(parameters[0], tuple) + ): raise ValueError( "When preserve_parameter_order is True, " - "values() only accepts a list of 2-tuples") + "values() only accepts a list of 2-tuples" + ) self._parameter_ordering = [key for key, value in parameters] return dict(parameters), False - if (isinstance(parameters, (list, tuple)) and parameters and - isinstance(parameters[0], (list, tuple, dict))): + if ( + isinstance(parameters, (list, tuple)) + and parameters + and isinstance(parameters[0], (list, tuple, dict)) + ): if not self._supports_multi_parameters: raise exc.InvalidRequestError( "This construct does not support " - "multiple parameter sets.") + "multiple parameter sets." + ) return [process_single(p) for p in parameters], True else: @@ -77,7 +97,8 @@ class UpdateBase( raise NotImplementedError( "params() is not supported for INSERT/UPDATE/DELETE statements." " To set the values for an INSERT or UPDATE statement, use" - " stmt.values(**parameters).") + " stmt.values(**parameters)." + ) def bind(self): """Return a 'bind' linked to this :class:`.UpdateBase` @@ -88,6 +109,7 @@ class UpdateBase( def _set_bind(self, bind): self._bind = bind + bind = property(bind, _set_bind) @_generative @@ -181,15 +203,14 @@ class UpdateBase( if selectable is None: selectable = self.table - self._hints = self._hints.union( - {(selectable, dialect_name): text}) + self._hints = self._hints.union({(selectable, dialect_name): text}) class ValuesBase(UpdateBase): """Supplies support for :meth:`.ValuesBase.values` to INSERT and UPDATE constructs.""" - __visit_name__ = 'values_base' + __visit_name__ = "values_base" _supports_multi_parameters = False _has_multi_parameters = False @@ -199,8 +220,9 @@ class ValuesBase(UpdateBase): def __init__(self, table, values, prefixes): self.table = _interpret_as_from(table) - self.parameters, self._has_multi_parameters = \ - self._process_colparams(values) + self.parameters, self._has_multi_parameters = self._process_colparams( + values + ) if prefixes: self._setup_prefixes(prefixes) @@ -332,23 +354,27 @@ class ValuesBase(UpdateBase): """ if self.select is not None: raise exc.InvalidRequestError( - "This construct already inserts from a SELECT") + "This construct already inserts from a SELECT" + ) if self._has_multi_parameters and kwargs: raise exc.InvalidRequestError( - "This construct already has multiple parameter sets.") + "This construct already has multiple parameter sets." + ) if args: if len(args) > 1: raise exc.ArgumentError( "Only a single dictionary/tuple or list of " - "dictionaries/tuples is accepted positionally.") + "dictionaries/tuples is accepted positionally." + ) v = args[0] else: v = {} if self.parameters is None: - self.parameters, self._has_multi_parameters = \ - self._process_colparams(v) + self.parameters, self._has_multi_parameters = self._process_colparams( + v + ) else: if self._has_multi_parameters: self.parameters = list(self.parameters) @@ -356,7 +382,8 @@ class ValuesBase(UpdateBase): if not self._has_multi_parameters: raise exc.ArgumentError( "Can't mix single-values and multiple values " - "formats in one statement") + "formats in one statement" + ) self.parameters.extend(p) else: @@ -365,14 +392,16 @@ class ValuesBase(UpdateBase): if self._has_multi_parameters: raise exc.ArgumentError( "Can't mix single-values and multiple values " - "formats in one statement") + "formats in one statement" + ) self.parameters.update(p) if kwargs: if self._has_multi_parameters: raise exc.ArgumentError( "Can't pass kwargs and multiple parameter sets " - "simultaneously") + "simultaneously" + ) else: self.parameters.update(kwargs) @@ -456,19 +485,22 @@ class Insert(ValuesBase): :ref:`coretutorial_insert_expressions` """ - __visit_name__ = 'insert' + + __visit_name__ = "insert" _supports_multi_parameters = True - def __init__(self, - table, - values=None, - inline=False, - bind=None, - prefixes=None, - returning=None, - return_defaults=False, - **dialect_kw): + def __init__( + self, + table, + values=None, + inline=False, + bind=None, + prefixes=None, + returning=None, + return_defaults=False, + **dialect_kw + ): """Construct an :class:`.Insert` object. Similar functionality is available via the @@ -526,7 +558,7 @@ class Insert(ValuesBase): def get_children(self, **kwargs): if self.select is not None: - return self.select, + return (self.select,) else: return () @@ -578,11 +610,12 @@ class Insert(ValuesBase): """ if self.parameters: raise exc.InvalidRequestError( - "This construct already inserts value expressions") + "This construct already inserts value expressions" + ) - self.parameters, self._has_multi_parameters = \ - self._process_colparams( - {_column_as_key(n): Null() for n in names}) + self.parameters, self._has_multi_parameters = self._process_colparams( + {_column_as_key(n): Null() for n in names} + ) self.select_names = names self.inline = True @@ -603,19 +636,22 @@ class Update(ValuesBase): function. """ - __visit_name__ = 'update' - - def __init__(self, - table, - whereclause=None, - values=None, - inline=False, - bind=None, - prefixes=None, - returning=None, - return_defaults=False, - preserve_parameter_order=False, - **dialect_kw): + + __visit_name__ = "update" + + def __init__( + self, + table, + whereclause=None, + values=None, + inline=False, + bind=None, + prefixes=None, + returning=None, + return_defaults=False, + preserve_parameter_order=False, + **dialect_kw + ): r"""Construct an :class:`.Update` object. E.g.:: @@ -745,7 +781,7 @@ class Update(ValuesBase): def get_children(self, **kwargs): if self._whereclause is not None: - return self._whereclause, + return (self._whereclause,) else: return () @@ -761,8 +797,9 @@ class Update(ValuesBase): """ if self._whereclause is not None: - self._whereclause = and_(self._whereclause, - _literal_as_text(whereclause)) + self._whereclause = and_( + self._whereclause, _literal_as_text(whereclause) + ) else: self._whereclause = _literal_as_text(whereclause) @@ -788,15 +825,17 @@ class Delete(UpdateBase): """ - __visit_name__ = 'delete' - - def __init__(self, - table, - whereclause=None, - bind=None, - returning=None, - prefixes=None, - **dialect_kw): + __visit_name__ = "delete" + + def __init__( + self, + table, + whereclause=None, + bind=None, + returning=None, + prefixes=None, + **dialect_kw + ): """Construct :class:`.Delete` object. Similar functionality is available via the @@ -847,7 +886,7 @@ class Delete(UpdateBase): def get_children(self, **kwargs): if self._whereclause is not None: - return self._whereclause, + return (self._whereclause,) else: return () @@ -856,8 +895,9 @@ class Delete(UpdateBase): """Add the given WHERE clause to a newly returned delete construct.""" if self._whereclause is not None: - self._whereclause = and_(self._whereclause, - _literal_as_text(whereclause)) + self._whereclause = and_( + self._whereclause, _literal_as_text(whereclause) + ) else: self._whereclause = _literal_as_text(whereclause) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index de3b7992a..e857f2da8 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -51,9 +51,8 @@ def collate(expression, collation): expr = _literal_as_binds(expression) return BinaryExpression( - expr, - CollationClause(collation), - operators.collate, type_=expr.type) + expr, CollationClause(collation), operators.collate, type_=expr.type + ) def between(expr, lower_bound, upper_bound, symmetric=False): @@ -130,8 +129,6 @@ def literal(value, type_=None): return BindParameter(None, value, type_=type_, unique=True) - - def outparam(key, type_=None): """Create an 'OUT' parameter for usage in functions (stored procedures), for databases which support them. @@ -142,8 +139,7 @@ def outparam(key, type_=None): attribute, which returns a dictionary containing the values. """ - return BindParameter( - key, None, type_=type_, unique=False, isoutparam=True) + return BindParameter(key, None, type_=type_, unique=False, isoutparam=True) def not_(clause): @@ -163,7 +159,8 @@ class ClauseElement(Visitable): expression. """ - __visit_name__ = 'clause' + + __visit_name__ = "clause" _annotations = {} supports_execution = False @@ -230,7 +227,7 @@ class ClauseElement(Visitable): def __getstate__(self): d = self.__dict__.copy() - d.pop('_is_clone_of', None) + d.pop("_is_clone_of", None) return d def _annotate(self, values): @@ -300,7 +297,8 @@ class ClauseElement(Visitable): kwargs.update(optionaldict[0]) elif len(optionaldict) > 1: raise exc.ArgumentError( - "params() takes zero or one positional dictionary argument") + "params() takes zero or one positional dictionary argument" + ) def visit_bindparam(bind): if bind.key in kwargs: @@ -308,7 +306,8 @@ class ClauseElement(Visitable): bind.required = False if unique: bind._convert_to_unique() - return cloned_traverse(self, {}, {'bindparam': visit_bindparam}) + + return cloned_traverse(self, {}, {"bindparam": visit_bindparam}) def compare(self, other, **kw): r"""Compare this ClauseElement to the given ClauseElement. @@ -451,7 +450,7 @@ class ClauseElement(Visitable): if util.py3k: return str(self.compile()) else: - return unicode(self.compile()).encode('ascii', 'backslashreplace') + return unicode(self.compile()).encode("ascii", "backslashreplace") def __and__(self, other): """'and' at the ClauseElement level. @@ -472,7 +471,7 @@ class ClauseElement(Visitable): return or_(self, other) def __invert__(self): - if hasattr(self, 'negation_clause'): + if hasattr(self, "negation_clause"): return self.negation_clause else: return self._negate() @@ -481,7 +480,8 @@ class ClauseElement(Visitable): return UnaryExpression( self.self_group(against=operators.inv), operator=operators.inv, - negate=None) + negate=None, + ) def __bool__(self): raise TypeError("Boolean value of this clause is not defined") @@ -493,8 +493,12 @@ class ClauseElement(Visitable): if friendly is None: return object.__repr__(self) else: - return '<%s.%s at 0x%x; %s>' % ( - self.__module__, self.__class__.__name__, id(self), friendly) + return "<%s.%s at 0x%x; %s>" % ( + self.__module__, + self.__class__.__name__, + id(self), + friendly, + ) class ColumnElement(operators.ColumnOperators, ClauseElement): @@ -571,7 +575,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): """ - __visit_name__ = 'column_element' + __visit_name__ = "column_element" primary_key = False foreign_keys = [] @@ -646,11 +650,12 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): _alt_names = () def self_group(self, against=None): - if (against in (operators.and_, operators.or_, operators._asbool) and - self.type._type_affinity - is type_api.BOOLEANTYPE._type_affinity): + if ( + against in (operators.and_, operators.or_, operators._asbool) + and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity + ): return AsBoolean(self, operators.istrue, operators.isfalse) - elif (against in (operators.any_op, operators.all_op)): + elif against in (operators.any_op, operators.all_op): return Grouping(self) else: return self @@ -675,7 +680,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): except AttributeError: raise TypeError( "Object %r associated with '.type' attribute " - "is not a TypeEngine class or object" % self.type) + "is not a TypeEngine class or object" % self.type + ) else: return comparator_factory(self) @@ -684,10 +690,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): return getattr(self.comparator, key) except AttributeError: raise AttributeError( - 'Neither %r object nor %r object has an attribute %r' % ( - type(self).__name__, - type(self.comparator).__name__, - key) + "Neither %r object nor %r object has an attribute %r" + % (type(self).__name__, type(self.comparator).__name__, key) ) def operate(self, op, *other, **kwargs): @@ -697,10 +701,14 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): return op(other, self.comparator, **kwargs) def _bind_param(self, operator, obj, type_=None): - return BindParameter(None, obj, - _compared_to_operator=operator, - type_=type_, - _compared_to_type=self.type, unique=True) + return BindParameter( + None, + obj, + _compared_to_operator=operator, + type_=type_, + _compared_to_type=self.type, + unique=True, + ) @property def expression(self): @@ -713,17 +721,18 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): @property def _select_iterable(self): - return (self, ) + return (self,) @util.memoized_property def base_columns(self): - return util.column_set(c for c in self.proxy_set - if not hasattr(c, '_proxies')) + return util.column_set( + c for c in self.proxy_set if not hasattr(c, "_proxies") + ) @util.memoized_property def proxy_set(self): s = util.column_set([self]) - if hasattr(self, '_proxies'): + if hasattr(self, "_proxies"): for c in self._proxies: s.update(c.proxy_set) return s @@ -738,11 +747,15 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): """Return True if the given column element compares to this one when targeting within a result row.""" - return hasattr(other, 'name') and hasattr(self, 'name') and \ - other.name == self.name + return ( + hasattr(other, "name") + and hasattr(self, "name") + and other.name == self.name + ) def _make_proxy( - self, selectable, name=None, name_is_truncatable=False, **kw): + self, selectable, name=None, name_is_truncatable=False, **kw + ): """Create a new :class:`.ColumnElement` representing this :class:`.ColumnElement` as it appears in the select list of a descending selectable. @@ -762,13 +775,12 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): key = name co = ColumnClause( _as_truncated(name) if name_is_truncatable else name, - type_=getattr(self, 'type', None), - _selectable=selectable + type_=getattr(self, "type", None), + _selectable=selectable, ) co._proxies = [self] if selectable._is_clone_of is not None: - co._is_clone_of = \ - selectable._is_clone_of.columns.get(key) + co._is_clone_of = selectable._is_clone_of.columns.get(key) selectable._columns[key] = co return co @@ -788,7 +800,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): this one via foreign key or other criterion. """ - to_compare = (other, ) + to_compare = (other,) if equivalents and other in equivalents: to_compare = equivalents[other].union(to_compare) @@ -838,7 +850,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): self = self._is_clone_of return _anonymous_label( - '%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon')) + "%%(%d %s)s" % (id(self), getattr(self, "name", "anon")) ) @@ -862,18 +874,25 @@ class BindParameter(ColumnElement): """ - __visit_name__ = 'bindparam' + __visit_name__ = "bindparam" _is_crud = False _expanding_in_types = () - def __init__(self, key, value=NO_ARG, type_=None, - unique=False, required=NO_ARG, - quote=None, callable_=None, - expanding=False, - isoutparam=False, - _compared_to_operator=None, - _compared_to_type=None): + def __init__( + self, + key, + value=NO_ARG, + type_=None, + unique=False, + required=NO_ARG, + quote=None, + callable_=None, + expanding=False, + isoutparam=False, + _compared_to_operator=None, + _compared_to_type=None, + ): r"""Produce a "bound expression". The return value is an instance of :class:`.BindParameter`; this @@ -1093,7 +1112,7 @@ class BindParameter(ColumnElement): type_ = key.type key = key.key if required is NO_ARG: - required = (value is NO_ARG and callable_ is None) + required = value is NO_ARG and callable_ is None if value is NO_ARG: value = None @@ -1101,11 +1120,11 @@ class BindParameter(ColumnElement): key = quoted_name(key, quote) if unique: - self.key = _anonymous_label('%%(%d %s)s' % (id(self), key - or 'param')) + self.key = _anonymous_label( + "%%(%d %s)s" % (id(self), key or "param") + ) else: - self.key = key or _anonymous_label('%%(%d param)s' - % id(self)) + self.key = key or _anonymous_label("%%(%d param)s" % id(self)) # identifying key that won't change across # clones, used to identify the bind's logical @@ -1114,7 +1133,7 @@ class BindParameter(ColumnElement): # key that was passed in the first place, used to # generate new keys - self._orig_key = key or 'param' + self._orig_key = key or "param" self.unique = unique self.value = value @@ -1125,9 +1144,9 @@ class BindParameter(ColumnElement): if type_ is None: if _compared_to_type is not None: - self.type = \ - _compared_to_type.coerce_compared_value( - _compared_to_operator, value) + self.type = _compared_to_type.coerce_compared_value( + _compared_to_operator, value + ) else: self.type = type_api._resolve_value_to_type(value) elif isinstance(type_, type): @@ -1174,24 +1193,28 @@ class BindParameter(ColumnElement): def _clone(self): c = ClauseElement._clone(self) if self.unique: - c.key = _anonymous_label('%%(%d %s)s' % (id(c), c._orig_key - or 'param')) + c.key = _anonymous_label( + "%%(%d %s)s" % (id(c), c._orig_key or "param") + ) return c def _convert_to_unique(self): if not self.unique: self.unique = True self.key = _anonymous_label( - '%%(%d %s)s' % (id(self), self._orig_key or 'param')) + "%%(%d %s)s" % (id(self), self._orig_key or "param") + ) def compare(self, other, **kw): """Compare this :class:`BindParameter` to the given clause.""" - return isinstance(other, BindParameter) \ - and self.type._compare_type_affinity(other.type) \ - and self.value == other.value \ + return ( + isinstance(other, BindParameter) + and self.type._compare_type_affinity(other.type) + and self.value == other.value and self.callable == other.callable + ) def __getstate__(self): """execute a deferred value for serialization purposes.""" @@ -1200,13 +1223,16 @@ class BindParameter(ColumnElement): v = self.value if self.callable: v = self.callable() - d['callable'] = None - d['value'] = v + d["callable"] = None + d["value"] = v return d def __repr__(self): - return 'BindParameter(%r, %r, type_=%r)' % (self.key, - self.value, self.type) + return "BindParameter(%r, %r, type_=%r)" % ( + self.key, + self.value, + self.type, + ) class TypeClause(ClauseElement): @@ -1216,7 +1242,7 @@ class TypeClause(ClauseElement): """ - __visit_name__ = 'typeclause' + __visit_name__ = "typeclause" def __init__(self, type): self.type = type @@ -1242,12 +1268,12 @@ class TextClause(Executable, ClauseElement): """ - __visit_name__ = 'textclause' + __visit_name__ = "textclause" - _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE) - _execution_options = \ - Executable._execution_options.union( - {'autocommit': PARSE_AUTOCOMMIT}) + _bind_params_regex = re.compile(r"(?<![:\w\x5c]):(\w+)(?!:)", re.UNICODE) + _execution_options = Executable._execution_options.union( + {"autocommit": PARSE_AUTOCOMMIT} + ) _is_implicitly_boolean = False @property @@ -1268,24 +1294,22 @@ class TextClause(Executable, ClauseElement): _allow_label_resolve = False - def __init__( - self, - text, - bind=None): + def __init__(self, text, bind=None): self._bind = bind self._bindparams = {} def repl(m): self._bindparams[m.group(1)] = BindParameter(m.group(1)) - return ':%s' % m.group(1) + return ":%s" % m.group(1) # scan the string and search for bind parameter names, add them # to the list of bindparams self.text = self._bind_params_regex.sub(repl, text) @classmethod - def _create_text(self, text, bind=None, bindparams=None, - typemap=None, autocommit=None): + def _create_text( + self, text, bind=None, bindparams=None, typemap=None, autocommit=None + ): r"""Construct a new :class:`.TextClause` clause, representing a textual SQL string directly. @@ -1428,8 +1452,10 @@ class TextClause(Executable, ClauseElement): if typemap: stmt = stmt.columns(**typemap) if autocommit is not None: - util.warn_deprecated('autocommit on text() is deprecated. ' - 'Use .execution_options(autocommit=True)') + util.warn_deprecated( + "autocommit on text() is deprecated. " + "Use .execution_options(autocommit=True)" + ) stmt = stmt.execution_options(autocommit=autocommit) return stmt @@ -1513,7 +1539,8 @@ class TextClause(Executable, ClauseElement): except KeyError: raise exc.ArgumentError( "This text() construct doesn't define a " - "bound parameter named %r" % bind.key) + "bound parameter named %r" % bind.key + ) else: new_params[existing.key] = bind @@ -1523,11 +1550,12 @@ class TextClause(Executable, ClauseElement): except KeyError: raise exc.ArgumentError( "This text() construct doesn't define a " - "bound parameter named %r" % key) + "bound parameter named %r" % key + ) else: new_params[key] = existing._with_value(value) - @util.dependencies('sqlalchemy.sql.selectable') + @util.dependencies("sqlalchemy.sql.selectable") def columns(self, selectable, *cols, **types): """Turn this :class:`.TextClause` object into a :class:`.TextAsFrom` object that can be embedded into another statement. @@ -1629,12 +1657,14 @@ class TextClause(Executable, ClauseElement): for col in cols ] keyed_input_cols = [ - ColumnClause(key, type_) for key, type_ in types.items()] + ColumnClause(key, type_) for key, type_ in types.items() + ] return selectable.TextAsFrom( self, positional_input_cols + keyed_input_cols, - positional=bool(positional_input_cols) and not keyed_input_cols) + positional=bool(positional_input_cols) and not keyed_input_cols, + ) @property def type(self): @@ -1651,8 +1681,9 @@ class TextClause(Executable, ClauseElement): return self def _copy_internals(self, clone=_clone, **kw): - self._bindparams = dict((b.key, clone(b, **kw)) - for b in self._bindparams.values()) + self._bindparams = dict( + (b.key, clone(b, **kw)) for b in self._bindparams.values() + ) def get_children(self, **kwargs): return list(self._bindparams.values()) @@ -1669,7 +1700,7 @@ class Null(ColumnElement): """ - __visit_name__ = 'null' + __visit_name__ = "null" @util.memoized_property def type(self): @@ -1693,7 +1724,7 @@ class False_(ColumnElement): """ - __visit_name__ = 'false' + __visit_name__ = "false" @util.memoized_property def type(self): @@ -1752,7 +1783,7 @@ class True_(ColumnElement): """ - __visit_name__ = 'true' + __visit_name__ = "true" @util.memoized_property def type(self): @@ -1816,23 +1847,23 @@ class ClauseList(ClauseElement): By default, is comma-separated, such as a column listing. """ - __visit_name__ = 'clauselist' + + __visit_name__ = "clauselist" def __init__(self, *clauses, **kwargs): - self.operator = kwargs.pop('operator', operators.comma_op) - self.group = kwargs.pop('group', True) - self.group_contents = kwargs.pop('group_contents', True) + self.operator = kwargs.pop("operator", operators.comma_op) + self.group = kwargs.pop("group", True) + self.group_contents = kwargs.pop("group_contents", True) text_converter = kwargs.pop( - '_literal_as_text', - _expression_literal_as_text) + "_literal_as_text", _expression_literal_as_text + ) if self.group_contents: self.clauses = [ text_converter(clause).self_group(against=self.operator) - for clause in clauses] + for clause in clauses + ] else: - self.clauses = [ - text_converter(clause) - for clause in clauses] + self.clauses = [text_converter(clause) for clause in clauses] self._is_implicitly_boolean = operators.is_boolean(self.operator) def __iter__(self): @@ -1847,8 +1878,9 @@ class ClauseList(ClauseElement): def append(self, clause): if self.group_contents: - self.clauses.append(_literal_as_text(clause). - self_group(against=self.operator)) + self.clauses.append( + _literal_as_text(clause).self_group(against=self.operator) + ) else: self.clauses.append(_literal_as_text(clause)) @@ -1875,14 +1907,18 @@ class ClauseList(ClauseElement): """ if not isinstance(other, ClauseList) and len(self.clauses) == 1: return self.clauses[0].compare(other, **kw) - elif isinstance(other, ClauseList) and \ - len(self.clauses) == len(other.clauses) and \ - self.operator is other.operator: + elif ( + isinstance(other, ClauseList) + and len(self.clauses) == len(other.clauses) + and self.operator is other.operator + ): if self.operator in (operators.and_, operators.or_): completed = set() for clause in self.clauses: - for other_clause in set(other.clauses).difference(completed): + for other_clause in set(other.clauses).difference( + completed + ): if clause.compare(other_clause, **kw): completed.add(other_clause) break @@ -1898,11 +1934,12 @@ class ClauseList(ClauseElement): class BooleanClauseList(ClauseList, ColumnElement): - __visit_name__ = 'clauselist' + __visit_name__ = "clauselist" def __init__(self, *arg, **kw): raise NotImplementedError( - "BooleanClauseList has a private constructor") + "BooleanClauseList has a private constructor" + ) @classmethod def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): @@ -1910,8 +1947,7 @@ class BooleanClauseList(ClauseList, ColumnElement): clauses = [ _expression_literal_as_text(clause) - for clause in - util.coerce_generator_arg(clauses) + for clause in util.coerce_generator_arg(clauses) ] for clause in clauses: @@ -1927,8 +1963,9 @@ class BooleanClauseList(ClauseList, ColumnElement): elif not convert_clauses and clauses: return clauses[0].self_group(against=operators._asbool) - convert_clauses = [c.self_group(against=operator) - for c in convert_clauses] + convert_clauses = [ + c.self_group(against=operator) for c in convert_clauses + ] self = cls.__new__(cls) self.clauses = convert_clauses @@ -2014,7 +2051,7 @@ class BooleanClauseList(ClauseList, ColumnElement): @property def _select_iterable(self): - return (self, ) + return (self,) def self_group(self, against=None): if not self.clauses: @@ -2056,22 +2093,31 @@ class Tuple(ClauseList, ColumnElement): clauses = [_literal_as_binds(c) for c in clauses] self._type_tuple = [arg.type for arg in clauses] - self.type = kw.pop('type_', self._type_tuple[0] - if self._type_tuple else type_api.NULLTYPE) + self.type = kw.pop( + "type_", + self._type_tuple[0] if self._type_tuple else type_api.NULLTYPE, + ) super(Tuple, self).__init__(*clauses, **kw) @property def _select_iterable(self): - return (self, ) + return (self,) def _bind_param(self, operator, obj, type_=None): - return Tuple(*[ - BindParameter(None, o, _compared_to_operator=operator, - _compared_to_type=compared_to_type, unique=True, - type_=type_) - for o, compared_to_type in zip(obj, self._type_tuple) - ]).self_group() + return Tuple( + *[ + BindParameter( + None, + o, + _compared_to_operator=operator, + _compared_to_type=compared_to_type, + unique=True, + type_=type_, + ) + for o, compared_to_type in zip(obj, self._type_tuple) + ] + ).self_group() class Case(ColumnElement): @@ -2101,7 +2147,7 @@ class Case(ColumnElement): """ - __visit_name__ = 'case' + __visit_name__ = "case" def __init__(self, whens, value=None, else_=None): r"""Produce a ``CASE`` expression. @@ -2231,13 +2277,13 @@ class Case(ColumnElement): if value is not None: whenlist = [ - (_literal_as_binds(c).self_group(), - _literal_as_binds(r)) for (c, r) in whens + (_literal_as_binds(c).self_group(), _literal_as_binds(r)) + for (c, r) in whens ] else: whenlist = [ - (_no_literals(c).self_group(), - _literal_as_binds(r)) for (c, r) in whens + (_no_literals(c).self_group(), _literal_as_binds(r)) + for (c, r) in whens ] if whenlist: @@ -2260,8 +2306,7 @@ class Case(ColumnElement): def _copy_internals(self, clone=_clone, **kw): if self.value is not None: self.value = clone(self.value, **kw) - self.whens = [(clone(x, **kw), clone(y, **kw)) - for x, y in self.whens] + self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens] if self.else_ is not None: self.else_ = clone(self.else_, **kw) @@ -2276,8 +2321,9 @@ class Case(ColumnElement): @property def _from_objects(self): - return list(itertools.chain(*[x._from_objects for x in - self.get_children()])) + return list( + itertools.chain(*[x._from_objects for x in self.get_children()]) + ) def literal_column(text, type_=None): @@ -2333,7 +2379,7 @@ class Cast(ColumnElement): """ - __visit_name__ = 'cast' + __visit_name__ = "cast" def __init__(self, expression, type_): """Produce a ``CAST`` expression. @@ -2416,7 +2462,7 @@ class TypeCoerce(ColumnElement): """ - __visit_name__ = 'type_coerce' + __visit_name__ = "type_coerce" def __init__(self, expression, type_): """Associate a SQL expression with a particular type, without rendering @@ -2484,10 +2530,10 @@ class TypeCoerce(ColumnElement): def _copy_internals(self, clone=_clone, **kw): self.clause = clone(self.clause, **kw) - self.__dict__.pop('typed_expression', None) + self.__dict__.pop("typed_expression", None) def get_children(self, **kwargs): - return self.clause, + return (self.clause,) @property def _from_objects(self): @@ -2506,7 +2552,7 @@ class TypeCoerce(ColumnElement): class Extract(ColumnElement): """Represent a SQL EXTRACT clause, ``extract(field FROM expr)``.""" - __visit_name__ = 'extract' + __visit_name__ = "extract" def __init__(self, field, expr, **kwargs): """Return a :class:`.Extract` construct. @@ -2524,7 +2570,7 @@ class Extract(ColumnElement): self.expr = clone(self.expr, **kw) def get_children(self, **kwargs): - return self.expr, + return (self.expr,) @property def _from_objects(self): @@ -2543,7 +2589,8 @@ class _label_reference(ColumnElement): within an OVER clause. """ - __visit_name__ = 'label_reference' + + __visit_name__ = "label_reference" def __init__(self, element): self.element = element @@ -2557,7 +2604,7 @@ class _label_reference(ColumnElement): class _textual_label_reference(ColumnElement): - __visit_name__ = 'textual_label_reference' + __visit_name__ = "textual_label_reference" def __init__(self, element): self.element = element @@ -2580,14 +2627,23 @@ class UnaryExpression(ColumnElement): :func:`.nullsfirst` and :func:`.nullslast`. """ - __visit_name__ = 'unary' - def __init__(self, element, operator=None, modifier=None, - type_=None, negate=None, wraps_column_expression=False): + __visit_name__ = "unary" + + def __init__( + self, + element, + operator=None, + modifier=None, + type_=None, + negate=None, + wraps_column_expression=False, + ): self.operator = operator self.modifier = modifier self.element = element.self_group( - against=self.operator or self.modifier) + against=self.operator or self.modifier + ) self.type = type_api.to_instance(type_) self.negate = negate self.wraps_column_expression = wraps_column_expression @@ -2633,7 +2689,8 @@ class UnaryExpression(ColumnElement): return UnaryExpression( _literal_as_label_reference(column), modifier=operators.nullsfirst_op, - wraps_column_expression=False) + wraps_column_expression=False, + ) @classmethod def _create_nullslast(cls, column): @@ -2675,7 +2732,8 @@ class UnaryExpression(ColumnElement): return UnaryExpression( _literal_as_label_reference(column), modifier=operators.nullslast_op, - wraps_column_expression=False) + wraps_column_expression=False, + ) @classmethod def _create_desc(cls, column): @@ -2715,7 +2773,8 @@ class UnaryExpression(ColumnElement): return UnaryExpression( _literal_as_label_reference(column), modifier=operators.desc_op, - wraps_column_expression=False) + wraps_column_expression=False, + ) @classmethod def _create_asc(cls, column): @@ -2754,7 +2813,8 @@ class UnaryExpression(ColumnElement): return UnaryExpression( _literal_as_label_reference(column), modifier=operators.asc_op, - wraps_column_expression=False) + wraps_column_expression=False, + ) @classmethod def _create_distinct(cls, expr): @@ -2794,8 +2854,11 @@ class UnaryExpression(ColumnElement): """ expr = _literal_as_binds(expr) return UnaryExpression( - expr, operator=operators.distinct_op, - type_=expr.type, wraps_column_expression=False) + expr, + operator=operators.distinct_op, + type_=expr.type, + wraps_column_expression=False, + ) @property def _order_by_label_element(self): @@ -2812,17 +2875,17 @@ class UnaryExpression(ColumnElement): self.element = clone(self.element, **kw) def get_children(self, **kwargs): - return self.element, + return (self.element,) def compare(self, other, **kw): """Compare this :class:`UnaryExpression` against the given :class:`.ClauseElement`.""" return ( - isinstance(other, UnaryExpression) and - self.operator == other.operator and - self.modifier == other.modifier and - self.element.compare(other.element, **kw) + isinstance(other, UnaryExpression) + and self.operator == other.operator + and self.modifier == other.modifier + and self.element.compare(other.element, **kw) ) def _negate(self): @@ -2833,14 +2896,16 @@ class UnaryExpression(ColumnElement): negate=self.operator, modifier=self.modifier, type_=self.type, - wraps_column_expression=self.wraps_column_expression) + wraps_column_expression=self.wraps_column_expression, + ) elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: return UnaryExpression( self.self_group(against=operators.inv), operator=operators.inv, type_=type_api.BOOLEANTYPE, wraps_column_expression=self.wraps_column_expression, - negate=None) + negate=None, + ) else: return ClauseElement._negate(self) @@ -2860,6 +2925,7 @@ class CollectionAggregate(UnaryExpression): MySQL, they only work for subqueries. """ + @classmethod def _create_any(cls, expr): """Produce an ANY expression. @@ -2883,12 +2949,15 @@ class CollectionAggregate(UnaryExpression): expr = _literal_as_binds(expr) - if expr.is_selectable and hasattr(expr, 'as_scalar'): + if expr.is_selectable and hasattr(expr, "as_scalar"): expr = expr.as_scalar() expr = expr.self_group() return CollectionAggregate( - expr, operator=operators.any_op, - type_=type_api.NULLTYPE, wraps_column_expression=False) + expr, + operator=operators.any_op, + type_=type_api.NULLTYPE, + wraps_column_expression=False, + ) @classmethod def _create_all(cls, expr): @@ -2912,12 +2981,15 @@ class CollectionAggregate(UnaryExpression): """ expr = _literal_as_binds(expr) - if expr.is_selectable and hasattr(expr, 'as_scalar'): + if expr.is_selectable and hasattr(expr, "as_scalar"): expr = expr.as_scalar() expr = expr.self_group() return CollectionAggregate( - expr, operator=operators.all_op, - type_=type_api.NULLTYPE, wraps_column_expression=False) + expr, + operator=operators.all_op, + type_=type_api.NULLTYPE, + wraps_column_expression=False, + ) # operate and reverse_operate are hardwired to # dispatch onto the type comparator directly, so that we can @@ -2925,19 +2997,20 @@ class CollectionAggregate(UnaryExpression): def operate(self, op, *other, **kwargs): if not operators.is_comparison(op): raise exc.ArgumentError( - "Only comparison operators may be used with ANY/ALL") - kwargs['reverse'] = True + "Only comparison operators may be used with ANY/ALL" + ) + kwargs["reverse"] = True return self.comparator.operate(operators.mirror(op), *other, **kwargs) def reverse_operate(self, op, other, **kwargs): # comparison operators should never call reverse_operate assert not operators.is_comparison(op) raise exc.ArgumentError( - "Only comparison operators may be used with ANY/ALL") + "Only comparison operators may be used with ANY/ALL" + ) class AsBoolean(UnaryExpression): - def __init__(self, element, operator, negate): self.element = element self.type = type_api.BOOLEANTYPE @@ -2971,7 +3044,7 @@ class BinaryExpression(ColumnElement): """ - __visit_name__ = 'binary' + __visit_name__ = "binary" _is_implicitly_boolean = True """Indicates that any database will know this is a boolean expression @@ -2979,8 +3052,9 @@ class BinaryExpression(ColumnElement): """ - def __init__(self, left, right, operator, type_=None, - negate=None, modifiers=None): + def __init__( + self, left, right, operator, type_=None, negate=None, modifiers=None + ): # allow compatibility with libraries that # refer to BinaryExpression directly and pass strings if isinstance(operator, util.string_types): @@ -3026,15 +3100,15 @@ class BinaryExpression(ColumnElement): given :class:`BinaryExpression`.""" return ( - isinstance(other, BinaryExpression) and - self.operator == other.operator and - ( - self.left.compare(other.left, **kw) and - self.right.compare(other.right, **kw) or - ( - operators.is_commutative(self.operator) and - self.left.compare(other.right, **kw) and - self.right.compare(other.left, **kw) + isinstance(other, BinaryExpression) + and self.operator == other.operator + and ( + self.left.compare(other.left, **kw) + and self.right.compare(other.right, **kw) + or ( + operators.is_commutative(self.operator) + and self.left.compare(other.right, **kw) + and self.right.compare(other.left, **kw) ) ) ) @@ -3053,7 +3127,8 @@ class BinaryExpression(ColumnElement): self.negate, negate=self.operator, type_=self.type, - modifiers=self.modifiers) + modifiers=self.modifiers, + ) else: return super(BinaryExpression, self)._negate() @@ -3065,7 +3140,8 @@ class Slice(ColumnElement): may be interpreted by specific dialects, e.g. PostgreSQL. """ - __visit_name__ = 'slice' + + __visit_name__ = "slice" def __init__(self, start, stop, step): self.start = start @@ -3081,17 +3157,18 @@ class Slice(ColumnElement): class IndexExpression(BinaryExpression): """Represent the class of expressions that are like an "index" operation. """ + pass class Grouping(ColumnElement): """Represent a grouping within a column expression""" - __visit_name__ = 'grouping' + __visit_name__ = "grouping" def __init__(self, element): self.element = element - self.type = getattr(element, 'type', type_api.NULLTYPE) + self.type = getattr(element, "type", type_api.NULLTYPE) def self_group(self, against=None): return self @@ -3106,13 +3183,13 @@ class Grouping(ColumnElement): @property def _label(self): - return getattr(self.element, '_label', None) or self.anon_label + return getattr(self.element, "_label", None) or self.anon_label def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) def get_children(self, **kwargs): - return self.element, + return (self.element,) @property def _from_objects(self): @@ -3122,15 +3199,16 @@ class Grouping(ColumnElement): return getattr(self.element, attr) def __getstate__(self): - return {'element': self.element, 'type': self.type} + return {"element": self.element, "type": self.type} def __setstate__(self, state): - self.element = state['element'] - self.type = state['type'] + self.element = state["element"] + self.type = state["type"] def compare(self, other, **kw): - return isinstance(other, Grouping) and \ - self.element.compare(other.element) + return isinstance(other, Grouping) and self.element.compare( + other.element + ) RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED") @@ -3147,14 +3225,15 @@ class Over(ColumnElement): backends. """ - __visit_name__ = 'over' + + __visit_name__ = "over" order_by = None partition_by = None def __init__( - self, element, partition_by=None, - order_by=None, range_=None, rows=None): + self, element, partition_by=None, order_by=None, range_=None, rows=None + ): """Produce an :class:`.Over` object against a function. Used against aggregate or so-called "window" functions, @@ -3237,17 +3316,20 @@ class Over(ColumnElement): if order_by is not None: self.order_by = ClauseList( *util.to_list(order_by), - _literal_as_text=_literal_as_label_reference) + _literal_as_text=_literal_as_label_reference + ) if partition_by is not None: self.partition_by = ClauseList( *util.to_list(partition_by), - _literal_as_text=_literal_as_label_reference) + _literal_as_text=_literal_as_label_reference + ) if range_: self.range_ = self._interpret_range(range_) if rows: raise exc.ArgumentError( - "'range_' and 'rows' are mutually exclusive") + "'range_' and 'rows' are mutually exclusive" + ) else: self.rows = None elif rows: @@ -3267,7 +3349,8 @@ class Over(ColumnElement): lower = int(range_[0]) except ValueError: raise exc.ArgumentError( - "Integer or None expected for range value") + "Integer or None expected for range value" + ) else: if lower == 0: lower = RANGE_CURRENT @@ -3279,7 +3362,8 @@ class Over(ColumnElement): upper = int(range_[1]) except ValueError: raise exc.ArgumentError( - "Integer or None expected for range value") + "Integer or None expected for range value" + ) else: if upper == 0: upper = RANGE_CURRENT @@ -3303,9 +3387,11 @@ class Over(ColumnElement): return self.element.type def get_children(self, **kwargs): - return [c for c in - (self.element, self.partition_by, self.order_by) - if c is not None] + return [ + c + for c in (self.element, self.partition_by, self.order_by) + if c is not None + ] def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) @@ -3316,11 +3402,15 @@ class Over(ColumnElement): @property def _from_objects(self): - return list(itertools.chain( - *[c._from_objects for c in - (self.element, self.partition_by, self.order_by) - if c is not None] - )) + return list( + itertools.chain( + *[ + c._from_objects + for c in (self.element, self.partition_by, self.order_by) + if c is not None + ] + ) + ) class WithinGroup(ColumnElement): @@ -3339,7 +3429,8 @@ class WithinGroup(ColumnElement): ``None``, the function's ``.type`` is used. """ - __visit_name__ = 'withingroup' + + __visit_name__ = "withingroup" order_by = None @@ -3383,7 +3474,8 @@ class WithinGroup(ColumnElement): if order_by is not None: self.order_by = ClauseList( *util.to_list(order_by), - _literal_as_text=_literal_as_label_reference) + _literal_as_text=_literal_as_label_reference + ) def over(self, partition_by=None, order_by=None, range_=None, rows=None): """Produce an OVER clause against this :class:`.WithinGroup` @@ -3394,8 +3486,12 @@ class WithinGroup(ColumnElement): """ return Over( - self, partition_by=partition_by, order_by=order_by, - range_=range_, rows=rows) + self, + partition_by=partition_by, + order_by=order_by, + range_=range_, + rows=rows, + ) @util.memoized_property def type(self): @@ -3406,9 +3502,7 @@ class WithinGroup(ColumnElement): return self.element.type def get_children(self, **kwargs): - return [c for c in - (self.element, self.order_by) - if c is not None] + return [c for c in (self.element, self.order_by) if c is not None] def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) @@ -3417,11 +3511,15 @@ class WithinGroup(ColumnElement): @property def _from_objects(self): - return list(itertools.chain( - *[c._from_objects for c in - (self.element, self.order_by) - if c is not None] - )) + return list( + itertools.chain( + *[ + c._from_objects + for c in (self.element, self.order_by) + if c is not None + ] + ) + ) class FunctionFilter(ColumnElement): @@ -3443,7 +3541,8 @@ class FunctionFilter(ColumnElement): :meth:`.FunctionElement.filter` """ - __visit_name__ = 'funcfilter' + + __visit_name__ = "funcfilter" criterion = None @@ -3515,17 +3614,19 @@ class FunctionFilter(ColumnElement): """ return Over( - self, partition_by=partition_by, order_by=order_by, - range_=range_, rows=rows) + self, + partition_by=partition_by, + order_by=order_by, + range_=range_, + rows=rows, + ) @util.memoized_property def type(self): return self.func.type def get_children(self, **kwargs): - return [c for c in - (self.func, self.criterion) - if c is not None] + return [c for c in (self.func, self.criterion) if c is not None] def _copy_internals(self, clone=_clone, **kw): self.func = clone(self.func, **kw) @@ -3534,10 +3635,15 @@ class FunctionFilter(ColumnElement): @property def _from_objects(self): - return list(itertools.chain( - *[c._from_objects for c in (self.func, self.criterion) - if c is not None] - )) + return list( + itertools.chain( + *[ + c._from_objects + for c in (self.func, self.criterion) + if c is not None + ] + ) + ) class Label(ColumnElement): @@ -3548,7 +3654,7 @@ class Label(ColumnElement): """ - __visit_name__ = 'label' + __visit_name__ = "label" def __init__(self, name, element, type_=None): """Return a :class:`Label` object for the @@ -3577,7 +3683,7 @@ class Label(ColumnElement): self._resolve_label = self.name else: self.name = _anonymous_label( - '%%(%d %s)s' % (id(self), getattr(element, 'name', 'anon')) + "%%(%d %s)s" % (id(self), getattr(element, "name", "anon")) ) self.key = self._label = self._key_label = self.name @@ -3603,7 +3709,7 @@ class Label(ColumnElement): @util.memoized_property def type(self): return type_api.to_instance( - self._type or getattr(self._element, 'type', None) + self._type or getattr(self._element, "type", None) ) @util.memoized_property @@ -3619,9 +3725,7 @@ class Label(ColumnElement): def _apply_to_inner(self, fn, *arg, **kw): sub_element = fn(*arg, **kw) if sub_element is not self._element: - return Label(self.name, - sub_element, - type_=self._type) + return Label(self.name, sub_element, type_=self._type) else: return self @@ -3634,16 +3738,16 @@ class Label(ColumnElement): return self.element.foreign_keys def get_children(self, **kwargs): - return self.element, + return (self.element,) def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw): self._element = clone(self._element, **kw) - self.__dict__.pop('element', None) - self.__dict__.pop('_allow_label_resolve', None) + self.__dict__.pop("element", None) + self.__dict__.pop("_allow_label_resolve", None) if anonymize_labels: self.name = self._resolve_label = _anonymous_label( - '%%(%d %s)s' % ( - id(self), getattr(self.element, 'name', 'anon')) + "%%(%d %s)s" + % (id(self), getattr(self.element, "name", "anon")) ) self.key = self._label = self._key_label = self.name @@ -3652,8 +3756,9 @@ class Label(ColumnElement): return self.element._from_objects def _make_proxy(self, selectable, name=None, **kw): - e = self.element._make_proxy(selectable, - name=name if name else self.name) + e = self.element._make_proxy( + selectable, name=name if name else self.name + ) e._proxies.append(self) if self._type is not None: e.type = self._type @@ -3694,7 +3799,8 @@ class ColumnClause(Immutable, ColumnElement): :class:`.Column` """ - __visit_name__ = 'column' + + __visit_name__ = "column" onupdate = default = server_default = server_onupdate = None @@ -3792,25 +3898,33 @@ class ColumnClause(Immutable, ColumnElement): self.is_literal = is_literal def _compare_name_for_result(self, other): - if self.is_literal or \ - self.table is None or self.table._textual or \ - not hasattr(other, 'proxy_set') or ( - isinstance(other, ColumnClause) and - (other.is_literal or - other.table is None or - other.table._textual) - ): - return (hasattr(other, 'name') and self.name == other.name) or \ - (hasattr(other, '_label') and self._label == other._label) + if ( + self.is_literal + or self.table is None + or self.table._textual + or not hasattr(other, "proxy_set") + or ( + isinstance(other, ColumnClause) + and ( + other.is_literal + or other.table is None + or other.table._textual + ) + ) + ): + return (hasattr(other, "name") and self.name == other.name) or ( + hasattr(other, "_label") and self._label == other._label + ) else: return other.proxy_set.intersection(self.proxy_set) def _get_table(self): - return self.__dict__['table'] + return self.__dict__["table"] def _set_table(self, table): self._memoized_property.expire_instance(self) - self.__dict__['table'] = table + self.__dict__["table"] = table + table = property(_get_table, _set_table) @_memoized_property @@ -3826,7 +3940,7 @@ class ColumnClause(Immutable, ColumnElement): if util.py3k: return self.name else: - return self.name.encode('ascii', 'backslashreplace') + return self.name.encode("ascii", "backslashreplace") @_memoized_property def _key_label(self): @@ -3850,9 +3964,8 @@ class ColumnClause(Immutable, ColumnElement): return None elif t is not None and t.named_with_column: - if getattr(t, 'schema', None): - label = t.schema.replace('.', '_') + "_" + \ - t.name + "_" + name + if getattr(t, "schema", None): + label = t.schema.replace(".", "_") + "_" + t.name + "_" + name else: label = t.name + "_" + name @@ -3884,31 +3997,39 @@ class ColumnClause(Immutable, ColumnElement): return name def _bind_param(self, operator, obj, type_=None): - return BindParameter(self.key, obj, - _compared_to_operator=operator, - _compared_to_type=self.type, - type_=type_, - unique=True) - - def _make_proxy(self, selectable, name=None, attach=True, - name_is_truncatable=False, **kw): + return BindParameter( + self.key, + obj, + _compared_to_operator=operator, + _compared_to_type=self.type, + type_=type_, + unique=True, + ) + + def _make_proxy( + self, + selectable, + name=None, + attach=True, + name_is_truncatable=False, + **kw + ): # propagate the "is_literal" flag only if we are keeping our name, # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) c = self._constructor( - _as_truncated(name or self.name) if - name_is_truncatable else - (name or self.name), + _as_truncated(name or self.name) + if name_is_truncatable + else (name or self.name), type_=self.type, _selectable=selectable, - is_literal=is_literal + is_literal=is_literal, ) if name is None: c.key = self.key c._proxies = [self] if selectable._is_clone_of is not None: - c._is_clone_of = \ - selectable._is_clone_of.columns.get(c.key) + c._is_clone_of = selectable._is_clone_of.columns.get(c.key) if attach: selectable._columns[c.key] = c @@ -3924,24 +4045,25 @@ class CollationClause(ColumnElement): class _IdentifiedClause(Executable, ClauseElement): - __visit_name__ = 'identified' - _execution_options = \ - Executable._execution_options.union({'autocommit': False}) + __visit_name__ = "identified" + _execution_options = Executable._execution_options.union( + {"autocommit": False} + ) def __init__(self, ident): self.ident = ident class SavepointClause(_IdentifiedClause): - __visit_name__ = 'savepoint' + __visit_name__ = "savepoint" class RollbackToSavepointClause(_IdentifiedClause): - __visit_name__ = 'rollback_to_savepoint' + __visit_name__ = "rollback_to_savepoint" class ReleaseSavepointClause(_IdentifiedClause): - __visit_name__ = 'release_savepoint' + __visit_name__ = "release_savepoint" class quoted_name(util.MemoizedSlots, util.text_type): @@ -3992,7 +4114,7 @@ class quoted_name(util.MemoizedSlots, util.text_type): """ - __slots__ = 'quote', 'lower', 'upper' + __slots__ = "quote", "lower", "upper" def __new__(cls, value, quote): if value is None: @@ -4026,9 +4148,9 @@ class quoted_name(util.MemoizedSlots, util.text_type): return util.text_type(self).upper() def __repr__(self): - backslashed = self.encode('ascii', 'backslashreplace') + backslashed = self.encode("ascii", "backslashreplace") if not util.py2k: - backslashed = backslashed.decode('ascii') + backslashed = backslashed.decode("ascii") return "'%s'" % backslashed @@ -4094,6 +4216,7 @@ class conv(_truncated_label): :ref:`constraint_naming_conventions` """ + __slots__ = () @@ -4102,6 +4225,7 @@ class _defer_name(_truncated_label): generation. """ + __slots__ = () def __new__(cls, value): @@ -4113,13 +4237,15 @@ class _defer_name(_truncated_label): return super(_defer_name, cls).__new__(cls, value) def __reduce__(self): - return self.__class__, (util.text_type(self), ) + return self.__class__, (util.text_type(self),) class _defer_none_name(_defer_name): """indicate a 'deferred' name that was ultimately the value None.""" + __slots__ = () + _NONE_NAME = _defer_none_name("_unnamed_") # for backwards compatibility in case @@ -4138,15 +4264,15 @@ class _anonymous_label(_truncated_label): def __add__(self, other): return _anonymous_label( quoted_name( - util.text_type.__add__(self, util.text_type(other)), - self.quote) + util.text_type.__add__(self, util.text_type(other)), self.quote + ) ) def __radd__(self, other): return _anonymous_label( quoted_name( - util.text_type.__add__(util.text_type(other), self), - self.quote) + util.text_type.__add__(util.text_type(other), self), self.quote + ) ) def apply_map(self, map_): @@ -4206,20 +4332,23 @@ def _cloned_intersection(a, b): """ all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set(elem for elem in a - if all_overlap.intersection(elem._cloned_set)) + return set( + elem for elem in a if all_overlap.intersection(elem._cloned_set) + ) def _cloned_difference(a, b): all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set(elem for elem in a - if not all_overlap.intersection(elem._cloned_set)) + return set( + elem for elem in a if not all_overlap.intersection(elem._cloned_set) + ) @util.dependencies("sqlalchemy.sql.functions") def _labeled(functions, element): - if not hasattr(element, 'name') or \ - isinstance(element, functions.FunctionElement): + if not hasattr(element, "name") or isinstance( + element, functions.FunctionElement + ): return element.label(None) else: return element @@ -4235,7 +4364,7 @@ def _find_columns(clause): """locate Column objects within the given expression.""" cols = util.column_set() - traverse(clause, {}, {'column': cols.add}) + traverse(clause, {}, {"column": cols.add}) return cols @@ -4253,7 +4382,7 @@ def _find_columns(clause): def _column_as_key(element): if isinstance(element, util.string_types): return element - if hasattr(element, '__clause_element__'): + if hasattr(element, "__clause_element__"): element = element.__clause_element__() try: return element.key @@ -4262,7 +4391,7 @@ def _column_as_key(element): def _clause_element_as_expr(element): - if hasattr(element, '__clause_element__'): + if hasattr(element, "__clause_element__"): return element.__clause_element__() else: return element @@ -4272,7 +4401,7 @@ def _literal_as_label_reference(element): if isinstance(element, util.string_types): return _textual_label_reference(element) - elif hasattr(element, '__clause_element__'): + elif hasattr(element, "__clause_element__"): element = element.__clause_element__() return _literal_as_text(element) @@ -4282,11 +4411,13 @@ def _literal_and_labels_as_label_reference(element): if isinstance(element, util.string_types): return _textual_label_reference(element) - elif hasattr(element, '__clause_element__'): + elif hasattr(element, "__clause_element__"): element = element.__clause_element__() - if isinstance(element, ColumnElement) and \ - element._order_by_label_element is not None: + if ( + isinstance(element, ColumnElement) + and element._order_by_label_element is not None + ): return _label_reference(element) else: return _literal_as_text(element) @@ -4299,14 +4430,15 @@ def _expression_literal_as_text(element): def _literal_as_text(element, warn=False): if isinstance(element, Visitable): return element - elif hasattr(element, '__clause_element__'): + elif hasattr(element, "__clause_element__"): return element.__clause_element__() elif isinstance(element, util.string_types): if warn: util.warn_limited( "Textual SQL expression %(expr)r should be " "explicitly declared as text(%(expr)r)", - {"expr": util.ellipses_string(element)}) + {"expr": util.ellipses_string(element)}, + ) return TextClause(util.text_type(element)) elif isinstance(element, (util.NoneType, bool)): @@ -4319,20 +4451,23 @@ def _literal_as_text(element, warn=False): def _no_literals(element): - if hasattr(element, '__clause_element__'): + if hasattr(element, "__clause_element__"): return element.__clause_element__() elif not isinstance(element, Visitable): - raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' " - "function to indicate a SQL expression " - "literal, or 'literal()' to indicate a " - "bound value." % (element, )) + raise exc.ArgumentError( + "Ambiguous literal: %r. Use the 'text()' " + "function to indicate a SQL expression " + "literal, or 'literal()' to indicate a " + "bound value." % (element,) + ) else: return element def _is_literal(element): - return not isinstance(element, Visitable) and \ - not hasattr(element, '__clause_element__') + return not isinstance(element, Visitable) and not hasattr( + element, "__clause_element__" + ) def _only_column_elements_or_none(element, name): @@ -4343,17 +4478,18 @@ def _only_column_elements_or_none(element, name): def _only_column_elements(element, name): - if hasattr(element, '__clause_element__'): + if hasattr(element, "__clause_element__"): element = element.__clause_element__() if not isinstance(element, ColumnElement): raise exc.ArgumentError( "Column-based expression object expected for argument " - "'%s'; got: '%s', type %s" % (name, element, type(element))) + "'%s'; got: '%s', type %s" % (name, element, type(element)) + ) return element def _literal_as_binds(element, name=None, type_=None): - if hasattr(element, '__clause_element__'): + if hasattr(element, "__clause_element__"): return element.__clause_element__() elif not isinstance(element, Visitable): if element is None: @@ -4363,13 +4499,14 @@ def _literal_as_binds(element, name=None, type_=None): else: return element -_guess_straight_column = re.compile(r'^\w\S*$', re.I) + +_guess_straight_column = re.compile(r"^\w\S*$", re.I) def _interpret_as_column_or_from(element): if isinstance(element, Visitable): return element - elif hasattr(element, '__clause_element__'): + elif hasattr(element, "__clause_element__"): return element.__clause_element__() insp = inspection.inspect(element, raiseerr=False) @@ -4399,11 +4536,11 @@ def _interpret_as_column_or_from(element): { "column": util.ellipses_string(element), "literal_column": "literal_column" - if guess_is_literal else "column" - }) - return ColumnClause( - element, - is_literal=guess_is_literal) + if guess_is_literal + else "column", + }, + ) + return ColumnClause(element, is_literal=guess_is_literal) def _const_expr(element): @@ -4416,9 +4553,7 @@ def _const_expr(element): elif element is True: return True_() else: - raise exc.ArgumentError( - "Expected None, False, or True" - ) + raise exc.ArgumentError("Expected None, False, or True") def _type_from_args(args): @@ -4429,18 +4564,15 @@ def _type_from_args(args): return type_api.NULLTYPE -def _corresponding_column_or_error(fromclause, column, - require_embedded=False): - c = fromclause.corresponding_column(column, - require_embedded=require_embedded) +def _corresponding_column_or_error(fromclause, column, require_embedded=False): + c = fromclause.corresponding_column( + column, require_embedded=require_embedded + ) if c is None: raise exc.InvalidRequestError( "Given column '%s', attached to table '%s', " "failed to locate a corresponding column from table '%s'" - % - (column, - getattr(column, 'table', None), - fromclause.description) + % (column, getattr(column, "table", None), fromclause.description) ) return c @@ -4449,7 +4581,7 @@ class AnnotatedColumnElement(Annotated): def __init__(self, element, values): Annotated.__init__(self, element, values) ColumnElement.comparator._reset(self) - for attr in ('name', 'key', 'table'): + for attr in ("name", "key", "table"): if self.__dict__.get(attr, False) is None: self.__dict__.pop(attr) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index b69b6ee8c..aab9f46d4 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -15,43 +15,142 @@ class. """ __all__ = [ - 'Alias', 'any_', 'all_', 'ClauseElement', 'ColumnCollection', 'ColumnElement', - 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 'Lateral', - 'Select', - 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', 'between', - 'bindparam', 'case', 'cast', 'column', 'delete', 'desc', 'distinct', - 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier', - 'collate', 'insert', 'intersect', 'intersect_all', 'join', 'label', - 'lateral', 'literal', 'literal_column', 'not_', 'null', 'nullsfirst', - 'nullslast', - 'or_', 'outparam', 'outerjoin', 'over', 'select', 'subquery', - 'table', 'text', - 'tuple_', 'type_coerce', 'quoted_name', 'union', 'union_all', 'update', - 'within_group', - 'TableSample', 'tablesample'] + "Alias", + "any_", + "all_", + "ClauseElement", + "ColumnCollection", + "ColumnElement", + "CompoundSelect", + "Delete", + "FromClause", + "Insert", + "Join", + "Lateral", + "Select", + "Selectable", + "TableClause", + "Update", + "alias", + "and_", + "asc", + "between", + "bindparam", + "case", + "cast", + "column", + "delete", + "desc", + "distinct", + "except_", + "except_all", + "exists", + "extract", + "func", + "modifier", + "collate", + "insert", + "intersect", + "intersect_all", + "join", + "label", + "lateral", + "literal", + "literal_column", + "not_", + "null", + "nullsfirst", + "nullslast", + "or_", + "outparam", + "outerjoin", + "over", + "select", + "subquery", + "table", + "text", + "tuple_", + "type_coerce", + "quoted_name", + "union", + "union_all", + "update", + "within_group", + "TableSample", + "tablesample", +] from .visitors import Visitable from .functions import func, modifier, FunctionElement, Function from ..util.langhelpers import public_factory -from .elements import ClauseElement, ColumnElement,\ - BindParameter, CollectionAggregate, UnaryExpression, BooleanClauseList, \ - Label, Cast, Case, ColumnClause, TextClause, Over, Null, \ - True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \ - Grouping, WithinGroup, not_, quoted_name, \ - collate, literal_column, between,\ - literal, outparam, TypeCoerce, ClauseList, FunctionFilter +from .elements import ( + ClauseElement, + ColumnElement, + BindParameter, + CollectionAggregate, + UnaryExpression, + BooleanClauseList, + Label, + Cast, + Case, + ColumnClause, + TextClause, + Over, + Null, + True_, + False_, + BinaryExpression, + Tuple, + TypeClause, + Extract, + Grouping, + WithinGroup, + not_, + quoted_name, + collate, + literal_column, + between, + literal, + outparam, + TypeCoerce, + ClauseList, + FunctionFilter, +) -from .elements import SavepointClause, RollbackToSavepointClause, \ - ReleaseSavepointClause +from .elements import ( + SavepointClause, + RollbackToSavepointClause, + ReleaseSavepointClause, +) -from .base import ColumnCollection, Generative, Executable, \ - PARSE_AUTOCOMMIT +from .base import ColumnCollection, Generative, Executable, PARSE_AUTOCOMMIT -from .selectable import Alias, Join, Select, Selectable, TableClause, \ - CompoundSelect, CTE, FromClause, FromGrouping, Lateral, SelectBase, \ - alias, GenerativeSelect, subquery, HasCTE, HasPrefixes, HasSuffixes, \ - lateral, Exists, ScalarSelect, TextAsFrom, TableSample, tablesample +from .selectable import ( + Alias, + Join, + Select, + Selectable, + TableClause, + CompoundSelect, + CTE, + FromClause, + FromGrouping, + Lateral, + SelectBase, + alias, + GenerativeSelect, + subquery, + HasCTE, + HasPrefixes, + HasSuffixes, + lateral, + Exists, + ScalarSelect, + TextAsFrom, + TableSample, + tablesample, +) from .dml import Insert, Update, Delete, UpdateBase, ValuesBase @@ -79,23 +178,30 @@ extract = public_factory(Extract, ".expression.extract") tuple_ = public_factory(Tuple, ".expression.tuple_") except_ = public_factory(CompoundSelect._create_except, ".expression.except_") except_all = public_factory( - CompoundSelect._create_except_all, ".expression.except_all") + CompoundSelect._create_except_all, ".expression.except_all" +) intersect = public_factory( - CompoundSelect._create_intersect, ".expression.intersect") + CompoundSelect._create_intersect, ".expression.intersect" +) intersect_all = public_factory( - CompoundSelect._create_intersect_all, ".expression.intersect_all") + CompoundSelect._create_intersect_all, ".expression.intersect_all" +) union = public_factory(CompoundSelect._create_union, ".expression.union") union_all = public_factory( - CompoundSelect._create_union_all, ".expression.union_all") + CompoundSelect._create_union_all, ".expression.union_all" +) exists = public_factory(Exists, ".expression.exists") nullsfirst = public_factory( - UnaryExpression._create_nullsfirst, ".expression.nullsfirst") + UnaryExpression._create_nullsfirst, ".expression.nullsfirst" +) nullslast = public_factory( - UnaryExpression._create_nullslast, ".expression.nullslast") + UnaryExpression._create_nullslast, ".expression.nullslast" +) asc = public_factory(UnaryExpression._create_asc, ".expression.asc") desc = public_factory(UnaryExpression._create_desc, ".expression.desc") distinct = public_factory( - UnaryExpression._create_distinct, ".expression.distinct") + UnaryExpression._create_distinct, ".expression.distinct" +) type_coerce = public_factory(TypeCoerce, ".expression.type_coerce") true = public_factory(True_._instance, ".expression.true") false = public_factory(False_._instance, ".expression.false") @@ -105,19 +211,30 @@ outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin") insert = public_factory(Insert, ".expression.insert") update = public_factory(Update, ".expression.update") delete = public_factory(Delete, ".expression.delete") -funcfilter = public_factory( - FunctionFilter, ".expression.funcfilter") +funcfilter = public_factory(FunctionFilter, ".expression.funcfilter") # internal functions still being called from tests and the ORM, # these might be better off in some other namespace from .base import _from_objects -from .elements import _literal_as_text, _clause_element_as_expr,\ - _is_column, _labeled, _only_column_elements, _string_or_unprintable, \ - _truncated_label, _clone, _cloned_difference, _cloned_intersection,\ - _column_as_key, _literal_as_binds, _select_iterables, \ - _corresponding_column_or_error, _literal_as_label_reference, \ - _expression_literal_as_text +from .elements import ( + _literal_as_text, + _clause_element_as_expr, + _is_column, + _labeled, + _only_column_elements, + _string_or_unprintable, + _truncated_label, + _clone, + _cloned_difference, + _cloned_intersection, + _column_as_key, + _literal_as_binds, + _select_iterables, + _corresponding_column_or_error, + _literal_as_label_reference, + _expression_literal_as_text, +) from .selectable import _interpret_as_from diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 4b4d2d463..883bb8cc3 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -10,10 +10,22 @@ """ from . import sqltypes, schema from .base import Executable, ColumnCollection -from .elements import ClauseList, Cast, Extract, _literal_as_binds, \ - literal_column, _type_from_args, ColumnElement, _clone,\ - Over, BindParameter, FunctionFilter, Grouping, WithinGroup, \ - BinaryExpression +from .elements import ( + ClauseList, + Cast, + Extract, + _literal_as_binds, + literal_column, + _type_from_args, + ColumnElement, + _clone, + Over, + BindParameter, + FunctionFilter, + Grouping, + WithinGroup, + BinaryExpression, +) from .selectable import FromClause, Select, Alias from . import util as sqlutil from . import operators @@ -62,9 +74,8 @@ class FunctionElement(Executable, ColumnElement, FromClause): args = [_literal_as_binds(c, self.name) for c in clauses] self._has_args = self._has_args or bool(args) self.clause_expr = ClauseList( - operator=operators.comma_op, - group_contents=True, *args).\ - self_group() + operator=operators.comma_op, group_contents=True, *args + ).self_group() def _execute_on_connection(self, connection, multiparams, params): return connection._execute_function(self, multiparams, params) @@ -123,7 +134,7 @@ class FunctionElement(Executable, ColumnElement, FromClause): partition_by=partition_by, order_by=order_by, rows=rows, - range_=range_ + range_=range_, ) def within_group(self, *order_by): @@ -233,16 +244,14 @@ class FunctionElement(Executable, ColumnElement, FromClause): .. versionadded:: 1.3 """ - return FunctionAsBinary( - self, left_index, right_index - ) + return FunctionAsBinary(self, left_index, right_index) @property def _from_objects(self): return self.clauses._from_objects def get_children(self, **kwargs): - return self.clause_expr, + return (self.clause_expr,) def _copy_internals(self, clone=_clone, **kw): self.clause_expr = clone(self.clause_expr, **kw) @@ -336,24 +345,29 @@ class FunctionElement(Executable, ColumnElement, FromClause): return self.select().execute() def _bind_param(self, operator, obj, type_=None): - return BindParameter(None, obj, _compared_to_operator=operator, - _compared_to_type=self.type, unique=True, - type_=type_) + return BindParameter( + None, + obj, + _compared_to_operator=operator, + _compared_to_type=self.type, + unique=True, + type_=type_, + ) def self_group(self, against=None): # for the moment, we are parenthesizing all array-returning # expressions against getitem. This may need to be made # more portable if in the future we support other DBs # besides postgresql. - if against is operators.getitem and \ - isinstance(self.type, sqltypes.ARRAY): + if against is operators.getitem and isinstance( + self.type, sqltypes.ARRAY + ): return Grouping(self) else: return super(FunctionElement, self).self_group(against=against) class FunctionAsBinary(BinaryExpression): - def __init__(self, fn, left_index, right_index): left = fn.clauses.clauses[left_index - 1] right = fn.clauses.clauses[right_index - 1] @@ -362,8 +376,11 @@ class FunctionAsBinary(BinaryExpression): self.right_index = right_index super(FunctionAsBinary, self).__init__( - left, right, operators.function_as_comparison_op, - type_=sqltypes.BOOLEANTYPE) + left, + right, + operators.function_as_comparison_op, + type_=sqltypes.BOOLEANTYPE, + ) @property def left(self): @@ -382,7 +399,7 @@ class FunctionAsBinary(BinaryExpression): self.sql_function.clauses.clauses[self.right_index - 1] = value def _copy_internals(self, **kw): - clone = kw.pop('clone') + clone = kw.pop("clone") self.sql_function = clone(self.sql_function, **kw) super(FunctionAsBinary, self)._copy_internals(**kw) @@ -396,13 +413,13 @@ class _FunctionGenerator(object): def __getattr__(self, name): # passthru __ attributes; fixes pydoc - if name.startswith('__'): + if name.startswith("__"): try: return self.__dict__[name] except KeyError: raise AttributeError(name) - elif name.endswith('_'): + elif name.endswith("_"): name = name[0:-1] f = _FunctionGenerator(**self.opts) f.__names = list(self.__names) + [name] @@ -426,8 +443,9 @@ class _FunctionGenerator(object): if func is not None: return func(*c, **o) - return Function(self.__names[-1], - packagenames=self.__names[0:-1], *c, **o) + return Function( + self.__names[-1], packagenames=self.__names[0:-1], *c, **o + ) func = _FunctionGenerator() @@ -523,7 +541,7 @@ class Function(FunctionElement): """ - __visit_name__ = 'function' + __visit_name__ = "function" def __init__(self, name, *clauses, **kw): """Construct a :class:`.Function`. @@ -532,30 +550,33 @@ class Function(FunctionElement): new :class:`.Function` instances. """ - self.packagenames = kw.pop('packagenames', None) or [] + self.packagenames = kw.pop("packagenames", None) or [] self.name = name - self._bind = kw.get('bind', None) - self.type = sqltypes.to_instance(kw.get('type_', None)) + self._bind = kw.get("bind", None) + self.type = sqltypes.to_instance(kw.get("type_", None)) FunctionElement.__init__(self, *clauses, **kw) def _bind_param(self, operator, obj, type_=None): - return BindParameter(self.name, obj, - _compared_to_operator=operator, - _compared_to_type=self.type, - type_=type_, - unique=True) + return BindParameter( + self.name, + obj, + _compared_to_operator=operator, + _compared_to_type=self.type, + type_=type_, + unique=True, + ) class _GenericMeta(VisitableType): def __init__(cls, clsname, bases, clsdict): if annotation.Annotated not in cls.__mro__: - cls.name = name = clsdict.get('name', clsname) - cls.identifier = identifier = clsdict.get('identifier', name) - package = clsdict.pop('package', '_default') + cls.name = name = clsdict.get("name", clsname) + cls.identifier = identifier = clsdict.get("identifier", name) + package = clsdict.pop("package", "_default") # legacy - if '__return_type__' in clsdict: - cls.type = clsdict['__return_type__'] + if "__return_type__" in clsdict: + cls.type = clsdict["__return_type__"] register_function(identifier, cls, package) super(_GenericMeta, cls).__init__(clsname, bases, clsdict) @@ -635,17 +656,19 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): coerce_arguments = True def __init__(self, *args, **kwargs): - parsed_args = kwargs.pop('_parsed_args', None) + parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: parsed_args = [_literal_as_binds(c, self.name) for c in args] self._has_args = self._has_args or bool(parsed_args) self.packagenames = [] - self._bind = kwargs.get('bind', None) + self._bind = kwargs.get("bind", None) self.clause_expr = ClauseList( - operator=operators.comma_op, - group_contents=True, *parsed_args).self_group() + operator=operators.comma_op, group_contents=True, *parsed_args + ).self_group() self.type = sqltypes.to_instance( - kwargs.pop("type_", None) or getattr(self, 'type', None)) + kwargs.pop("type_", None) or getattr(self, "type", None) + ) + register_function("cast", Cast) register_function("extract", Extract) @@ -660,13 +683,15 @@ class next_value(GenericFunction): that does not provide support for sequences. """ + type = sqltypes.Integer() name = "next_value" def __init__(self, seq, **kw): - assert isinstance(seq, schema.Sequence), \ - "next_value() accepts a Sequence object as input." - self._bind = kw.get('bind', None) + assert isinstance( + seq, schema.Sequence + ), "next_value() accepts a Sequence object as input." + self._bind = kw.get("bind", None) self.sequence = seq @property @@ -684,8 +709,8 @@ class ReturnTypeFromArgs(GenericFunction): def __init__(self, *args, **kwargs): args = [_literal_as_binds(c, self.name) for c in args] - kwargs.setdefault('type_', _type_from_args(args)) - kwargs['_parsed_args'] = args + kwargs.setdefault("type_", _type_from_args(args)) + kwargs["_parsed_args"] = args super(ReturnTypeFromArgs, self).__init__(*args, **kwargs) @@ -733,7 +758,7 @@ class count(GenericFunction): def __init__(self, expression=None, **kwargs): if expression is None: - expression = literal_column('*') + expression = literal_column("*") super(count, self).__init__(expression, **kwargs) @@ -797,15 +822,15 @@ class array_agg(GenericFunction): def __init__(self, *args, **kwargs): args = [_literal_as_binds(c) for c in args] - default_array_type = kwargs.pop('_default_array_type', sqltypes.ARRAY) - if 'type_' not in kwargs: + default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY) + if "type_" not in kwargs: type_from_args = _type_from_args(args) if isinstance(type_from_args, sqltypes.ARRAY): - kwargs['type_'] = type_from_args + kwargs["type_"] = type_from_args else: - kwargs['type_'] = default_array_type(type_from_args) - kwargs['_parsed_args'] = args + kwargs["type_"] = default_array_type(type_from_args) + kwargs["_parsed_args"] = args super(array_agg, self).__init__(*args, **kwargs) @@ -883,6 +908,7 @@ class rank(GenericFunction): .. versionadded:: 1.1 """ + type = sqltypes.Integer() @@ -897,6 +923,7 @@ class dense_rank(GenericFunction): .. versionadded:: 1.1 """ + type = sqltypes.Integer() @@ -911,6 +938,7 @@ class percent_rank(GenericFunction): .. versionadded:: 1.1 """ + type = sqltypes.Numeric() @@ -925,6 +953,7 @@ class cume_dist(GenericFunction): .. versionadded:: 1.1 """ + type = sqltypes.Numeric() diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py index 0107ce724..144cc4dfc 100644 --- a/lib/sqlalchemy/sql/naming.py +++ b/lib/sqlalchemy/sql/naming.py @@ -10,8 +10,16 @@ """ -from .schema import Constraint, ForeignKeyConstraint, PrimaryKeyConstraint, \ - UniqueConstraint, CheckConstraint, Index, Table, Column +from .schema import ( + Constraint, + ForeignKeyConstraint, + PrimaryKeyConstraint, + UniqueConstraint, + CheckConstraint, + Index, + Table, + Column, +) from .. import event, events from .. import exc from .elements import _truncated_label, _defer_name, _defer_none_name, conv @@ -19,7 +27,6 @@ import re class ConventionDict(object): - def __init__(self, const, table, convention): self.const = const self._is_fk = isinstance(const, ForeignKeyConstraint) @@ -79,8 +86,8 @@ class ConventionDict(object): def __getitem__(self, key): if key in self.convention: return self.convention[key](self.const, self.table) - elif hasattr(self, '_key_%s' % key): - return getattr(self, '_key_%s' % key)() + elif hasattr(self, "_key_%s" % key): + return getattr(self, "_key_%s" % key)() else: col_template = re.match(r".*_?column_(\d+)(_?N)?_.+", key) if col_template: @@ -108,12 +115,13 @@ class ConventionDict(object): return getattr(self, attr)(idx) raise KeyError(key) + _prefix_dict = { Index: "ix", PrimaryKeyConstraint: "pk", CheckConstraint: "ck", UniqueConstraint: "uq", - ForeignKeyConstraint: "fk" + ForeignKeyConstraint: "fk", } @@ -134,15 +142,18 @@ def _constraint_name_for_table(const, table): if isinstance(const.name, conv): return const.name - elif convention is not None and \ - not isinstance(const.name, conv) and \ - ( - const.name is None or - "constraint_name" in convention or - isinstance(const.name, _defer_name)): + elif ( + convention is not None + and not isinstance(const.name, conv) + and ( + const.name is None + or "constraint_name" in convention + or isinstance(const.name, _defer_name) + ) + ): return conv( - convention % ConventionDict(const, table, - metadata.naming_convention) + convention + % ConventionDict(const, table, metadata.naming_convention) ) elif isinstance(convention, _defer_none_name): return None @@ -155,9 +166,11 @@ def _constraint_name(const, table): # for column-attached constraint, set another event # to link the column attached to the table as this constraint # associated with the table. - event.listen(table, "after_parent_attach", - lambda col, table: _constraint_name(const, table) - ) + event.listen( + table, + "after_parent_attach", + lambda col, table: _constraint_name(const, table), + ) elif isinstance(table, Table): if isinstance(const.name, (conv, _defer_name)): return diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 5b4a28a06..2b843d751 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -13,8 +13,25 @@ from .. import util from operator import ( - and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq, neg, - getitem, lshift, rshift, contains + and_, + or_, + inv, + add, + mul, + sub, + mod, + truediv, + lt, + le, + ne, + gt, + ge, + eq, + neg, + getitem, + lshift, + rshift, + contains, ) if util.py2k: @@ -37,6 +54,7 @@ class Operators(object): :class:`.ColumnOperators`. """ + __slots__ = () def __and__(self, other): @@ -105,8 +123,8 @@ class Operators(object): return self.operate(inv) def op( - self, opstring, precedence=0, is_comparison=False, - return_type=None): + self, opstring, precedence=0, is_comparison=False, return_type=None + ): """produce a generic operator function. e.g.:: @@ -168,6 +186,7 @@ class Operators(object): def against(other): return operator(self, other) + return against def bool_op(self, opstring, precedence=0): @@ -247,12 +266,18 @@ class custom_op(object): :meth:`.Operators.bool_op` """ - __name__ = 'custom_op' + + __name__ = "custom_op" def __init__( - self, opstring, precedence=0, is_comparison=False, - return_type=None, natural_self_precedent=False, - eager_grouping=False): + self, + opstring, + precedence=0, + is_comparison=False, + return_type=None, + natural_self_precedent=False, + eager_grouping=False, + ): self.opstring = opstring self.precedence = precedence self.is_comparison = is_comparison @@ -263,8 +288,7 @@ class custom_op(object): ) def __eq__(self, other): - return isinstance(other, custom_op) and \ - other.opstring == self.opstring + return isinstance(other, custom_op) and other.opstring == self.opstring def __hash__(self): return id(self) @@ -1138,6 +1162,7 @@ class ColumnOperators(Operators): """ return self.reverse_operate(truediv, other) + _commutative = {eq, ne, add, mul} _comparison = {eq, ne, lt, gt, ge, le} @@ -1261,20 +1286,18 @@ def _escaped_like_impl(fn, other, escape, autoescape): if autoescape: if autoescape is not True: util.warn( - "The autoescape parameter is now a simple boolean True/False") + "The autoescape parameter is now a simple boolean True/False" + ) if escape is None: - escape = '/' + escape = "/" if not isinstance(other, util.compat.string_types): raise TypeError("String value expected when autoescape=True") - if escape not in ('%', '_'): + if escape not in ("%", "_"): other = other.replace(escape, escape + escape) - other = ( - other.replace('%', escape + '%'). - replace('_', escape + '_') - ) + other = other.replace("%", escape + "%").replace("_", escape + "_") return fn(other, escape=escape) @@ -1362,8 +1385,7 @@ def json_path_getitem_op(a, b): def is_comparison(op): - return op in _comparison or \ - isinstance(op, custom_op) and op.is_comparison + return op in _comparison or isinstance(op, custom_op) and op.is_comparison def is_commutative(op): @@ -1371,13 +1393,16 @@ def is_commutative(op): def is_ordering_modifier(op): - return op in (asc_op, desc_op, - nullsfirst_op, nullslast_op) + return op in (asc_op, desc_op, nullsfirst_op, nullslast_op) def is_natural_self_precedent(op): - return op in _natural_self_precedent or \ - isinstance(op, custom_op) and op.natural_self_precedent + return ( + op in _natural_self_precedent + or isinstance(op, custom_op) + and op.natural_self_precedent + ) + _booleans = (inv, istrue, isfalse, and_, or_) @@ -1385,12 +1410,8 @@ _booleans = (inv, istrue, isfalse, and_, or_) def is_boolean(op): return is_comparison(op) or op in _booleans -_mirror = { - gt: lt, - ge: le, - lt: gt, - le: ge -} + +_mirror = {gt: lt, ge: le, lt: gt, le: ge} def mirror(op): @@ -1404,17 +1425,18 @@ def mirror(op): _associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne]) -_natural_self_precedent = _associative.union([ - getitem, json_getitem_op, json_path_getitem_op]) +_natural_self_precedent = _associative.union( + [getitem, json_getitem_op, json_path_getitem_op] +) """Operators where if we have (a op b) op c, we don't want to parenthesize (a op b). """ -_asbool = util.symbol('_asbool', canonical=-10) -_smallest = util.symbol('_smallest', canonical=-100) -_largest = util.symbol('_largest', canonical=100) +_asbool = util.symbol("_asbool", canonical=-10) +_smallest = util.symbol("_smallest", canonical=-100) +_largest = util.symbol("_largest", canonical=100) _PRECEDENCE = { from_: 15, @@ -1424,7 +1446,6 @@ _PRECEDENCE = { getitem: 15, json_getitem_op: 15, json_path_getitem_op: 15, - mul: 8, truediv: 8, div: 8, @@ -1432,22 +1453,17 @@ _PRECEDENCE = { neg: 8, add: 7, sub: 7, - concat_op: 6, - match_op: 5, notmatch_op: 5, - ilike_op: 5, notilike_op: 5, like_op: 5, notlike_op: 5, in_op: 5, notin_op: 5, - is_: 5, isnot: 5, - eq: 5, ne: 5, is_distinct_from: 5, @@ -1458,7 +1474,6 @@ _PRECEDENCE = { lt: 5, ge: 5, le: 5, - between_op: 5, notbetween_op: 5, distinct_op: 5, @@ -1468,17 +1483,14 @@ _PRECEDENCE = { and_: 3, or_: 2, comma_op: -1, - desc_op: 3, asc_op: 3, collate: 4, - as_: -1, exists: 0, - _asbool: -10, _smallest: _smallest, - _largest: _largest + _largest: _largest, } @@ -1486,7 +1498,6 @@ def is_precedent(operator, against): if operator is against and is_natural_self_precedent(operator): return False else: - return (_PRECEDENCE.get(operator, - getattr(operator, 'precedence', _smallest)) <= - _PRECEDENCE.get(against, - getattr(against, 'precedence', _largest))) + return _PRECEDENCE.get( + operator, getattr(operator, "precedence", _smallest) + ) <= _PRECEDENCE.get(against, getattr(against, "precedence", _largest)) diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 3e9aa174a..d6c3f5000 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -36,25 +36,31 @@ import operator from . import visitors from . import type_api from .base import _bind_or_error, ColumnCollection -from .elements import ClauseElement, ColumnClause, \ - _as_truncated, TextClause, _literal_as_text,\ - ColumnElement, quoted_name +from .elements import ( + ClauseElement, + ColumnClause, + _as_truncated, + TextClause, + _literal_as_text, + ColumnElement, + quoted_name, +) from .selectable import TableClause import collections import sqlalchemy from . import ddl -RETAIN_SCHEMA = util.symbol('retain_schema') +RETAIN_SCHEMA = util.symbol("retain_schema") BLANK_SCHEMA = util.symbol( - 'blank_schema', + "blank_schema", """Symbol indicating that a :class:`.Table` or :class:`.Sequence` should have 'None' for its schema, even if the parent :class:`.MetaData` has specified a schema. .. versionadded:: 1.0.14 - """ + """, ) @@ -69,11 +75,15 @@ def _get_table_key(name, schema): # break an import cycle def _copy_expression(expression, source_table, target_table): def replace(col): - if isinstance(col, Column) and \ - col.table is source_table and col.key in source_table.c: + if ( + isinstance(col, Column) + and col.table is source_table + and col.key in source_table.c + ): return target_table.c[col.key] else: return None + return visitors.replacement_traverse(expression, {}, replace) @@ -81,7 +91,7 @@ def _copy_expression(expression, source_table, target_table): class SchemaItem(SchemaEventTarget, visitors.Visitable): """Base class for items that define a database schema.""" - __visit_name__ = 'schema_item' + __visit_name__ = "schema_item" def _init_items(self, *args): """Initialize the list of child items for this SchemaItem.""" @@ -95,10 +105,10 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): return [] def __repr__(self): - return util.generic_repr(self, omit_kwarg=['info']) + return util.generic_repr(self, omit_kwarg=["info"]) @property - @util.deprecated('0.9', 'Use ``<obj>.name.quote``') + @util.deprecated("0.9", "Use ``<obj>.name.quote``") def quote(self): """Return the value of the ``quote`` flag passed to this schema object, for those schema items which @@ -121,7 +131,7 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): return {} def _schema_item_copy(self, schema_item): - if 'info' in self.__dict__: + if "info" in self.__dict__: schema_item.info = self.info.copy() schema_item.dispatch._update(self.dispatch) return schema_item @@ -396,7 +406,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): """ - __visit_name__ = 'table' + __visit_name__ = "table" def __new__(cls, *args, **kw): if not args: @@ -408,26 +418,26 @@ class Table(DialectKWArgs, SchemaItem, TableClause): except IndexError: raise TypeError("Table() takes at least two arguments") - schema = kw.get('schema', None) + schema = kw.get("schema", None) if schema is None: schema = metadata.schema elif schema is BLANK_SCHEMA: schema = None - keep_existing = kw.pop('keep_existing', False) - extend_existing = kw.pop('extend_existing', False) - if 'useexisting' in kw: + keep_existing = kw.pop("keep_existing", False) + extend_existing = kw.pop("extend_existing", False) + if "useexisting" in kw: msg = "useexisting is deprecated. Use extend_existing." util.warn_deprecated(msg) if extend_existing: msg = "useexisting is synonymous with extend_existing." raise exc.ArgumentError(msg) - extend_existing = kw.pop('useexisting', False) + extend_existing = kw.pop("useexisting", False) if keep_existing and extend_existing: msg = "keep_existing and extend_existing are mutually exclusive." raise exc.ArgumentError(msg) - mustexist = kw.pop('mustexist', False) + mustexist = kw.pop("mustexist", False) key = _get_table_key(name, schema) if key in metadata.tables: if not keep_existing and not extend_existing and bool(args): @@ -436,15 +446,15 @@ class Table(DialectKWArgs, SchemaItem, TableClause): "instance. Specify 'extend_existing=True' " "to redefine " "options and columns on an " - "existing Table object." % key) + "existing Table object." % key + ) table = metadata.tables[key] if extend_existing: table._init_existing(*args, **kw) return table else: if mustexist: - raise exc.InvalidRequestError( - "Table '%s' not defined" % (key)) + raise exc.InvalidRequestError("Table '%s' not defined" % (key)) table = object.__new__(cls) table.dispatch.before_parent_attach(table, metadata) metadata._add_table(name, schema, table) @@ -457,7 +467,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): metadata._remove_table(name, schema) @property - @util.deprecated('0.9', 'Use ``table.schema.quote``') + @util.deprecated("0.9", "Use ``table.schema.quote``") def quote_schema(self): """Return the value of the ``quote_schema`` flag passed to this :class:`.Table`. @@ -478,23 +488,25 @@ class Table(DialectKWArgs, SchemaItem, TableClause): def _init(self, name, metadata, *args, **kwargs): super(Table, self).__init__( - quoted_name(name, kwargs.pop('quote', None))) + quoted_name(name, kwargs.pop("quote", None)) + ) self.metadata = metadata - self.schema = kwargs.pop('schema', None) + self.schema = kwargs.pop("schema", None) if self.schema is None: self.schema = metadata.schema elif self.schema is BLANK_SCHEMA: self.schema = None else: - quote_schema = kwargs.pop('quote_schema', None) + quote_schema = kwargs.pop("quote_schema", None) self.schema = quoted_name(self.schema, quote_schema) self.indexes = set() self.constraints = set() self._columns = ColumnCollection() - PrimaryKeyConstraint(_implicit_generated=True).\ - _set_parent_with_dispatch(self) + PrimaryKeyConstraint( + _implicit_generated=True + )._set_parent_with_dispatch(self) self.foreign_keys = set() self._extra_dependencies = set() if self.schema is not None: @@ -502,26 +514,26 @@ class Table(DialectKWArgs, SchemaItem, TableClause): else: self.fullname = self.name - autoload_with = kwargs.pop('autoload_with', None) - autoload = kwargs.pop('autoload', autoload_with is not None) + autoload_with = kwargs.pop("autoload_with", None) + autoload = kwargs.pop("autoload", autoload_with is not None) # this argument is only used with _init_existing() - kwargs.pop('autoload_replace', True) + kwargs.pop("autoload_replace", True) _extend_on = kwargs.pop("_extend_on", None) - include_columns = kwargs.pop('include_columns', None) + include_columns = kwargs.pop("include_columns", None) - self.implicit_returning = kwargs.pop('implicit_returning', True) + self.implicit_returning = kwargs.pop("implicit_returning", True) - self.comment = kwargs.pop('comment', None) + self.comment = kwargs.pop("comment", None) - if 'info' in kwargs: - self.info = kwargs.pop('info') - if 'listeners' in kwargs: - listeners = kwargs.pop('listeners') + if "info" in kwargs: + self.info = kwargs.pop("info") + if "listeners" in kwargs: + listeners = kwargs.pop("listeners") for evt, fn in listeners: event.listen(self, evt, fn) - self._prefixes = kwargs.pop('prefixes', []) + self._prefixes = kwargs.pop("prefixes", []) self._extra_kwargs(**kwargs) @@ -530,21 +542,29 @@ class Table(DialectKWArgs, SchemaItem, TableClause): # circular foreign keys if autoload: self._autoload( - metadata, autoload_with, - include_columns, _extend_on=_extend_on) + metadata, autoload_with, include_columns, _extend_on=_extend_on + ) # initialize all the column, etc. objects. done after reflection to # allow user-overrides self._init_items(*args) - def _autoload(self, metadata, autoload_with, include_columns, - exclude_columns=(), _extend_on=None): + def _autoload( + self, + metadata, + autoload_with, + include_columns, + exclude_columns=(), + _extend_on=None, + ): if autoload_with: autoload_with.run_callable( autoload_with.dialect.reflecttable, - self, include_columns, exclude_columns, - _extend_on=_extend_on + self, + include_columns, + exclude_columns, + _extend_on=_extend_on, ) else: bind = _bind_or_error( @@ -553,11 +573,14 @@ class Table(DialectKWArgs, SchemaItem, TableClause): "Pass an engine to the Table via " "autoload_with=<someengine>, " "or associate the MetaData with an engine via " - "metadata.bind=<someengine>") + "metadata.bind=<someengine>", + ) bind.run_callable( bind.dialect.reflecttable, - self, include_columns, exclude_columns, - _extend_on=_extend_on + self, + include_columns, + exclude_columns, + _extend_on=_extend_on, ) @property @@ -582,34 +605,36 @@ class Table(DialectKWArgs, SchemaItem, TableClause): return set(fkc.constraint for fkc in self.foreign_keys) def _init_existing(self, *args, **kwargs): - autoload_with = kwargs.pop('autoload_with', None) - autoload = kwargs.pop('autoload', autoload_with is not None) - autoload_replace = kwargs.pop('autoload_replace', True) - schema = kwargs.pop('schema', None) - _extend_on = kwargs.pop('_extend_on', None) + autoload_with = kwargs.pop("autoload_with", None) + autoload = kwargs.pop("autoload", autoload_with is not None) + autoload_replace = kwargs.pop("autoload_replace", True) + schema = kwargs.pop("schema", None) + _extend_on = kwargs.pop("_extend_on", None) if schema and schema != self.schema: raise exc.ArgumentError( "Can't change schema of existing table from '%s' to '%s'", - (self.schema, schema)) + (self.schema, schema), + ) - include_columns = kwargs.pop('include_columns', None) + include_columns = kwargs.pop("include_columns", None) if include_columns is not None: for c in self.c: if c.name not in include_columns: self._columns.remove(c) - for key in ('quote', 'quote_schema'): + for key in ("quote", "quote_schema"): if key in kwargs: raise exc.ArgumentError( - "Can't redefine 'quote' or 'quote_schema' arguments") + "Can't redefine 'quote' or 'quote_schema' arguments" + ) - if 'comment' in kwargs: - self.comment = kwargs.pop('comment', None) + if "comment" in kwargs: + self.comment = kwargs.pop("comment", None) - if 'info' in kwargs: - self.info = kwargs.pop('info') + if "info" in kwargs: + self.info = kwargs.pop("info") if autoload: if not autoload_replace: @@ -620,8 +645,12 @@ class Table(DialectKWArgs, SchemaItem, TableClause): else: exclude_columns = () self._autoload( - self.metadata, autoload_with, - include_columns, exclude_columns, _extend_on=_extend_on) + self.metadata, + autoload_with, + include_columns, + exclude_columns, + _extend_on=_extend_on, + ) self._extra_kwargs(**kwargs) self._init_items(*args) @@ -653,10 +682,12 @@ class Table(DialectKWArgs, SchemaItem, TableClause): return _get_table_key(self.name, self.schema) def __repr__(self): - return "Table(%s)" % ', '.join( - [repr(self.name)] + [repr(self.metadata)] + - [repr(x) for x in self.columns] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']]) + return "Table(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + ["%s=%s" % (k, repr(getattr(self, k))) for k in ["schema"]] + ) def __str__(self): return _get_table_key(self.description, self.schema) @@ -735,17 +766,19 @@ class Table(DialectKWArgs, SchemaItem, TableClause): def adapt_listener(target, connection, **kw): listener(event_name, target, connection) - event.listen(self, "" + event_name.replace('-', '_'), adapt_listener) + event.listen(self, "" + event_name.replace("-", "_"), adapt_listener) def _set_parent(self, metadata): metadata._add_table(self.name, self.schema, self) self.metadata = metadata - def get_children(self, column_collections=True, - schema_visitor=False, **kw): + def get_children( + self, column_collections=True, schema_visitor=False, **kw + ): if not schema_visitor: return TableClause.get_children( - self, column_collections=column_collections, **kw) + self, column_collections=column_collections, **kw + ) else: if column_collections: return list(self.columns) @@ -758,8 +791,9 @@ class Table(DialectKWArgs, SchemaItem, TableClause): if bind is None: bind = _bind_or_error(self) - return bind.run_callable(bind.dialect.has_table, - self.name, schema=self.schema) + return bind.run_callable( + bind.dialect.has_table, self.name, schema=self.schema + ) def create(self, bind=None, checkfirst=False): """Issue a ``CREATE`` statement for this @@ -774,9 +808,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaGenerator, - self, - checkfirst=checkfirst) + bind._run_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst) def drop(self, bind=None, checkfirst=False): """Issue a ``DROP`` statement for this @@ -790,12 +822,15 @@ class Table(DialectKWArgs, SchemaItem, TableClause): """ if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaDropper, - self, - checkfirst=checkfirst) - - def tometadata(self, metadata, schema=RETAIN_SCHEMA, - referred_schema_fn=None, name=None): + bind._run_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst) + + def tometadata( + self, + metadata, + schema=RETAIN_SCHEMA, + referred_schema_fn=None, + name=None, + ): """Return a copy of this :class:`.Table` associated with a different :class:`.MetaData`. @@ -868,29 +903,37 @@ class Table(DialectKWArgs, SchemaItem, TableClause): schema = metadata.schema key = _get_table_key(name, schema) if key in metadata.tables: - util.warn("Table '%s' already exists within the given " - "MetaData - not copying." % self.description) + util.warn( + "Table '%s' already exists within the given " + "MetaData - not copying." % self.description + ) return metadata.tables[key] args = [] for c in self.columns: args.append(c.copy(schema=schema)) table = Table( - name, metadata, schema=schema, + name, + metadata, + schema=schema, comment=self.comment, - *args, **self.kwargs + *args, + **self.kwargs ) for c in self.constraints: if isinstance(c, ForeignKeyConstraint): referred_schema = c._referred_schema if referred_schema_fn: fk_constraint_schema = referred_schema_fn( - self, schema, c, referred_schema) + self, schema, c, referred_schema + ) else: fk_constraint_schema = ( - schema if referred_schema == self.schema else None) + schema if referred_schema == self.schema else None + ) table.append_constraint( - c.copy(schema=fk_constraint_schema, target_table=table)) + c.copy(schema=fk_constraint_schema, target_table=table) + ) elif not c._type_bound: # skip unique constraints that would be generated # by the 'unique' flag on Column @@ -898,25 +941,30 @@ class Table(DialectKWArgs, SchemaItem, TableClause): continue table.append_constraint( - c.copy(schema=schema, target_table=table)) + c.copy(schema=schema, target_table=table) + ) for index in self.indexes: # skip indexes that would be generated # by the 'index' flag on Column if index._column_flag: continue - Index(index.name, - unique=index.unique, - *[_copy_expression(expr, self, table) - for expr in index.expressions], - _table=table, - **index.kwargs) + Index( + index.name, + unique=index.unique, + *[ + _copy_expression(expr, self, table) + for expr in index.expressions + ], + _table=table, + **index.kwargs + ) return self._schema_item_copy(table) class Column(DialectKWArgs, SchemaItem, ColumnClause): """Represents a column in a database table.""" - __visit_name__ = 'column' + __visit_name__ = "column" def __init__(self, *args, **kwargs): r""" @@ -1192,14 +1240,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): """ - name = kwargs.pop('name', None) - type_ = kwargs.pop('type_', None) + name = kwargs.pop("name", None) + type_ = kwargs.pop("type_", None) args = list(args) if args: if isinstance(args[0], util.string_types): if name is not None: raise exc.ArgumentError( - "May not pass name positionally and as a keyword.") + "May not pass name positionally and as a keyword." + ) name = args.pop(0) if args: coltype = args[0] @@ -1207,40 +1256,42 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): if hasattr(coltype, "_sqla_type"): if type_ is not None: raise exc.ArgumentError( - "May not pass type_ positionally and as a keyword.") + "May not pass type_ positionally and as a keyword." + ) type_ = args.pop(0) if name is not None: - name = quoted_name(name, kwargs.pop('quote', None)) + name = quoted_name(name, kwargs.pop("quote", None)) elif "quote" in kwargs: - raise exc.ArgumentError("Explicit 'name' is required when " - "sending 'quote' argument") + raise exc.ArgumentError( + "Explicit 'name' is required when " "sending 'quote' argument" + ) super(Column, self).__init__(name, type_) - self.key = kwargs.pop('key', name) - self.primary_key = kwargs.pop('primary_key', False) - self.nullable = kwargs.pop('nullable', not self.primary_key) - self.default = kwargs.pop('default', None) - self.server_default = kwargs.pop('server_default', None) - self.server_onupdate = kwargs.pop('server_onupdate', None) + self.key = kwargs.pop("key", name) + self.primary_key = kwargs.pop("primary_key", False) + self.nullable = kwargs.pop("nullable", not self.primary_key) + self.default = kwargs.pop("default", None) + self.server_default = kwargs.pop("server_default", None) + self.server_onupdate = kwargs.pop("server_onupdate", None) # these default to None because .index and .unique is *not* # an informational flag about Column - there can still be an # Index or UniqueConstraint referring to this Column. - self.index = kwargs.pop('index', None) - self.unique = kwargs.pop('unique', None) + self.index = kwargs.pop("index", None) + self.unique = kwargs.pop("unique", None) - self.system = kwargs.pop('system', False) - self.doc = kwargs.pop('doc', None) - self.onupdate = kwargs.pop('onupdate', None) - self.autoincrement = kwargs.pop('autoincrement', "auto") + self.system = kwargs.pop("system", False) + self.doc = kwargs.pop("doc", None) + self.onupdate = kwargs.pop("onupdate", None) + self.autoincrement = kwargs.pop("autoincrement", "auto") self.constraints = set() self.foreign_keys = set() - self.comment = kwargs.pop('comment', None) + self.comment = kwargs.pop("comment", None) # check if this Column is proxying another column - if '_proxies' in kwargs: - self._proxies = kwargs.pop('_proxies') + if "_proxies" in kwargs: + self._proxies = kwargs.pop("_proxies") # otherwise, add DDL-related events elif isinstance(self.type, SchemaEventTarget): self.type._set_parent_with_dispatch(self) @@ -1249,14 +1300,13 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): if isinstance(self.default, (ColumnDefault, Sequence)): args.append(self.default) else: - if getattr(self.type, '_warn_on_bytestring', False): + if getattr(self.type, "_warn_on_bytestring", False): if isinstance(self.default, util.binary_type): util.warn( "Unicode column '%s' has non-unicode " - "default value %r specified." % ( - self.key, - self.default - )) + "default value %r specified." + % (self.key, self.default) + ) args.append(ColumnDefault(self.default)) if self.server_default is not None: @@ -1275,30 +1325,31 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): if isinstance(self.server_onupdate, FetchedValue): args.append(self.server_onupdate._as_for_update(True)) else: - args.append(DefaultClause(self.server_onupdate, - for_update=True)) + args.append( + DefaultClause(self.server_onupdate, for_update=True) + ) self._init_items(*args) util.set_creation_order(self) - if 'info' in kwargs: - self.info = kwargs.pop('info') + if "info" in kwargs: + self.info = kwargs.pop("info") self._extra_kwargs(**kwargs) def _extra_kwargs(self, **kwargs): self._validate_dialect_kwargs(kwargs) -# @property -# def quote(self): -# return getattr(self.name, "quote", None) + # @property + # def quote(self): + # return getattr(self.name, "quote", None) def __str__(self): if self.name is None: return "(no name)" elif self.table is not None: if self.table.named_with_column: - return (self.table.description + "." + self.description) + return self.table.description + "." + self.description else: return self.description else: @@ -1320,40 +1371,47 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): def __repr__(self): kwarg = [] if self.key != self.name: - kwarg.append('key') + kwarg.append("key") if self.primary_key: - kwarg.append('primary_key') + kwarg.append("primary_key") if not self.nullable: - kwarg.append('nullable') + kwarg.append("nullable") if self.onupdate: - kwarg.append('onupdate') + kwarg.append("onupdate") if self.default: - kwarg.append('default') + kwarg.append("default") if self.server_default: - kwarg.append('server_default') - return "Column(%s)" % ', '.join( - [repr(self.name)] + [repr(self.type)] + - [repr(x) for x in self.foreign_keys if x is not None] + - [repr(x) for x in self.constraints] + - [(self.table is not None and "table=<%s>" % - self.table.description or "table=None")] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]) + kwarg.append("server_default") + return "Column(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.type)] + + [repr(x) for x in self.foreign_keys if x is not None] + + [repr(x) for x in self.constraints] + + [ + ( + self.table is not None + and "table=<%s>" % self.table.description + or "table=None" + ) + ] + + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg] + ) def _set_parent(self, table): if not self.name: raise exc.ArgumentError( "Column must be constructed with a non-blank name or " - "assign a non-blank .name before adding to a Table.") + "assign a non-blank .name before adding to a Table." + ) if self.key is None: self.key = self.name - existing = getattr(self, 'table', None) + existing = getattr(self, "table", None) if existing is not None and existing is not table: raise exc.ArgumentError( - "Column object '%s' already assigned to Table '%s'" % ( - self.key, - existing.description - )) + "Column object '%s' already assigned to Table '%s'" + % (self.key, existing.description) + ) if self.key in table._columns: col = table._columns.get(self.key) @@ -1373,8 +1431,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): elif self.key in table.primary_key: raise exc.ArgumentError( "Trying to redefine primary-key column '%s' as a " - "non-primary-key column on table '%s'" % ( - self.key, table.fullname)) + "non-primary-key column on table '%s'" + % (self.key, table.fullname) + ) self.table = table @@ -1383,7 +1442,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): raise exc.ArgumentError( "The 'index' keyword argument on Column is boolean only. " "To create indexes with a specific name, create an " - "explicit Index object external to the Table.") + "explicit Index object external to the Table." + ) Index(None, self, unique=bool(self.unique), _column_flag=True) elif self.unique: if isinstance(self.unique, util.string_types): @@ -1392,9 +1452,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): "only. To create unique constraints or indexes with a " "specific name, append an explicit UniqueConstraint to " "the Table's list of elements, or create an explicit " - "Index object external to the Table.") + "Index object external to the Table." + ) table.append_constraint( - UniqueConstraint(self.key, _column_flag=True)) + UniqueConstraint(self.key, _column_flag=True) + ) self._setup_on_memoized_fks(lambda fk: fk._set_remote_table(table)) @@ -1413,7 +1475,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): if self.table is not None: fn(self, self.table) else: - event.listen(self, 'after_parent_attach', fn) + event.listen(self, "after_parent_attach", fn) def copy(self, **kw): """Create a copy of this ``Column``, unitialized. @@ -1423,9 +1485,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): """ # Constraint objects plus non-constraint-bound ForeignKey objects - args = \ - [c.copy(**kw) for c in self.constraints if not c._type_bound] + \ - [c.copy(**kw) for c in self.foreign_keys if not c.constraint] + args = [ + c.copy(**kw) for c in self.constraints if not c._type_bound + ] + [c.copy(**kw) for c in self.foreign_keys if not c.constraint] type_ = self.type if isinstance(type_, SchemaEventTarget): @@ -1452,8 +1514,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): ) return self._schema_item_copy(c) - def _make_proxy(self, selectable, name=None, key=None, - name_is_truncatable=False, **kw): + def _make_proxy( + self, selectable, name=None, key=None, name_is_truncatable=False, **kw + ): """Create a *proxy* for this column. This is a copy of this ``Column`` referenced by a different parent @@ -1462,22 +1525,28 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): information is not transferred. """ - fk = [ForeignKey(f.column, _constraint=f.constraint) - for f in self.foreign_keys] + fk = [ + ForeignKey(f.column, _constraint=f.constraint) + for f in self.foreign_keys + ] if name is None and self.name is None: raise exc.InvalidRequestError( "Cannot initialize a sub-selectable" " with this Column object until its 'name' has " - "been assigned.") + "been assigned." + ) try: c = self._constructor( - _as_truncated(name or self.name) if - name_is_truncatable else (name or self.name), + _as_truncated(name or self.name) + if name_is_truncatable + else (name or self.name), self.type, key=key if key else name if name else self.key, primary_key=self.primary_key, nullable=self.nullable, - _proxies=[self], *fk) + _proxies=[self], + *fk + ) except TypeError: util.raise_from_cause( TypeError( @@ -1485,7 +1554,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): "Ensure the class includes a _constructor() " "attribute or method which accepts the " "standard Column constructor arguments, or " - "references the Column class itself." % self.__class__) + "references the Column class itself." % self.__class__ + ) ) c.table = selectable @@ -1499,9 +1569,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): def get_children(self, schema_visitor=False, **kwargs): if schema_visitor: - return [x for x in (self.default, self.onupdate) - if x is not None] + \ - list(self.foreign_keys) + list(self.constraints) + return ( + [x for x in (self.default, self.onupdate) if x is not None] + + list(self.foreign_keys) + + list(self.constraints) + ) else: return ColumnClause.get_children(self, **kwargs) @@ -1543,13 +1615,23 @@ class ForeignKey(DialectKWArgs, SchemaItem): """ - __visit_name__ = 'foreign_key' - - def __init__(self, column, _constraint=None, use_alter=False, name=None, - onupdate=None, ondelete=None, deferrable=None, - initially=None, link_to_name=False, match=None, - info=None, - **dialect_kw): + __visit_name__ = "foreign_key" + + def __init__( + self, + column, + _constraint=None, + use_alter=False, + name=None, + onupdate=None, + ondelete=None, + deferrable=None, + initially=None, + link_to_name=False, + match=None, + info=None, + **dialect_kw + ): r""" Construct a column-level FOREIGN KEY. @@ -1626,7 +1708,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): if isinstance(self._colspec, util.string_types): self._table_column = None else: - if hasattr(self._colspec, '__clause_element__'): + if hasattr(self._colspec, "__clause_element__"): self._table_column = self._colspec.__clause_element__() else: self._table_column = self._colspec @@ -1634,9 +1716,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): if not isinstance(self._table_column, ColumnClause): raise exc.ArgumentError( "String, Column, or Column-bound argument " - "expected, got %r" % self._table_column) + "expected, got %r" % self._table_column + ) elif not isinstance( - self._table_column.table, (util.NoneType, TableClause)): + self._table_column.table, (util.NoneType, TableClause) + ): raise exc.ArgumentError( "ForeignKey received Column not bound " "to a Table, got: %r" % self._table_column.table @@ -1715,7 +1799,9 @@ class ForeignKey(DialectKWArgs, SchemaItem): return "%s.%s" % (table_name, colname) elif self._table_column is not None: return "%s.%s" % ( - self._table_column.table.fullname, self._table_column.key) + self._table_column.table.fullname, + self._table_column.key, + ) else: return self._colspec @@ -1756,12 +1842,12 @@ class ForeignKey(DialectKWArgs, SchemaItem): def _column_tokens(self): """parse a string-based _colspec into its component parts.""" - m = self._get_colspec().split('.') + m = self._get_colspec().split(".") if m is None: raise exc.ArgumentError( - "Invalid foreign key column specification: %s" % - self._colspec) - if (len(m) == 1): + "Invalid foreign key column specification: %s" % self._colspec + ) + if len(m) == 1: tname = m.pop() colname = None else: @@ -1777,8 +1863,8 @@ class ForeignKey(DialectKWArgs, SchemaItem): # indirectly related -- Ticket #594. This assumes that '.' # will never appear *within* any component of the FK. - if (len(m) > 0): - schema = '.'.join(m) + if len(m) > 0: + schema = ".".join(m) else: schema = None return schema, tname, colname @@ -1787,12 +1873,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): if self.parent is None: raise exc.InvalidRequestError( "this ForeignKey object does not yet have a " - "parent Column associated with it.") + "parent Column associated with it." + ) elif self.parent.table is None: raise exc.InvalidRequestError( "this ForeignKey's parent column is not yet associated " - "with a Table.") + "with a Table." + ) parenttable = self.parent.table @@ -1817,7 +1905,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): return parenttable, tablekey, colname def _link_to_col_by_colstring(self, parenttable, table, colname): - if not hasattr(self.constraint, '_referred_table'): + if not hasattr(self.constraint, "_referred_table"): self.constraint._referred_table = table else: assert self.constraint._referred_table is table @@ -1843,9 +1931,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): raise exc.NoReferencedColumnError( "Could not initialize target column " "for ForeignKey '%s' on table '%s': " - "table '%s' has no column named '%s'" % - (self._colspec, parenttable.name, table.name, key), - table.name, key) + "table '%s' has no column named '%s'" + % (self._colspec, parenttable.name, table.name, key), + table.name, + key, + ) self._set_target_column(_column) @@ -1861,6 +1951,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): def set_type(fk): if fk.parent.type._isnull: fk.parent.type = column.type + self.parent._setup_on_memoized_fks(set_type) self.column = column @@ -1888,21 +1979,25 @@ class ForeignKey(DialectKWArgs, SchemaItem): raise exc.NoReferencedTableError( "Foreign key associated with column '%s' could not find " "table '%s' with which to generate a " - "foreign key to target column '%s'" % - (self.parent, tablekey, colname), - tablekey) + "foreign key to target column '%s'" + % (self.parent, tablekey, colname), + tablekey, + ) elif parenttable.key not in parenttable.metadata: raise exc.InvalidRequestError( "Table %s is no longer associated with its " - "parent MetaData" % parenttable) + "parent MetaData" % parenttable + ) else: raise exc.NoReferencedColumnError( "Could not initialize target column for " "ForeignKey '%s' on table '%s': " - "table '%s' has no column named '%s'" % ( - self._colspec, parenttable.name, tablekey, colname), - tablekey, colname) - elif hasattr(self._colspec, '__clause_element__'): + "table '%s' has no column named '%s'" + % (self._colspec, parenttable.name, tablekey, colname), + tablekey, + colname, + ) + elif hasattr(self._colspec, "__clause_element__"): _column = self._colspec.__clause_element__() return _column else: @@ -1912,7 +2007,8 @@ class ForeignKey(DialectKWArgs, SchemaItem): def _set_parent(self, column): if self.parent is not None and self.parent is not column: raise exc.InvalidRequestError( - "This ForeignKey already has a parent !") + "This ForeignKey already has a parent !" + ) self.parent = column self.parent.foreign_keys.add(self) self.parent._on_table_attach(self._set_table) @@ -1935,9 +2031,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): # on the hosting Table when attached to the Table. if self.constraint is None and isinstance(table, Table): self.constraint = ForeignKeyConstraint( - [], [], use_alter=self.use_alter, name=self.name, - onupdate=self.onupdate, ondelete=self.ondelete, - deferrable=self.deferrable, initially=self.initially, + [], + [], + use_alter=self.use_alter, + name=self.name, + onupdate=self.onupdate, + ondelete=self.ondelete, + deferrable=self.deferrable, + initially=self.initially, match=self.match, **self._unvalidated_dialect_kw ) @@ -1953,13 +2054,12 @@ class ForeignKey(DialectKWArgs, SchemaItem): if table_key in parenttable.metadata.tables: table = parenttable.metadata.tables[table_key] try: - self._link_to_col_by_colstring( - parenttable, table, colname) + self._link_to_col_by_colstring(parenttable, table, colname) except exc.NoReferencedColumnError: # this is OK, we'll try later pass parenttable.metadata._fk_memos[fk_key].append(self) - elif hasattr(self._colspec, '__clause_element__'): + elif hasattr(self._colspec, "__clause_element__"): _column = self._colspec.__clause_element__() self._set_target_column(_column) else: @@ -1971,7 +2071,8 @@ class _NotAColumnExpr(object): def _not_a_column_expr(self): raise exc.InvalidRequestError( "This %s cannot be used directly " - "as a column expression." % self.__class__.__name__) + "as a column expression." % self.__class__.__name__ + ) __clause_element__ = self_group = lambda self: self._not_a_column_expr() _from_objects = property(lambda self: self._not_a_column_expr()) @@ -1980,7 +2081,7 @@ class _NotAColumnExpr(object): class DefaultGenerator(_NotAColumnExpr, SchemaItem): """Base class for column *default* values.""" - __visit_name__ = 'default_generator' + __visit_name__ = "default_generator" is_sequence = False is_server_default = False @@ -2007,7 +2108,7 @@ class DefaultGenerator(_NotAColumnExpr, SchemaItem): @property def bind(self): """Return the connectable associated with this default.""" - if getattr(self, 'column', None) is not None: + if getattr(self, "column", None) is not None: return self.column.table.bind else: return None @@ -2064,7 +2165,8 @@ class ColumnDefault(DefaultGenerator): super(ColumnDefault, self).__init__(**kwargs) if isinstance(arg, FetchedValue): raise exc.ArgumentError( - "ColumnDefault may not be a server-side default type.") + "ColumnDefault may not be a server-side default type." + ) if util.callable(arg): arg = self._maybe_wrap_callable(arg) self.arg = arg @@ -2079,9 +2181,11 @@ class ColumnDefault(DefaultGenerator): @util.memoized_property def is_scalar(self): - return not self.is_callable and \ - not self.is_clause_element and \ - not self.is_sequence + return ( + not self.is_callable + and not self.is_clause_element + and not self.is_sequence + ) @util.memoized_property @util.dependencies("sqlalchemy.sql.sqltypes") @@ -2114,17 +2218,19 @@ class ColumnDefault(DefaultGenerator): else: raise exc.ArgumentError( "ColumnDefault Python function takes zero or one " - "positional arguments") + "positional arguments" + ) def _visit_name(self): if self.for_update: return "column_onupdate" else: return "column_default" + __visit_name__ = property(_visit_name) def __repr__(self): - return "ColumnDefault(%r)" % (self.arg, ) + return "ColumnDefault(%r)" % (self.arg,) class Sequence(DefaultGenerator): @@ -2157,15 +2263,29 @@ class Sequence(DefaultGenerator): """ - __visit_name__ = 'sequence' + __visit_name__ = "sequence" is_sequence = True - def __init__(self, name, start=None, increment=None, minvalue=None, - maxvalue=None, nominvalue=None, nomaxvalue=None, cycle=None, - schema=None, cache=None, order=None, optional=False, - quote=None, metadata=None, quote_schema=None, - for_update=False): + def __init__( + self, + name, + start=None, + increment=None, + minvalue=None, + maxvalue=None, + nominvalue=None, + nomaxvalue=None, + cycle=None, + schema=None, + cache=None, + order=None, + optional=False, + quote=None, + metadata=None, + quote_schema=None, + for_update=False, + ): """Construct a :class:`.Sequence` object. :param name: The name of the sequence. @@ -2353,27 +2473,22 @@ class Sequence(DefaultGenerator): if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaGenerator, - self, - checkfirst=checkfirst) + bind._run_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst) def drop(self, bind=None, checkfirst=True): """Drops this sequence from the database.""" if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaDropper, - self, - checkfirst=checkfirst) + bind._run_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst) def _not_a_column_expr(self): raise exc.InvalidRequestError( "This %s cannot be used directly " "as a column expression. Use func.next_value(sequence) " "to produce a 'next value' function that's usable " - "as a column element." - % self.__class__.__name__) - + "as a column element." % self.__class__.__name__ + ) @inspection._self_inspects @@ -2396,6 +2511,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget): :ref:`triggered_columns` """ + is_server_default = True reflected = False has_argument = False @@ -2412,7 +2528,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget): def _clone(self, for_update): n = self.__class__.__new__(self.__class__) n.__dict__.update(self.__dict__) - n.__dict__.pop('column', None) + n.__dict__.pop("column", None) n.for_update = for_update return n @@ -2452,16 +2568,15 @@ class DefaultClause(FetchedValue): has_argument = True def __init__(self, arg, for_update=False, _reflected=False): - util.assert_arg_type(arg, (util.string_types[0], - ClauseElement, - TextClause), 'arg') + util.assert_arg_type( + arg, (util.string_types[0], ClauseElement, TextClause), "arg" + ) super(DefaultClause, self).__init__(for_update) self.arg = arg self.reflected = _reflected def __repr__(self): - return "DefaultClause(%r, for_update=%r)" % \ - (self.arg, self.for_update) + return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update) class PassiveDefault(DefaultClause): @@ -2471,10 +2586,13 @@ class PassiveDefault(DefaultClause): :class:`.PassiveDefault` is deprecated. Use :class:`.DefaultClause`. """ - @util.deprecated("0.6", - ":class:`.PassiveDefault` is deprecated. " - "Use :class:`.DefaultClause`.", - False) + + @util.deprecated( + "0.6", + ":class:`.PassiveDefault` is deprecated. " + "Use :class:`.DefaultClause`.", + False, + ) def __init__(self, *arg, **kw): DefaultClause.__init__(self, *arg, **kw) @@ -2482,11 +2600,18 @@ class PassiveDefault(DefaultClause): class Constraint(DialectKWArgs, SchemaItem): """A table-level SQL constraint.""" - __visit_name__ = 'constraint' - - def __init__(self, name=None, deferrable=None, initially=None, - _create_rule=None, info=None, _type_bound=False, - **dialect_kw): + __visit_name__ = "constraint" + + def __init__( + self, + name=None, + deferrable=None, + initially=None, + _create_rule=None, + info=None, + _type_bound=False, + **dialect_kw + ): r"""Create a SQL constraint. :param name: @@ -2548,7 +2673,8 @@ class Constraint(DialectKWArgs, SchemaItem): pass raise exc.InvalidRequestError( "This constraint is not bound to a table. Did you " - "mean to call table.append_constraint(constraint) ?") + "mean to call table.append_constraint(constraint) ?" + ) def _set_parent(self, parent): self.parent = parent @@ -2559,7 +2685,7 @@ class Constraint(DialectKWArgs, SchemaItem): def _to_schema_column(element): - if hasattr(element, '__clause_element__'): + if hasattr(element, "__clause_element__"): element = element.__clause_element__() if not isinstance(element, Column): raise exc.ArgumentError("schema.Column object expected") @@ -2567,9 +2693,9 @@ def _to_schema_column(element): def _to_schema_column_or_string(element): - if hasattr(element, '__clause_element__'): + if hasattr(element, "__clause_element__"): element = element.__clause_element__() - if not isinstance(element, util.string_types + (ColumnElement, )): + if not isinstance(element, util.string_types + (ColumnElement,)): msg = "Element %r is not a string name or column element" raise exc.ArgumentError(msg % element) return element @@ -2588,11 +2714,12 @@ class ColumnCollectionMixin(object): _allow_multiple_tables = False def __init__(self, *columns, **kw): - _autoattach = kw.pop('_autoattach', True) - self._column_flag = kw.pop('_column_flag', False) + _autoattach = kw.pop("_autoattach", True) + self._column_flag = kw.pop("_column_flag", False) self.columns = ColumnCollection() - self._pending_colargs = [_to_schema_column_or_string(c) - for c in columns] + self._pending_colargs = [ + _to_schema_column_or_string(c) for c in columns + ] if _autoattach and self._pending_colargs: self._check_attach() @@ -2601,7 +2728,7 @@ class ColumnCollectionMixin(object): for expr in expressions: strname = None column = None - if hasattr(expr, '__clause_element__'): + if hasattr(expr, "__clause_element__"): expr = expr.__clause_element__() if not isinstance(expr, (ColumnElement, TextClause)): @@ -2609,21 +2736,16 @@ class ColumnCollectionMixin(object): strname = expr else: cols = [] - visitors.traverse(expr, {}, {'column': cols.append}) + visitors.traverse(expr, {}, {"column": cols.append}) if cols: column = cols[0] add_element = column if column is not None else strname yield expr, column, strname, add_element def _check_attach(self, evt=False): - col_objs = [ - c for c in self._pending_colargs - if isinstance(c, Column) - ] + col_objs = [c for c in self._pending_colargs if isinstance(c, Column)] - cols_w_table = [ - c for c in col_objs if isinstance(c.table, Table) - ] + cols_w_table = [c for c in col_objs if isinstance(c.table, Table)] cols_wo_table = set(col_objs).difference(cols_w_table) @@ -2636,6 +2758,7 @@ class ColumnCollectionMixin(object): # columns are specified as strings. has_string_cols = set(self._pending_colargs).difference(col_objs) if not has_string_cols: + def _col_attached(column, table): # this isinstance() corresponds with the # isinstance() above; only want to count Table-bound @@ -2644,6 +2767,7 @@ class ColumnCollectionMixin(object): cols_wo_table.discard(column) if not cols_wo_table: self._check_attach(evt=True) + self._cols_wo_table = cols_wo_table for col in cols_wo_table: col._on_table_attach(_col_attached) @@ -2659,9 +2783,11 @@ class ColumnCollectionMixin(object): others = [c for c in columns[1:] if c.table is not table] if others: raise exc.ArgumentError( - "Column(s) %s are not part of table '%s'." % - (", ".join("'%s'" % c for c in others), - table.description) + "Column(s) %s are not part of table '%s'." + % ( + ", ".join("'%s'" % c for c in others), + table.description, + ) ) def _set_parent(self, table): @@ -2694,11 +2820,12 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): arguments are propagated to the :class:`.Constraint` superclass. """ - _autoattach = kw.pop('_autoattach', True) - _column_flag = kw.pop('_column_flag', False) + _autoattach = kw.pop("_autoattach", True) + _column_flag = kw.pop("_column_flag", False) Constraint.__init__(self, **kw) ColumnCollectionMixin.__init__( - self, *columns, _autoattach=_autoattach, _column_flag=_column_flag) + self, *columns, _autoattach=_autoattach, _column_flag=_column_flag + ) columns = None """A :class:`.ColumnCollection` representing the set of columns @@ -2714,8 +2841,12 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): return x in self.columns def copy(self, **kw): - c = self.__class__(name=self.name, deferrable=self.deferrable, - initially=self.initially, *self.columns.keys()) + c = self.__class__( + name=self.name, + deferrable=self.deferrable, + initially=self.initially, + *self.columns.keys() + ) return self._schema_item_copy(c) def contains_column(self, col): @@ -2747,9 +2878,19 @@ class CheckConstraint(ColumnCollectionConstraint): _allow_multiple_tables = True - def __init__(self, sqltext, name=None, deferrable=None, - initially=None, table=None, info=None, _create_rule=None, - _autoattach=True, _type_bound=False, **kw): + def __init__( + self, + sqltext, + name=None, + deferrable=None, + initially=None, + table=None, + info=None, + _create_rule=None, + _autoattach=True, + _type_bound=False, + **kw + ): r"""Construct a CHECK constraint. :param sqltext: @@ -2781,14 +2922,19 @@ class CheckConstraint(ColumnCollectionConstraint): self.sqltext = _literal_as_text(sqltext, warn=False) columns = [] - visitors.traverse(self.sqltext, {}, {'column': columns.append}) - - super(CheckConstraint, self).\ - __init__( - name=name, deferrable=deferrable, - initially=initially, _create_rule=_create_rule, info=info, - _type_bound=_type_bound, _autoattach=_autoattach, - *columns, **kw) + visitors.traverse(self.sqltext, {}, {"column": columns.append}) + + super(CheckConstraint, self).__init__( + name=name, + deferrable=deferrable, + initially=initially, + _create_rule=_create_rule, + info=info, + _type_bound=_type_bound, + _autoattach=_autoattach, + *columns, + **kw + ) if table is not None: self._set_parent_with_dispatch(table) @@ -2797,22 +2943,24 @@ class CheckConstraint(ColumnCollectionConstraint): return "check_constraint" else: return "column_check_constraint" + __visit_name__ = property(__visit_name__) def copy(self, target_table=None, **kw): if target_table is not None: - sqltext = _copy_expression( - self.sqltext, self.table, target_table) + sqltext = _copy_expression(self.sqltext, self.table, target_table) else: sqltext = self.sqltext - c = CheckConstraint(sqltext, - name=self.name, - initially=self.initially, - deferrable=self.deferrable, - _create_rule=self._create_rule, - table=target_table, - _autoattach=False, - _type_bound=self._type_bound) + c = CheckConstraint( + sqltext, + name=self.name, + initially=self.initially, + deferrable=self.deferrable, + _create_rule=self._create_rule, + table=target_table, + _autoattach=False, + _type_bound=self._type_bound, + ) return self._schema_item_copy(c) @@ -2828,12 +2976,25 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): Examples of foreign key configuration are in :ref:`metadata_foreignkeys`. """ - __visit_name__ = 'foreign_key_constraint' - def __init__(self, columns, refcolumns, name=None, onupdate=None, - ondelete=None, deferrable=None, initially=None, - use_alter=False, link_to_name=False, match=None, - table=None, info=None, **dialect_kw): + __visit_name__ = "foreign_key_constraint" + + def __init__( + self, + columns, + refcolumns, + name=None, + onupdate=None, + ondelete=None, + deferrable=None, + initially=None, + use_alter=False, + link_to_name=False, + match=None, + table=None, + info=None, + **dialect_kw + ): r"""Construct a composite-capable FOREIGN KEY. :param columns: A sequence of local column names. The named columns @@ -2905,8 +3066,13 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): """ Constraint.__init__( - self, name=name, deferrable=deferrable, initially=initially, - info=info, **dialect_kw) + self, + name=name, + deferrable=deferrable, + initially=initially, + info=info, + **dialect_kw + ) self.onupdate = onupdate self.ondelete = ondelete self.link_to_name = link_to_name @@ -2927,7 +3093,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): raise exc.ArgumentError( "ForeignKeyConstraint number " "of constrained columns must match the number of " - "referenced columns.") + "referenced columns." + ) # standalone ForeignKeyConstraint - create # associated ForeignKey objects which will be applied to hosted @@ -2946,7 +3113,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): deferrable=self.deferrable, initially=self.initially, **self.dialect_kwargs - ) for refcol in refcolumns + ) + for refcol in refcolumns ] ColumnCollectionMixin.__init__(self, *columns) @@ -2978,9 +3146,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): @property def _elements(self): # legacy - provide a dictionary view of (column_key, fk) - return util.OrderedDict( - zip(self.column_keys, self.elements) - ) + return util.OrderedDict(zip(self.column_keys, self.elements)) @property def _referred_schema(self): @@ -3004,18 +3170,14 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): return self.elements[0].column.table def _validate_dest_table(self, table): - table_keys = set([elem._table_key() - for elem in self.elements]) + table_keys = set([elem._table_key() for elem in self.elements]) if None not in table_keys and len(table_keys) > 1: elem0, elem1 = sorted(table_keys)[0:2] raise exc.ArgumentError( - 'ForeignKeyConstraint on %s(%s) refers to ' - 'multiple remote tables: %s and %s' % ( - table.fullname, - self._col_description, - elem0, - elem1 - )) + "ForeignKeyConstraint on %s(%s) refers to " + "multiple remote tables: %s and %s" + % (table.fullname, self._col_description, elem0, elem1) + ) @property def column_keys(self): @@ -3034,8 +3196,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): return self.columns.keys() else: return [ - col.key if isinstance(col, ColumnElement) - else str(col) for col in self._pending_colargs + col.key if isinstance(col, ColumnElement) else str(col) + for col in self._pending_colargs ] @property @@ -3051,11 +3213,11 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): raise exc.ArgumentError( "Can't create ForeignKeyConstraint " "on table '%s': no column " - "named '%s' is present." % (table.description, ke.args[0])) + "named '%s' is present." % (table.description, ke.args[0]) + ) for col, fk in zip(self.columns, self.elements): - if not hasattr(fk, 'parent') or \ - fk.parent is not col: + if not hasattr(fk, "parent") or fk.parent is not col: fk._set_parent_with_dispatch(col) self._validate_dest_table(table) @@ -3063,13 +3225,16 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): def copy(self, schema=None, target_table=None, **kw): fkc = ForeignKeyConstraint( [x.parent.key for x in self.elements], - [x._get_colspec( - schema=schema, - table_name=target_table.name - if target_table is not None - and x._table_key() == x.parent.table.key - else None) - for x in self.elements], + [ + x._get_colspec( + schema=schema, + table_name=target_table.name + if target_table is not None + and x._table_key() == x.parent.table.key + else None, + ) + for x in self.elements + ], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, @@ -3077,11 +3242,9 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): deferrable=self.deferrable, initially=self.initially, link_to_name=self.link_to_name, - match=self.match + match=self.match, ) - for self_fk, other_fk in zip( - self.elements, - fkc.elements): + for self_fk, other_fk in zip(self.elements, fkc.elements): self_fk._schema_item_copy(other_fk) return self._schema_item_copy(fkc) @@ -3160,10 +3323,10 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): """ - __visit_name__ = 'primary_key_constraint' + __visit_name__ = "primary_key_constraint" def __init__(self, *columns, **kw): - self._implicit_generated = kw.pop('_implicit_generated', False) + self._implicit_generated = kw.pop("_implicit_generated", False) super(PrimaryKeyConstraint, self).__init__(*columns, **kw) def _set_parent(self, table): @@ -3175,18 +3338,21 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): table.constraints.add(self) table_pks = [c for c in table.c if c.primary_key] - if self.columns and table_pks and \ - set(table_pks) != set(self.columns.values()): + if ( + self.columns + and table_pks + and set(table_pks) != set(self.columns.values()) + ): util.warn( "Table '%s' specifies columns %s as primary_key=True, " "not matching locally specified columns %s; setting the " "current primary key columns to %s. This warning " - "may become an exception in a future release" % - ( + "may become an exception in a future release" + % ( table.name, ", ".join("'%s'" % c.name for c in table_pks), ", ".join("'%s'" % c.name for c in self.columns), - ", ".join("'%s'" % c.name for c in self.columns) + ", ".join("'%s'" % c.name for c in self.columns), ) ) table_pks[:] = [] @@ -3241,28 +3407,28 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): @util.memoized_property def _autoincrement_column(self): - def _validate_autoinc(col, autoinc_true): if col.type._type_affinity is None or not issubclass( - col.type._type_affinity, - type_api.INTEGERTYPE._type_affinity): + col.type._type_affinity, type_api.INTEGERTYPE._type_affinity + ): if autoinc_true: raise exc.ArgumentError( "Column type %s on column '%s' is not " - "compatible with autoincrement=True" % ( - col.type, - col - )) + "compatible with autoincrement=True" % (col.type, col) + ) else: return False - elif not isinstance(col.default, (type(None), Sequence)) and \ - not autoinc_true: - return False + elif ( + not isinstance(col.default, (type(None), Sequence)) + and not autoinc_true + ): + return False elif col.server_default is not None and not autoinc_true: return False - elif ( - col.foreign_keys and col.autoincrement - not in (True, 'ignore_fk')): + elif col.foreign_keys and col.autoincrement not in ( + True, + "ignore_fk", + ): return False return True @@ -3272,10 +3438,10 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): if col.autoincrement is True: _validate_autoinc(col, True) return col - elif ( - col.autoincrement in ('auto', 'ignore_fk') and - _validate_autoinc(col, False) - ): + elif col.autoincrement in ( + "auto", + "ignore_fk", + ) and _validate_autoinc(col, False): return col else: @@ -3286,8 +3452,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): if autoinc is not None: raise exc.ArgumentError( "Only one Column may be marked " - "autoincrement=True, found both %s and %s." % - (col.name, autoinc.name) + "autoincrement=True, found both %s and %s." + % (col.name, autoinc.name) ) else: autoinc = col @@ -3304,7 +3470,7 @@ class UniqueConstraint(ColumnCollectionConstraint): UniqueConstraint. """ - __visit_name__ = 'unique_constraint' + __visit_name__ = "unique_constraint" class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): @@ -3382,7 +3548,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): """ - __visit_name__ = 'index' + __visit_name__ = "index" def __init__(self, name, *expressions, **kw): r"""Construct an index object. @@ -3420,30 +3586,35 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): columns = [] processed_expressions = [] - for expr, column, strname, add_element in self.\ - _extract_col_expression_collection(expressions): + for ( + expr, + column, + strname, + add_element, + ) in self._extract_col_expression_collection(expressions): if add_element is not None: columns.append(add_element) processed_expressions.append(expr) self.expressions = processed_expressions self.name = quoted_name(name, kw.pop("quote", None)) - self.unique = kw.pop('unique', False) - _column_flag = kw.pop('_column_flag', False) - if 'info' in kw: - self.info = kw.pop('info') + self.unique = kw.pop("unique", False) + _column_flag = kw.pop("_column_flag", False) + if "info" in kw: + self.info = kw.pop("info") # TODO: consider "table" argument being public, but for # the purpose of the fix here, it starts as private. - if '_table' in kw: - table = kw.pop('_table') + if "_table" in kw: + table = kw.pop("_table") self._validate_dialect_kwargs(kw) # will call _set_parent() if table-bound column # objects are present ColumnCollectionMixin.__init__( - self, *columns, _column_flag=_column_flag) + self, *columns, _column_flag=_column_flag + ) if table is not None: self._set_parent(table) @@ -3454,20 +3625,17 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): if self.table is not None and table is not self.table: raise exc.ArgumentError( "Index '%s' is against table '%s', and " - "cannot be associated with table '%s'." % ( - self.name, - self.table.description, - table.description - ) + "cannot be associated with table '%s'." + % (self.name, self.table.description, table.description) ) self.table = table table.indexes.add(self) self.expressions = [ - expr if isinstance(expr, ClauseElement) - else colexpr - for expr, colexpr in util.zip_longest(self.expressions, - self.columns) + expr if isinstance(expr, ClauseElement) else colexpr + for expr, colexpr in util.zip_longest( + self.expressions, self.columns + ) ] @property @@ -3506,17 +3674,16 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): bind._run_visitor(ddl.SchemaDropper, self) def __repr__(self): - return 'Index(%s)' % ( + return "Index(%s)" % ( ", ".join( - [repr(self.name)] + - [repr(e) for e in self.expressions] + - (self.unique and ["unique=True"] or []) - )) + [repr(self.name)] + + [repr(e) for e in self.expressions] + + (self.unique and ["unique=True"] or []) + ) + ) -DEFAULT_NAMING_CONVENTION = util.immutabledict({ - "ix": 'ix_%(column_0_label)s' -}) +DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"}) class MetaData(SchemaItem): @@ -3542,13 +3709,17 @@ class MetaData(SchemaItem): """ - __visit_name__ = 'metadata' - - def __init__(self, bind=None, reflect=False, schema=None, - quote_schema=None, - naming_convention=DEFAULT_NAMING_CONVENTION, - info=None - ): + __visit_name__ = "metadata" + + def __init__( + self, + bind=None, + reflect=False, + schema=None, + quote_schema=None, + naming_convention=DEFAULT_NAMING_CONVENTION, + info=None, + ): """Create a new MetaData object. :param bind: @@ -3712,12 +3883,15 @@ class MetaData(SchemaItem): self.bind = bind if reflect: - util.warn_deprecated("reflect=True is deprecate; please " - "use the reflect() method.") + util.warn_deprecated( + "reflect=True is deprecate; please " + "use the reflect() method." + ) if not bind: raise exc.ArgumentError( "A bind must be supplied in conjunction " - "with reflect=True") + "with reflect=True" + ) self.reflect() tables = None @@ -3735,7 +3909,7 @@ class MetaData(SchemaItem): """ def __repr__(self): - return 'MetaData(bind=%r)' % self.bind + return "MetaData(bind=%r)" % self.bind def __contains__(self, table_or_key): if not isinstance(table_or_key, util.string_types): @@ -3755,27 +3929,32 @@ class MetaData(SchemaItem): for fk in removed.foreign_keys: fk._remove_from_metadata(self) if self._schemas: - self._schemas = set([t.schema - for t in self.tables.values() - if t.schema is not None]) + self._schemas = set( + [ + t.schema + for t in self.tables.values() + if t.schema is not None + ] + ) def __getstate__(self): - return {'tables': self.tables, - 'schema': self.schema, - 'schemas': self._schemas, - 'sequences': self._sequences, - 'fk_memos': self._fk_memos, - 'naming_convention': self.naming_convention - } + return { + "tables": self.tables, + "schema": self.schema, + "schemas": self._schemas, + "sequences": self._sequences, + "fk_memos": self._fk_memos, + "naming_convention": self.naming_convention, + } def __setstate__(self, state): - self.tables = state['tables'] - self.schema = state['schema'] - self.naming_convention = state['naming_convention'] + self.tables = state["tables"] + self.schema = state["schema"] + self.naming_convention = state["naming_convention"] self._bind = None - self._sequences = state['sequences'] - self._schemas = state['schemas'] - self._fk_memos = state['fk_memos'] + self._sequences = state["sequences"] + self._schemas = state["schemas"] + self._fk_memos = state["fk_memos"] def is_bound(self): """True if this MetaData is bound to an Engine or Connection.""" @@ -3805,10 +3984,11 @@ class MetaData(SchemaItem): def _bind_to(self, url, bind): """Bind this MetaData to an Engine, Connection, string or URL.""" - if isinstance(bind, util.string_types + (url.URL, )): + if isinstance(bind, util.string_types + (url.URL,)): self._bind = sqlalchemy.create_engine(bind) else: self._bind = bind + bind = property(bind, _bind_to) def clear(self): @@ -3858,12 +4038,20 @@ class MetaData(SchemaItem): """ - return ddl.sort_tables(sorted(self.tables.values(), key=lambda t: t.key)) + return ddl.sort_tables( + sorted(self.tables.values(), key=lambda t: t.key) + ) - def reflect(self, bind=None, schema=None, views=False, only=None, - extend_existing=False, - autoload_replace=True, - **dialect_kwargs): + def reflect( + self, + bind=None, + schema=None, + views=False, + only=None, + extend_existing=False, + autoload_replace=True, + **dialect_kwargs + ): r"""Load all available table definitions from the database. Automatically creates ``Table`` entries in this ``MetaData`` for any @@ -3926,11 +4114,11 @@ class MetaData(SchemaItem): with bind.connect() as conn: reflect_opts = { - 'autoload': True, - 'autoload_with': conn, - 'extend_existing': extend_existing, - 'autoload_replace': autoload_replace, - '_extend_on': set() + "autoload": True, + "autoload_with": conn, + "extend_existing": extend_existing, + "autoload_replace": autoload_replace, + "_extend_on": set(), } reflect_opts.update(dialect_kwargs) @@ -3939,42 +4127,49 @@ class MetaData(SchemaItem): schema = self.schema if schema is not None: - reflect_opts['schema'] = schema + reflect_opts["schema"] = schema available = util.OrderedSet( - bind.engine.table_names(schema, connection=conn)) + bind.engine.table_names(schema, connection=conn) + ) if views: - available.update( - bind.dialect.get_view_names(conn, schema) - ) + available.update(bind.dialect.get_view_names(conn, schema)) if schema is not None: - available_w_schema = util.OrderedSet(["%s.%s" % (schema, name) - for name in available]) + available_w_schema = util.OrderedSet( + ["%s.%s" % (schema, name) for name in available] + ) else: available_w_schema = available current = set(self.tables) if only is None: - load = [name for name, schname in - zip(available, available_w_schema) - if extend_existing or schname not in current] + load = [ + name + for name, schname in zip(available, available_w_schema) + if extend_existing or schname not in current + ] elif util.callable(only): - load = [name for name, schname in - zip(available, available_w_schema) - if (extend_existing or schname not in current) - and only(name, self)] + load = [ + name + for name, schname in zip(available, available_w_schema) + if (extend_existing or schname not in current) + and only(name, self) + ] else: missing = [name for name in only if name not in available] if missing: - s = schema and (" schema '%s'" % schema) or '' + s = schema and (" schema '%s'" % schema) or "" raise exc.InvalidRequestError( - 'Could not reflect: requested table(s) not available ' - 'in %r%s: (%s)' % - (bind.engine, s, ', '.join(missing))) - load = [name for name in only if extend_existing or - name not in current] + "Could not reflect: requested table(s) not available " + "in %r%s: (%s)" % (bind.engine, s, ", ".join(missing)) + ) + load = [ + name + for name in only + if extend_existing or name not in current + ] for name in load: try: @@ -3989,11 +4184,12 @@ class MetaData(SchemaItem): See :class:`.DDLEvents`. """ + def adapt_listener(target, connection, **kw): - tables = kw['tables'] + tables = kw["tables"] listener(event, target, connection, tables=tables) - event.listen(self, "" + event_name.replace('-', '_'), adapt_listener) + event.listen(self, "" + event_name.replace("-", "_"), adapt_listener) def create_all(self, bind=None, tables=None, checkfirst=True): """Create all tables stored in this metadata. @@ -4017,10 +4213,9 @@ class MetaData(SchemaItem): """ if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaGenerator, - self, - checkfirst=checkfirst, - tables=tables) + bind._run_visitor( + ddl.SchemaGenerator, self, checkfirst=checkfirst, tables=tables + ) def drop_all(self, bind=None, tables=None, checkfirst=True): """Drop all tables stored in this metadata. @@ -4044,10 +4239,9 @@ class MetaData(SchemaItem): """ if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaDropper, - self, - checkfirst=checkfirst, - tables=tables) + bind._run_visitor( + ddl.SchemaDropper, self, checkfirst=checkfirst, tables=tables + ) class ThreadLocalMetaData(MetaData): @@ -4064,7 +4258,7 @@ class ThreadLocalMetaData(MetaData): """ - __visit_name__ = 'metadata' + __visit_name__ = "metadata" def __init__(self): """Construct a ThreadLocalMetaData.""" @@ -4080,13 +4274,13 @@ class ThreadLocalMetaData(MetaData): string or URL to automatically create a basic Engine for this bind with ``create_engine()``.""" - return getattr(self.context, '_engine', None) + return getattr(self.context, "_engine", None) @util.dependencies("sqlalchemy.engine.url") def _bind_to(self, url, bind): """Bind to a Connectable in the caller's thread.""" - if isinstance(bind, util.string_types + (url.URL, )): + if isinstance(bind, util.string_types + (url.URL,)): try: self.context._engine = self.__engines[bind] except KeyError: @@ -4104,14 +4298,16 @@ class ThreadLocalMetaData(MetaData): def is_bound(self): """True if there is a bind for this thread.""" - return (hasattr(self.context, '_engine') and - self.context._engine is not None) + return ( + hasattr(self.context, "_engine") + and self.context._engine is not None + ) def dispose(self): """Dispose all bound engines, in all thread contexts.""" for e in self.__engines.values(): - if hasattr(e, 'dispose'): + if hasattr(e, "dispose"): e.dispose() @@ -4128,22 +4324,25 @@ class _SchemaTranslateMap(object): """ - __slots__ = 'map_', '__call__', 'hash_key', 'is_default' + + __slots__ = "map_", "__call__", "hash_key", "is_default" _default_schema_getter = operator.attrgetter("schema") def __init__(self, map_): self.map_ = map_ if map_ is not None: + def schema_for_object(obj): effective_schema = self._default_schema_getter(obj) effective_schema = obj._translate_schema( - effective_schema, map_) + effective_schema, map_ + ) return effective_schema + self.__call__ = schema_for_object self.hash_key = ";".join( - "%s=%s" % (k, map_[k]) - for k in sorted(map_, key=str) + "%s=%s" % (k, map_[k]) for k in sorted(map_, key=str) ) self.is_default = False else: @@ -4160,6 +4359,6 @@ class _SchemaTranslateMap(object): else: return _SchemaTranslateMap(map_) + _default_schema_map = _SchemaTranslateMap(None) _schema_getter = _SchemaTranslateMap._schema_getter - diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index f64f152c4..1f1800514 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -10,15 +10,39 @@ SQL tables and derived rowsets. """ -from .elements import ClauseElement, TextClause, ClauseList, \ - and_, Grouping, UnaryExpression, literal_column, BindParameter -from .elements import _clone, \ - _literal_as_text, _interpret_as_column_or_from, _expand_cloned,\ - _select_iterables, _anonymous_label, _clause_element_as_expr,\ - _cloned_intersection, _cloned_difference, True_, \ - _literal_as_label_reference, _literal_and_labels_as_label_reference -from .base import Immutable, Executable, _generative, \ - ColumnCollection, ColumnSet, _from_objects, Generative +from .elements import ( + ClauseElement, + TextClause, + ClauseList, + and_, + Grouping, + UnaryExpression, + literal_column, + BindParameter, +) +from .elements import ( + _clone, + _literal_as_text, + _interpret_as_column_or_from, + _expand_cloned, + _select_iterables, + _anonymous_label, + _clause_element_as_expr, + _cloned_intersection, + _cloned_difference, + True_, + _literal_as_label_reference, + _literal_and_labels_as_label_reference, +) +from .base import ( + Immutable, + Executable, + _generative, + ColumnCollection, + ColumnSet, + _from_objects, + Generative, +) from . import type_api from .. import inspection from .. import util @@ -40,7 +64,8 @@ def _interpret_as_from(element): "Textual SQL FROM expression %(expr)r should be " "explicitly declared as text(%(expr)r), " "or use table(%(expr)r) for more specificity", - {"expr": util.ellipses_string(element)}) + {"expr": util.ellipses_string(element)}, + ) return TextClause(util.text_type(element)) try: @@ -73,7 +98,7 @@ def _offset_or_limit_clause(element, name=None, type_=None): """ if element is None: return None - elif hasattr(element, '__clause_element__'): + elif hasattr(element, "__clause_element__"): return element.__clause_element__() elif isinstance(element, Visitable): return element @@ -97,7 +122,8 @@ def _offset_or_limit_clause_asint(clause, attrname): except AttributeError: raise exc.CompileError( "This SELECT structure does not use a simple " - "integer value for %s" % attrname) + "integer value for %s" % attrname + ) else: return util.asint(value) @@ -225,12 +251,14 @@ def tablesample(selectable, sampling, name=None, seed=None): """ return _interpret_as_from(selectable).tablesample( - sampling, name=name, seed=seed) + sampling, name=name, seed=seed + ) class Selectable(ClauseElement): """mark a class as being selectable""" - __visit_name__ = 'selectable' + + __visit_name__ = "selectable" is_selectable = True @@ -265,15 +293,17 @@ class HasPrefixes(object): limit rendering of this prefix to only that dialect. """ - dialect = kw.pop('dialect', None) + dialect = kw.pop("dialect", None) if kw: - raise exc.ArgumentError("Unsupported argument(s): %s" % - ",".join(kw)) + raise exc.ArgumentError( + "Unsupported argument(s): %s" % ",".join(kw) + ) self._setup_prefixes(expr, dialect) def _setup_prefixes(self, prefixes, dialect=None): self._prefixes = self._prefixes + tuple( - [(_literal_as_text(p, warn=False), dialect) for p in prefixes]) + [(_literal_as_text(p, warn=False), dialect) for p in prefixes] + ) class HasSuffixes(object): @@ -301,15 +331,17 @@ class HasSuffixes(object): limit rendering of this suffix to only that dialect. """ - dialect = kw.pop('dialect', None) + dialect = kw.pop("dialect", None) if kw: - raise exc.ArgumentError("Unsupported argument(s): %s" % - ",".join(kw)) + raise exc.ArgumentError( + "Unsupported argument(s): %s" % ",".join(kw) + ) self._setup_suffixes(expr, dialect) def _setup_suffixes(self, suffixes, dialect=None): self._suffixes = self._suffixes + tuple( - [(_literal_as_text(p, warn=False), dialect) for p in suffixes]) + [(_literal_as_text(p, warn=False), dialect) for p in suffixes] + ) class FromClause(Selectable): @@ -330,7 +362,8 @@ class FromClause(Selectable): """ - __visit_name__ = 'fromclause' + + __visit_name__ = "fromclause" named_with_column = False _hide_froms = [] @@ -359,13 +392,14 @@ class FromClause(Selectable): _memoized_property = util.group_expirable_memoized_property(["_columns"]) @util.deprecated( - '1.1', + "1.1", message="``FromClause.count()`` is deprecated. Counting " "rows requires that the correct column expression and " "accommodations for joins, DISTINCT, etc. must be made, " "otherwise results may not be what's expected. " "Please use an appropriate ``func.count()`` expression " - "directly.") + "directly.", + ) @util.dependencies("sqlalchemy.sql.functions") def count(self, functions, whereclause=None, **params): """return a SELECT COUNT generated against this @@ -392,10 +426,11 @@ class FromClause(Selectable): else: col = list(self.columns)[0] return Select( - [functions.func.count(col).label('tbl_row_count')], + [functions.func.count(col).label("tbl_row_count")], whereclause, from_obj=[self], - **params) + **params + ) def select(self, whereclause=None, **params): """return a SELECT of this :class:`.FromClause`. @@ -603,8 +638,9 @@ class FromClause(Selectable): def embedded(expanded_proxy_set, target_set): for t in target_set.difference(expanded_proxy_set): - if not set(_expand_cloned([t]) - ).intersection(expanded_proxy_set): + if not set(_expand_cloned([t])).intersection( + expanded_proxy_set + ): return False return True @@ -617,8 +653,10 @@ class FromClause(Selectable): for c in cols: expanded_proxy_set = set(_expand_cloned(c.proxy_set)) i = target_set.intersection(expanded_proxy_set) - if i and (not require_embedded - or embedded(expanded_proxy_set, target_set)): + if i and ( + not require_embedded + or embedded(expanded_proxy_set, target_set) + ): if col is None: # no corresponding column yet, pick this one. @@ -646,12 +684,20 @@ class FromClause(Selectable): col_distance = util.reduce( operator.add, - [sc._annotations.get('weight', 1) for sc in - col.proxy_set if sc.shares_lineage(column)]) + [ + sc._annotations.get("weight", 1) + for sc in col.proxy_set + if sc.shares_lineage(column) + ], + ) c_distance = util.reduce( operator.add, - [sc._annotations.get('weight', 1) for sc in - c.proxy_set if sc.shares_lineage(column)]) + [ + sc._annotations.get("weight", 1) + for sc in c.proxy_set + if sc.shares_lineage(column) + ], + ) if c_distance < col_distance: col, intersect = c, i return col @@ -663,7 +709,7 @@ class FromClause(Selectable): Used primarily for error message formatting. """ - return getattr(self, 'name', self.__class__.__name__ + " object") + return getattr(self, "name", self.__class__.__name__ + " object") def _reset_exported(self): """delete memoized collections when a FromClause is cloned.""" @@ -683,7 +729,7 @@ class FromClause(Selectable): """ - if '_columns' not in self.__dict__: + if "_columns" not in self.__dict__: self._init_collections() self._populate_column_collection() return self._columns.as_immutable() @@ -706,14 +752,16 @@ class FromClause(Selectable): self._populate_column_collection() return self.foreign_keys - c = property(attrgetter('columns'), - doc="An alias for the :attr:`.columns` attribute.") - _select_iterable = property(attrgetter('columns')) + c = property( + attrgetter("columns"), + doc="An alias for the :attr:`.columns` attribute.", + ) + _select_iterable = property(attrgetter("columns")) def _init_collections(self): - assert '_columns' not in self.__dict__ - assert 'primary_key' not in self.__dict__ - assert 'foreign_keys' not in self.__dict__ + assert "_columns" not in self.__dict__ + assert "primary_key" not in self.__dict__ + assert "foreign_keys" not in self.__dict__ self._columns = ColumnCollection() self.primary_key = ColumnSet() @@ -721,7 +769,7 @@ class FromClause(Selectable): @property def _cols_populated(self): - return '_columns' in self.__dict__ + return "_columns" in self.__dict__ def _populate_column_collection(self): """Called on subclasses to establish the .c collection. @@ -758,8 +806,7 @@ class FromClause(Selectable): """ if not self._cols_populated: return None - elif (column.key in self.columns and - self.columns[column.key] is column): + elif column.key in self.columns and self.columns[column.key] is column: return column else: return None @@ -780,7 +827,8 @@ class Join(FromClause): :meth:`.FromClause.join` """ - __visit_name__ = 'join' + + __visit_name__ = "join" _is_join = True @@ -829,8 +877,9 @@ class Join(FromClause): return cls(left, right, onclause, isouter=True, full=full) @classmethod - def _create_join(cls, left, right, onclause=None, isouter=False, - full=False): + def _create_join( + cls, left, right, onclause=None, isouter=False, full=False + ): """Produce a :class:`.Join` object, given two :class:`.FromClause` expressions. @@ -882,26 +931,34 @@ class Join(FromClause): self.left.description, id(self.left), self.right.description, - id(self.right)) + id(self.right), + ) def is_derived_from(self, fromclause): - return fromclause is self or \ - self.left.is_derived_from(fromclause) or \ - self.right.is_derived_from(fromclause) + return ( + fromclause is self + or self.left.is_derived_from(fromclause) + or self.right.is_derived_from(fromclause) + ) def self_group(self, against=None): return FromGrouping(self) @util.dependencies("sqlalchemy.sql.util") def _populate_column_collection(self, sqlutil): - columns = [c for c in self.left.columns] + \ - [c for c in self.right.columns] + columns = [c for c in self.left.columns] + [ + c for c in self.right.columns + ] - self.primary_key.extend(sqlutil.reduce_columns( - (c for c in columns if c.primary_key), self.onclause)) + self.primary_key.extend( + sqlutil.reduce_columns( + (c for c in columns if c.primary_key), self.onclause + ) + ) self._columns.update((col._label, col) for col in columns) - self.foreign_keys.update(itertools.chain( - *[col.foreign_keys for col in columns])) + self.foreign_keys.update( + itertools.chain(*[col.foreign_keys for col in columns]) + ) def _refresh_for_new_column(self, column): col = self.left._refresh_for_new_column(column) @@ -933,9 +990,14 @@ class Join(FromClause): return self._join_condition(left, right, a_subset=left_right) @classmethod - def _join_condition(cls, a, b, ignore_nonexistent_tables=False, - a_subset=None, - consider_as_foreign_keys=None): + def _join_condition( + cls, + a, + b, + ignore_nonexistent_tables=False, + a_subset=None, + consider_as_foreign_keys=None, + ): """create a join condition between two tables or selectables. e.g.:: @@ -963,26 +1025,31 @@ class Join(FromClause): """ constraints = cls._joincond_scan_left_right( - a, a_subset, b, consider_as_foreign_keys) + a, a_subset, b, consider_as_foreign_keys + ) if len(constraints) > 1: cls._joincond_trim_constraints( - a, b, constraints, consider_as_foreign_keys) + a, b, constraints, consider_as_foreign_keys + ) if len(constraints) == 0: if isinstance(b, FromGrouping): - hint = " Perhaps you meant to convert the right side to a "\ + hint = ( + " Perhaps you meant to convert the right side to a " "subquery using alias()?" + ) else: hint = "" raise exc.NoForeignKeysError( "Can't find any foreign key relationships " - "between '%s' and '%s'.%s" % - (a.description, b.description, hint)) + "between '%s' and '%s'.%s" + % (a.description, b.description, hint) + ) crit = [(x == y) for x, y in list(constraints.values())[0]] if len(crit) == 1: - return (crit[0]) + return crit[0] else: return and_(*crit) @@ -994,24 +1061,30 @@ class Join(FromClause): left_right = None constraints = cls._joincond_scan_left_right( - a=left, b=right, a_subset=left_right, - consider_as_foreign_keys=consider_as_foreign_keys) + a=left, + b=right, + a_subset=left_right, + consider_as_foreign_keys=consider_as_foreign_keys, + ) return bool(constraints) @classmethod def _joincond_scan_left_right( - cls, a, a_subset, b, consider_as_foreign_keys): + cls, a, a_subset, b, consider_as_foreign_keys + ): constraints = collections.defaultdict(list) for left in (a_subset, a): if left is None: continue for fk in sorted( - b.foreign_keys, - key=lambda fk: fk.parent._creation_order): - if consider_as_foreign_keys is not None and \ - fk.parent not in consider_as_foreign_keys: + b.foreign_keys, key=lambda fk: fk.parent._creation_order + ): + if ( + consider_as_foreign_keys is not None + and fk.parent not in consider_as_foreign_keys + ): continue try: col = fk.get_referent(left) @@ -1025,10 +1098,12 @@ class Join(FromClause): constraints[fk.constraint].append((col, fk.parent)) if left is not b: for fk in sorted( - left.foreign_keys, - key=lambda fk: fk.parent._creation_order): - if consider_as_foreign_keys is not None and \ - fk.parent not in consider_as_foreign_keys: + left.foreign_keys, key=lambda fk: fk.parent._creation_order + ): + if ( + consider_as_foreign_keys is not None + and fk.parent not in consider_as_foreign_keys + ): continue try: col = fk.get_referent(b) @@ -1046,14 +1121,16 @@ class Join(FromClause): @classmethod def _joincond_trim_constraints( - cls, a, b, constraints, consider_as_foreign_keys): + cls, a, b, constraints, consider_as_foreign_keys + ): # more than one constraint matched. narrow down the list # to include just those FKCs that match exactly to # "consider_as_foreign_keys". if consider_as_foreign_keys: for const in list(constraints): if set(f.parent for f in const.elements) != set( - consider_as_foreign_keys): + consider_as_foreign_keys + ): del constraints[const] # if still multiple constraints, but @@ -1070,8 +1147,8 @@ class Join(FromClause): "tables have more than one foreign key " "constraint relationship between them. " "Please specify the 'onclause' of this " - "join explicitly." % (a.description, b.description)) - + "join explicitly." % (a.description, b.description) + ) def select(self, whereclause=None, **kwargs): r"""Create a :class:`.Select` from this :class:`.Join`. @@ -1200,27 +1277,37 @@ class Join(FromClause): """ if flat: assert name is None, "Can't send name argument with flat" - left_a, right_a = self.left.alias(flat=True), \ - self.right.alias(flat=True) - adapter = sqlutil.ClauseAdapter(left_a).\ - chain(sqlutil.ClauseAdapter(right_a)) + left_a, right_a = ( + self.left.alias(flat=True), + self.right.alias(flat=True), + ) + adapter = sqlutil.ClauseAdapter(left_a).chain( + sqlutil.ClauseAdapter(right_a) + ) - return left_a.join(right_a, adapter.traverse(self.onclause), - isouter=self.isouter, full=self.full) + return left_a.join( + right_a, + adapter.traverse(self.onclause), + isouter=self.isouter, + full=self.full, + ) else: return self.select(use_labels=True, correlate=False).alias(name) @property def _hide_froms(self): - return itertools.chain(*[_from_objects(x.left, x.right) - for x in self._cloned_set]) + return itertools.chain( + *[_from_objects(x.left, x.right) for x in self._cloned_set] + ) @property def _from_objects(self): - return [self] + \ - self.onclause._from_objects + \ - self.left._from_objects + \ - self.right._from_objects + return ( + [self] + + self.onclause._from_objects + + self.left._from_objects + + self.right._from_objects + ) class Alias(FromClause): @@ -1236,7 +1323,7 @@ class Alias(FromClause): """ - __visit_name__ = 'alias' + __visit_name__ = "alias" named_with_column = True _is_from_container = True @@ -1252,15 +1339,16 @@ class Alias(FromClause): self.element = selectable if name is None: if self.original.named_with_column: - name = getattr(self.original, 'name', None) - name = _anonymous_label('%%(%d %s)s' % (id(self), name - or 'anon')) + name = getattr(self.original, "name", None) + name = _anonymous_label("%%(%d %s)s" % (id(self), name or "anon")) self.name = name def self_group(self, against=None): - if isinstance(against, CompoundSelect) and \ - isinstance(self.original, Select) and \ - self.original._needs_parens_for_grouping(): + if ( + isinstance(against, CompoundSelect) + and isinstance(self.original, Select) + and self.original._needs_parens_for_grouping() + ): return FromGrouping(self) return super(Alias, self).self_group(against=against) @@ -1270,14 +1358,15 @@ class Alias(FromClause): if util.py3k: return self.name else: - return self.name.encode('ascii', 'backslashreplace') + return self.name.encode("ascii", "backslashreplace") def as_scalar(self): try: return self.element.as_scalar() except AttributeError: - raise AttributeError("Element %s does not support " - "'as_scalar()'" % self.element) + raise AttributeError( + "Element %s does not support " "'as_scalar()'" % self.element + ) def is_derived_from(self, fromclause): if fromclause in self._cloned_set: @@ -1344,7 +1433,7 @@ class Lateral(Alias): """ - __visit_name__ = 'lateral' + __visit_name__ = "lateral" _is_lateral = True @@ -1363,11 +1452,9 @@ class TableSample(Alias): """ - __visit_name__ = 'tablesample' + __visit_name__ = "tablesample" - def __init__(self, selectable, sampling, - name=None, - seed=None): + def __init__(self, selectable, sampling, name=None, seed=None): self.sampling = sampling self.seed = seed super(TableSample, self).__init__(selectable, name=name) @@ -1390,14 +1477,18 @@ class CTE(Generative, HasSuffixes, Alias): .. versionadded:: 0.7.6 """ - __visit_name__ = 'cte' - - def __init__(self, selectable, - name=None, - recursive=False, - _cte_alias=None, - _restates=frozenset(), - _suffixes=None): + + __visit_name__ = "cte" + + def __init__( + self, + selectable, + name=None, + recursive=False, + _cte_alias=None, + _restates=frozenset(), + _suffixes=None, + ): self.recursive = recursive self._cte_alias = _cte_alias self._restates = _restates @@ -1409,9 +1500,9 @@ class CTE(Generative, HasSuffixes, Alias): super(CTE, self)._copy_internals(clone, **kw) if self._cte_alias is not None: self._cte_alias = clone(self._cte_alias, **kw) - self._restates = frozenset([ - clone(elem, **kw) for elem in self._restates - ]) + self._restates = frozenset( + [clone(elem, **kw) for elem in self._restates] + ) @util.dependencies("sqlalchemy.sql.dml") def _populate_column_collection(self, dml): @@ -1428,7 +1519,7 @@ class CTE(Generative, HasSuffixes, Alias): name=name, recursive=self.recursive, _cte_alias=self, - _suffixes=self._suffixes + _suffixes=self._suffixes, ) def union(self, other): @@ -1437,7 +1528,7 @@ class CTE(Generative, HasSuffixes, Alias): name=self.name, recursive=self.recursive, _restates=self._restates.union([self]), - _suffixes=self._suffixes + _suffixes=self._suffixes, ) def union_all(self, other): @@ -1446,7 +1537,7 @@ class CTE(Generative, HasSuffixes, Alias): name=self.name, recursive=self.recursive, _restates=self._restates.union([self]), - _suffixes=self._suffixes + _suffixes=self._suffixes, ) @@ -1620,7 +1711,8 @@ class HasCTE(object): class FromGrouping(FromClause): """Represent a grouping of a FROM clause""" - __visit_name__ = 'grouping' + + __visit_name__ = "grouping" def __init__(self, element): self.element = element @@ -1651,7 +1743,7 @@ class FromGrouping(FromClause): return self.element._hide_froms def get_children(self, **kwargs): - return self.element, + return (self.element,) def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) @@ -1664,10 +1756,10 @@ class FromGrouping(FromClause): return getattr(self.element, attr) def __getstate__(self): - return {'element': self.element} + return {"element": self.element} def __setstate__(self, state): - self.element = state['element'] + self.element = state["element"] class TableClause(Immutable, FromClause): @@ -1699,7 +1791,7 @@ class TableClause(Immutable, FromClause): """ - __visit_name__ = 'table' + __visit_name__ = "table" named_with_column = True @@ -1744,7 +1836,7 @@ class TableClause(Immutable, FromClause): if util.py3k: return self.name else: - return self.name.encode('ascii', 'backslashreplace') + return self.name.encode("ascii", "backslashreplace") def append_column(self, c): self._columns[c.key] = c @@ -1773,7 +1865,8 @@ class TableClause(Immutable, FromClause): @util.dependencies("sqlalchemy.sql.dml") def update( - self, dml, whereclause=None, values=None, inline=False, **kwargs): + self, dml, whereclause=None, values=None, inline=False, **kwargs + ): """Generate an :func:`.update` construct against this :class:`.TableClause`. @@ -1785,8 +1878,13 @@ class TableClause(Immutable, FromClause): """ - return dml.Update(self, whereclause=whereclause, - values=values, inline=inline, **kwargs) + return dml.Update( + self, + whereclause=whereclause, + values=values, + inline=inline, + **kwargs + ) @util.dependencies("sqlalchemy.sql.dml") def delete(self, dml, whereclause=None, **kwargs): @@ -1809,7 +1907,6 @@ class TableClause(Immutable, FromClause): class ForUpdateArg(ClauseElement): - @classmethod def parse_legacy_select(self, arg): """Parse the for_update argument of :func:`.select`. @@ -1836,11 +1933,11 @@ class ForUpdateArg(ClauseElement): return None nowait = read = False - if arg == 'nowait': + if arg == "nowait": nowait = True - elif arg == 'read': + elif arg == "read": read = True - elif arg == 'read_nowait': + elif arg == "read_nowait": read = nowait = True elif arg is not True: raise exc.ArgumentError("Unknown for_update argument: %r" % arg) @@ -1860,12 +1957,12 @@ class ForUpdateArg(ClauseElement): def __eq__(self, other): return ( - isinstance(other, ForUpdateArg) and - other.nowait == self.nowait and - other.read == self.read and - other.skip_locked == self.skip_locked and - other.key_share == self.key_share and - other.of is self.of + isinstance(other, ForUpdateArg) + and other.nowait == self.nowait + and other.read == self.read + and other.skip_locked == self.skip_locked + and other.key_share == self.key_share + and other.of is self.of ) def __hash__(self): @@ -1876,8 +1973,13 @@ class ForUpdateArg(ClauseElement): self.of = [clone(col, **kw) for col in self.of] def __init__( - self, nowait=False, read=False, of=None, - skip_locked=False, key_share=False): + self, + nowait=False, + read=False, + of=None, + skip_locked=False, + key_share=False, + ): """Represents arguments specified to :meth:`.Select.for_update`. .. versionadded:: 0.9.0 @@ -1889,8 +1991,9 @@ class ForUpdateArg(ClauseElement): self.skip_locked = skip_locked self.key_share = key_share if of is not None: - self.of = [_interpret_as_column_or_from(elem) - for elem in util.to_list(of)] + self.of = [ + _interpret_as_column_or_from(elem) for elem in util.to_list(of) + ] else: self.of = None @@ -1930,17 +2033,20 @@ class SelectBase(HasCTE, Executable, FromClause): return self.as_scalar().label(name) @_generative - @util.deprecated('0.6', - message="``autocommit()`` is deprecated. Use " - ":meth:`.Executable.execution_options` with the " - "'autocommit' flag.") + @util.deprecated( + "0.6", + message="``autocommit()`` is deprecated. Use " + ":meth:`.Executable.execution_options` with the " + "'autocommit' flag.", + ) def autocommit(self): """return a new selectable with the 'autocommit' flag set to True. """ - self._execution_options = \ - self._execution_options.union({'autocommit': True}) + self._execution_options = self._execution_options.union( + {"autocommit": True} + ) def _generate(self): """Override the default _generate() method to also clear out @@ -1973,34 +2079,38 @@ class GenerativeSelect(SelectBase): used for other SELECT-like objects, e.g. :class:`.TextAsFrom`. """ + _order_by_clause = ClauseList() _group_by_clause = ClauseList() _limit_clause = None _offset_clause = None _for_update_arg = None - def __init__(self, - use_labels=False, - for_update=False, - limit=None, - offset=None, - order_by=None, - group_by=None, - bind=None, - autocommit=None): + def __init__( + self, + use_labels=False, + for_update=False, + limit=None, + offset=None, + order_by=None, + group_by=None, + bind=None, + autocommit=None, + ): self.use_labels = use_labels if for_update is not False: - self._for_update_arg = (ForUpdateArg. - parse_legacy_select(for_update)) + self._for_update_arg = ForUpdateArg.parse_legacy_select(for_update) if autocommit is not None: - util.warn_deprecated('autocommit on select() is ' - 'deprecated. Use .execution_options(a' - 'utocommit=True)') - self._execution_options = \ - self._execution_options.union( - {'autocommit': autocommit}) + util.warn_deprecated( + "autocommit on select() is " + "deprecated. Use .execution_options(a" + "utocommit=True)" + ) + self._execution_options = self._execution_options.union( + {"autocommit": autocommit} + ) if limit is not None: self._limit_clause = _offset_or_limit_clause(limit) if offset is not None: @@ -2010,11 +2120,13 @@ class GenerativeSelect(SelectBase): if order_by is not None: self._order_by_clause = ClauseList( *util.to_list(order_by), - _literal_as_text=_literal_and_labels_as_label_reference) + _literal_as_text=_literal_and_labels_as_label_reference + ) if group_by is not None: self._group_by_clause = ClauseList( *util.to_list(group_by), - _literal_as_text=_literal_as_label_reference) + _literal_as_text=_literal_as_label_reference + ) @property def for_update(self): @@ -2030,8 +2142,14 @@ class GenerativeSelect(SelectBase): self._for_update_arg = ForUpdateArg.parse_legacy_select(value) @_generative - def with_for_update(self, nowait=False, read=False, of=None, - skip_locked=False, key_share=False): + def with_for_update( + self, + nowait=False, + read=False, + of=None, + skip_locked=False, + key_share=False, + ): """Specify a ``FOR UPDATE`` clause for this :class:`.GenerativeSelect`. E.g.:: @@ -2079,9 +2197,13 @@ class GenerativeSelect(SelectBase): .. versionadded:: 1.1.0 """ - self._for_update_arg = ForUpdateArg(nowait=nowait, read=read, of=of, - skip_locked=skip_locked, - key_share=key_share) + self._for_update_arg = ForUpdateArg( + nowait=nowait, + read=read, + of=of, + skip_locked=skip_locked, + key_share=key_share, + ) @_generative def apply_labels(self): @@ -2209,11 +2331,12 @@ class GenerativeSelect(SelectBase): if len(clauses) == 1 and clauses[0] is None: self._order_by_clause = ClauseList() else: - if getattr(self, '_order_by_clause', None) is not None: + if getattr(self, "_order_by_clause", None) is not None: clauses = list(self._order_by_clause) + list(clauses) self._order_by_clause = ClauseList( *clauses, - _literal_as_text=_literal_and_labels_as_label_reference) + _literal_as_text=_literal_and_labels_as_label_reference + ) def append_group_by(self, *clauses): """Append the given GROUP BY criterion applied to this selectable. @@ -2228,10 +2351,11 @@ class GenerativeSelect(SelectBase): if len(clauses) == 1 and clauses[0] is None: self._group_by_clause = ClauseList() else: - if getattr(self, '_group_by_clause', None) is not None: + if getattr(self, "_group_by_clause", None) is not None: clauses = list(self._group_by_clause) + list(clauses) self._group_by_clause = ClauseList( - *clauses, _literal_as_text=_literal_as_label_reference) + *clauses, _literal_as_text=_literal_as_label_reference + ) @property def _label_resolve_dict(self): @@ -2265,19 +2389,19 @@ class CompoundSelect(GenerativeSelect): """ - __visit_name__ = 'compound_select' + __visit_name__ = "compound_select" - UNION = util.symbol('UNION') - UNION_ALL = util.symbol('UNION ALL') - EXCEPT = util.symbol('EXCEPT') - EXCEPT_ALL = util.symbol('EXCEPT ALL') - INTERSECT = util.symbol('INTERSECT') - INTERSECT_ALL = util.symbol('INTERSECT ALL') + UNION = util.symbol("UNION") + UNION_ALL = util.symbol("UNION ALL") + EXCEPT = util.symbol("EXCEPT") + EXCEPT_ALL = util.symbol("EXCEPT ALL") + INTERSECT = util.symbol("INTERSECT") + INTERSECT_ALL = util.symbol("INTERSECT ALL") _is_from_container = True def __init__(self, keyword, *selects, **kwargs): - self._auto_correlate = kwargs.pop('correlate', False) + self._auto_correlate = kwargs.pop("correlate", False) self.keyword = keyword self.selects = [] @@ -2291,12 +2415,16 @@ class CompoundSelect(GenerativeSelect): numcols = len(s.c._all_columns) elif len(s.c._all_columns) != numcols: raise exc.ArgumentError( - 'All selectables passed to ' - 'CompoundSelect must have identical numbers of ' - 'columns; select #%d has %d columns, select ' - '#%d has %d' % - (1, len(self.selects[0].c._all_columns), - n + 1, len(s.c._all_columns)) + "All selectables passed to " + "CompoundSelect must have identical numbers of " + "columns; select #%d has %d columns, select " + "#%d has %d" + % ( + 1, + len(self.selects[0].c._all_columns), + n + 1, + len(s.c._all_columns), + ) ) self.selects.append(s.self_group(against=self)) @@ -2305,9 +2433,7 @@ class CompoundSelect(GenerativeSelect): @property def _label_resolve_dict(self): - d = dict( - (c.key, c) for c in self.c - ) + d = dict((c.key, c) for c in self.c) return d, d, d @classmethod @@ -2416,8 +2542,7 @@ class CompoundSelect(GenerativeSelect): :func:`select`. """ - return CompoundSelect( - CompoundSelect.INTERSECT_ALL, *selects, **kwargs) + return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs) def _scalar_type(self): return self.selects[0]._scalar_type() @@ -2445,8 +2570,10 @@ class CompoundSelect(GenerativeSelect): # those fks too. proxy = cols[0]._make_proxy( - self, name=cols[0]._label if self.use_labels else None, - key=cols[0]._key_label if self.use_labels else None) + self, + name=cols[0]._label if self.use_labels else None, + key=cols[0]._key_label if self.use_labels else None, + ) # hand-construct the "_proxies" collection to include all # derived columns place a 'weight' annotation corresponding @@ -2455,7 +2582,8 @@ class CompoundSelect(GenerativeSelect): # conflicts proxy._proxies = [ - c._annotate({'weight': i + 1}) for (i, c) in enumerate(cols)] + c._annotate({"weight": i + 1}) for (i, c) in enumerate(cols) + ] def _refresh_for_new_column(self, column): for s in self.selects: @@ -2464,25 +2592,32 @@ class CompoundSelect(GenerativeSelect): if not self._cols_populated: return None - raise NotImplementedError("CompoundSelect constructs don't support " - "addition of columns to underlying " - "selectables") + raise NotImplementedError( + "CompoundSelect constructs don't support " + "addition of columns to underlying " + "selectables" + ) def _copy_internals(self, clone=_clone, **kw): super(CompoundSelect, self)._copy_internals(clone, **kw) self._reset_exported() self.selects = [clone(s, **kw) for s in self.selects] - if hasattr(self, '_col_map'): + if hasattr(self, "_col_map"): del self._col_map for attr in ( - '_order_by_clause', '_group_by_clause', '_for_update_arg'): + "_order_by_clause", + "_group_by_clause", + "_for_update_arg", + ): if getattr(self, attr) is not None: setattr(self, attr, clone(getattr(self, attr), **kw)) def get_children(self, column_collections=True, **kwargs): - return (column_collections and list(self.c) or []) \ - + [self._order_by_clause, self._group_by_clause] \ + return ( + (column_collections and list(self.c) or []) + + [self._order_by_clause, self._group_by_clause] + list(self.selects) + ) def bind(self): if self._bind: @@ -2496,6 +2631,7 @@ class CompoundSelect(GenerativeSelect): def _set_bind(self, bind): self._bind = bind + bind = property(bind, _set_bind) @@ -2504,7 +2640,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ - __visit_name__ = 'select' + __visit_name__ = "select" _prefixes = () _suffixes = () @@ -2517,16 +2653,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): _memoized_property = SelectBase._memoized_property _is_select = True - def __init__(self, - columns=None, - whereclause=None, - from_obj=None, - distinct=False, - having=None, - correlate=True, - prefixes=None, - suffixes=None, - **kwargs): + def __init__( + self, + columns=None, + whereclause=None, + from_obj=None, + distinct=False, + having=None, + correlate=True, + prefixes=None, + suffixes=None, + **kwargs + ): """Construct a new :class:`.Select`. Similar functionality is also available via the @@ -2729,22 +2867,23 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._distinct = True else: self._distinct = [ - _literal_as_text(e) - for e in util.to_list(distinct) + _literal_as_text(e) for e in util.to_list(distinct) ] if from_obj is not None: self._from_obj = util.OrderedSet( - _interpret_as_from(f) - for f in util.to_list(from_obj)) + _interpret_as_from(f) for f in util.to_list(from_obj) + ) else: self._from_obj = util.OrderedSet() try: cols_present = bool(columns) except TypeError: - raise exc.ArgumentError("columns argument to select() must " - "be a Python list or other iterable") + raise exc.ArgumentError( + "columns argument to select() must " + "be a Python list or other iterable" + ) if cols_present: self._raw_columns = [] @@ -2757,14 +2896,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._raw_columns = [] if whereclause is not None: - self._whereclause = _literal_as_text( - whereclause).self_group(against=operators._asbool) + self._whereclause = _literal_as_text(whereclause).self_group( + against=operators._asbool + ) else: self._whereclause = None if having is not None: - self._having = _literal_as_text( - having).self_group(against=operators._asbool) + self._having = _literal_as_text(having).self_group( + against=operators._asbool + ) else: self._having = None @@ -2789,12 +2930,14 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): for item in itertools.chain( _from_objects(*self._raw_columns), _from_objects(self._whereclause) - if self._whereclause is not None else (), - self._from_obj + if self._whereclause is not None + else (), + self._from_obj, ): if item is self: raise exc.InvalidRequestError( - "select() construct refers to itself as a FROM") + "select() construct refers to itself as a FROM" + ) if translate and item in translate: item = translate[item] if not seen.intersection(item._cloned_set): @@ -2803,8 +2946,9 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): return froms - def _get_display_froms(self, explicit_correlate_froms=None, - implicit_correlate_froms=None): + def _get_display_froms( + self, explicit_correlate_froms=None, implicit_correlate_froms=None + ): """Return the full list of 'from' clauses to be displayed. Takes into account a set of existing froms which may be @@ -2815,17 +2959,17 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ froms = self._froms - toremove = set(itertools.chain(*[ - _expand_cloned(f._hide_froms) - for f in froms])) + toremove = set( + itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms]) + ) if toremove: # if we're maintaining clones of froms, # add the copies out to the toremove list. only include # clones that are lexical equivalents. if self._from_cloned: toremove.update( - self._from_cloned[f] for f in - toremove.intersection(self._from_cloned) + self._from_cloned[f] + for f in toremove.intersection(self._from_cloned) if self._from_cloned[f]._is_lexical_equivalent(f) ) # filter out to FROM clauses not in the list, @@ -2836,41 +2980,53 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): to_correlate = self._correlate if to_correlate: froms = [ - f for f in froms if f not in - _cloned_intersection( + f + for f in froms + if f + not in _cloned_intersection( _cloned_intersection( - froms, explicit_correlate_froms or ()), - to_correlate + froms, explicit_correlate_froms or () + ), + to_correlate, ) ] if self._correlate_except is not None: froms = [ - f for f in froms if f not in - _cloned_difference( + f + for f in froms + if f + not in _cloned_difference( _cloned_intersection( - froms, explicit_correlate_froms or ()), - self._correlate_except + froms, explicit_correlate_froms or () + ), + self._correlate_except, ) ] - if self._auto_correlate and \ - implicit_correlate_froms and \ - len(froms) > 1: + if ( + self._auto_correlate + and implicit_correlate_froms + and len(froms) > 1 + ): froms = [ - f for f in froms if f not in - _cloned_intersection(froms, implicit_correlate_froms) + f + for f in froms + if f + not in _cloned_intersection(froms, implicit_correlate_froms) ] if not len(froms): - raise exc.InvalidRequestError("Select statement '%s" - "' returned no FROM clauses " - "due to auto-correlation; " - "specify correlate(<tables>) " - "to control correlation " - "manually." % self) + raise exc.InvalidRequestError( + "Select statement '%s" + "' returned no FROM clauses " + "due to auto-correlation; " + "specify correlate(<tables>) " + "to control correlation " + "manually." % self + ) return froms @@ -2885,7 +3041,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): return self._get_display_froms() - def with_statement_hint(self, text, dialect_name='*'): + def with_statement_hint(self, text, dialect_name="*"): """add a statement hint to this :class:`.Select`. This method is similar to :meth:`.Select.with_hint` except that @@ -2906,7 +3062,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): return self.with_hint(None, text, dialect_name) @_generative - def with_hint(self, selectable, text, dialect_name='*'): + def with_hint(self, selectable, text, dialect_name="*"): r"""Add an indexing or other executional context hint for the given selectable to this :class:`.Select`. @@ -2940,17 +3096,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ if selectable is None: - self._statement_hints += ((dialect_name, text), ) + self._statement_hints += ((dialect_name, text),) else: - self._hints = self._hints.union( - {(selectable, dialect_name): text}) + self._hints = self._hints.union({(selectable, dialect_name): text}) @property def type(self): - raise exc.InvalidRequestError("Select objects don't have a type. " - "Call as_scalar() on this Select " - "object to return a 'scalar' version " - "of this Select.") + raise exc.InvalidRequestError( + "Select objects don't have a type. " + "Call as_scalar() on this Select " + "object to return a 'scalar' version " + "of this Select." + ) @_memoized_property.method def locate_all_froms(self): @@ -2977,10 +3134,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): with_cols = dict( (c._resolve_label or c._label or c.key, c) for c in _select_iterables(self._raw_columns) - if c._allow_label_resolve) + if c._allow_label_resolve + ) only_froms = dict( - (c.key, c) for c in - _select_iterables(self.froms) if c._allow_label_resolve) + (c.key, c) + for c in _select_iterables(self.froms) + if c._allow_label_resolve + ) only_cols = with_cols.copy() for key, value in only_froms.items(): with_cols.setdefault(key, value) @@ -3011,11 +3171,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): # gets cleared on each generation. previously we were "baking" # _froms into self._from_obj. self._from_cloned = from_cloned = dict( - (f, clone(f, **kw)) for f in self._from_obj.union(self._froms)) + (f, clone(f, **kw)) for f in self._from_obj.union(self._froms) + ) # 3. update persistent _from_obj with the cloned versions. - self._from_obj = util.OrderedSet(from_cloned[f] for f in - self._from_obj) + self._from_obj = util.OrderedSet( + from_cloned[f] for f in self._from_obj + ) # the _correlate collection is done separately, what can happen # here is the same item is _correlate as in _from_obj but the @@ -3023,16 +3185,22 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): # RelationshipProperty.Comparator._criterion_exists() does # this). Also keep _correlate liberally open with its previous # contents, as this set is used for matching, not rendering. - self._correlate = set(clone(f) for f in - self._correlate).union(self._correlate) + self._correlate = set(clone(f) for f in self._correlate).union( + self._correlate + ) # 4. clone other things. The difficulty here is that Column # objects are not actually cloned, and refer to their original # .table, resulting in the wrong "from" parent after a clone # operation. Hence _from_cloned and _from_obj supersede what is # present here. self._raw_columns = [clone(c, **kw) for c in self._raw_columns] - for attr in '_whereclause', '_having', '_order_by_clause', \ - '_group_by_clause', '_for_update_arg': + for attr in ( + "_whereclause", + "_having", + "_order_by_clause", + "_group_by_clause", + "_for_update_arg", + ): if getattr(self, attr) is not None: setattr(self, attr, clone(getattr(self, attr), **kw)) @@ -3043,12 +3211,21 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): def get_children(self, column_collections=True, **kwargs): """return child elements as per the ClauseElement specification.""" - return (column_collections and list(self.columns) or []) + \ - self._raw_columns + list(self._froms) + \ - [x for x in - (self._whereclause, self._having, - self._order_by_clause, self._group_by_clause) - if x is not None] + return ( + (column_collections and list(self.columns) or []) + + self._raw_columns + + list(self._froms) + + [ + x + for x in ( + self._whereclause, + self._having, + self._order_by_clause, + self._group_by_clause, + ) + if x is not None + ] + ) @_generative def column(self, column): @@ -3094,7 +3271,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): sqlutil.reduce_columns( self.inner_columns, only_synonyms=only_synonyms, - *(self._whereclause, ) + tuple(self._from_obj) + *(self._whereclause,) + tuple(self._from_obj) ) ) @@ -3307,7 +3484,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._correlate = () else: self._correlate = set(self._correlate).union( - _interpret_as_from(f) for f in fromclauses) + _interpret_as_from(f) for f in fromclauses + ) @_generative def correlate_except(self, *fromclauses): @@ -3349,7 +3527,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._correlate_except = () else: self._correlate_except = set(self._correlate_except or ()).union( - _interpret_as_from(f) for f in fromclauses) + _interpret_as_from(f) for f in fromclauses + ) def append_correlation(self, fromclause): """append the given correlation expression to this select() @@ -3363,7 +3542,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._auto_correlate = False self._correlate = set(self._correlate).union( - _interpret_as_from(f) for f in fromclause) + _interpret_as_from(f) for f in fromclause + ) def append_column(self, column): """append the given column expression to the columns clause of this @@ -3415,8 +3595,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ self._reset_exported() - self._whereclause = and_( - True_._ifnone(self._whereclause), whereclause) + self._whereclause = and_(True_._ifnone(self._whereclause), whereclause) def append_having(self, having): """append the given expression to this select() construct's HAVING @@ -3463,19 +3642,17 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): return [ name_for_col(c) - for c in util.unique_list( - _select_iterables(self._raw_columns)) + for c in util.unique_list(_select_iterables(self._raw_columns)) ] else: return [ (None, c) - for c in util.unique_list( - _select_iterables(self._raw_columns)) + for c in util.unique_list(_select_iterables(self._raw_columns)) ] def _populate_column_collection(self): for name, c in self._columns_plus_names: - if not hasattr(c, '_make_proxy'): + if not hasattr(c, "_make_proxy"): continue if name is None: key = None @@ -3486,9 +3663,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): else: key = None - c._make_proxy(self, key=key, - name=name, - name_is_truncatable=True) + c._make_proxy(self, key=key, name=name, name_is_truncatable=True) def _refresh_for_new_column(self, column): for fromclause in self._froms: @@ -3501,15 +3676,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self, name=col._label if self.use_labels else None, key=col._key_label if self.use_labels else None, - name_is_truncatable=True) + name_is_truncatable=True, + ) return None return None def _needs_parens_for_grouping(self): return ( - self._limit_clause is not None or - self._offset_clause is not None or - bool(self._order_by_clause.clauses) + self._limit_clause is not None + or self._offset_clause is not None + or bool(self._order_by_clause.clauses) ) def self_group(self, against=None): @@ -3521,8 +3697,10 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): expressions and should not require explicit use. """ - if isinstance(against, CompoundSelect) and \ - not self._needs_parens_for_grouping(): + if ( + isinstance(against, CompoundSelect) + and not self._needs_parens_for_grouping() + ): return self return FromGrouping(self) @@ -3586,6 +3764,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): def _set_bind(self, bind): self._bind = bind + bind = property(bind, _set_bind) @@ -3600,9 +3779,12 @@ class ScalarSelect(Generative, Grouping): @property def columns(self): - raise exc.InvalidRequestError('Scalar Select expression has no ' - 'columns; use this object directly ' - 'within a column-level expression.') + raise exc.InvalidRequestError( + "Scalar Select expression has no " + "columns; use this object directly " + "within a column-level expression." + ) + c = columns @_generative @@ -3621,6 +3803,7 @@ class Exists(UnaryExpression): """Represent an ``EXISTS`` clause. """ + __visit_name__ = UnaryExpression.__visit_name__ _from_objects = [] @@ -3646,12 +3829,16 @@ class Exists(UnaryExpression): s = args[0] else: if not args: - args = ([literal_column('*')],) + args = ([literal_column("*")],) s = Select(*args, **kwargs).as_scalar().self_group() - UnaryExpression.__init__(self, s, operator=operators.exists, - type_=type_api.BOOLEANTYPE, - wraps_column_expression=True) + UnaryExpression.__init__( + self, + s, + operator=operators.exists, + type_=type_api.BOOLEANTYPE, + wraps_column_expression=True, + ) def select(self, whereclause=None, **params): return Select([self], whereclause, **params) @@ -3706,6 +3893,7 @@ class TextAsFrom(SelectBase): :meth:`.TextClause.columns` """ + __visit_name__ = "text_as_from" _textual = True diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index c5708940b..61fc6d3c9 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -15,10 +15,21 @@ import collections import json from . import elements -from .type_api import TypeEngine, TypeDecorator, to_instance, Variant, \ - Emulated, NativeForEmulated -from .elements import quoted_name, TypeCoerce as type_coerce, _defer_name, \ - Slice, _literal_as_binds +from .type_api import ( + TypeEngine, + TypeDecorator, + to_instance, + Variant, + Emulated, + NativeForEmulated, +) +from .elements import ( + quoted_name, + TypeCoerce as type_coerce, + _defer_name, + Slice, + _literal_as_binds, +) from .. import exc, util, processors from .base import _bind_or_error, SchemaEventTarget from . import operators @@ -51,14 +62,15 @@ class _LookupExpressionAdapter(object): def _adapt_expression(self, op, other_comparator): othertype = other_comparator.type._type_affinity lookup = self.type._expression_adaptations.get( - op, self._blank_dict).get( - othertype, self.type) + op, self._blank_dict + ).get(othertype, self.type) if lookup is othertype: return (op, other_comparator.type) elif lookup is self.type._type_affinity: return (op, self.type) else: return (op, to_instance(lookup)) + comparator_factory = Comparator @@ -68,17 +80,16 @@ class Concatenable(object): typically strings.""" class Comparator(TypeEngine.Comparator): - def _adapt_expression(self, op, other_comparator): - if (op is operators.add and - isinstance( - other_comparator, - (Concatenable.Comparator, NullType.Comparator) - )): + if op is operators.add and isinstance( + other_comparator, + (Concatenable.Comparator, NullType.Comparator), + ): return operators.concat_op, self.expr.type else: return super(Concatenable.Comparator, self)._adapt_expression( - op, other_comparator) + op, other_comparator + ) comparator_factory = Comparator @@ -94,17 +105,15 @@ class Indexable(object): """ class Comparator(TypeEngine.Comparator): - def _setup_getitem(self, index): raise NotImplementedError() def __getitem__(self, index): - adjusted_op, adjusted_right_expr, result_type = \ - self._setup_getitem(index) + adjusted_op, adjusted_right_expr, result_type = self._setup_getitem( + index + ) return self.operate( - adjusted_op, - adjusted_right_expr, - result_type=result_type + adjusted_op, adjusted_right_expr, result_type=result_type ) comparator_factory = Comparator @@ -124,13 +133,16 @@ class String(Concatenable, TypeEngine): """ - __visit_name__ = 'string' + __visit_name__ = "string" - def __init__(self, length=None, collation=None, - convert_unicode=False, - unicode_error=None, - _warn_on_bytestring=False - ): + def __init__( + self, + length=None, + collation=None, + convert_unicode=False, + unicode_error=None, + _warn_on_bytestring=False, + ): """ Create a string-holding type. @@ -207,9 +219,10 @@ class String(Concatenable, TypeEngine): strings from a column with varied or corrupted encodings. """ - if unicode_error is not None and convert_unicode != 'force': - raise exc.ArgumentError("convert_unicode must be 'force' " - "when unicode_error is set.") + if unicode_error is not None and convert_unicode != "force": + raise exc.ArgumentError( + "convert_unicode must be 'force' " "when unicode_error is set." + ) self.length = length self.collation = collation @@ -222,23 +235,29 @@ class String(Concatenable, TypeEngine): value = value.replace("'", "''") if dialect.identifier_preparer._double_percents: - value = value.replace('%', '%%') + value = value.replace("%", "%%") return "'%s'" % value + return process def bind_processor(self, dialect): if self.convert_unicode or dialect.convert_unicode: - if dialect.supports_unicode_binds and \ - self.convert_unicode != 'force': + if ( + dialect.supports_unicode_binds + and self.convert_unicode != "force" + ): if self._warn_on_bytestring: + def process(value): if isinstance(value, util.binary_type): util.warn_limited( "Unicode type received non-unicode " "bind param value %r.", - (util.ellipses_string(value),)) + (util.ellipses_string(value),), + ) return value + return process else: return None @@ -253,29 +272,34 @@ class String(Concatenable, TypeEngine): util.warn_limited( "Unicode type received non-unicode bind " "param value %r.", - (util.ellipses_string(value),)) + (util.ellipses_string(value),), + ) return value + return process else: return None def result_processor(self, dialect, coltype): wants_unicode = self.convert_unicode or dialect.convert_unicode - needs_convert = wants_unicode and \ - (dialect.returns_unicode_strings is not True or - self.convert_unicode in ('force', 'force_nocheck')) + needs_convert = wants_unicode and ( + dialect.returns_unicode_strings is not True + or self.convert_unicode in ("force", "force_nocheck") + ) needs_isinstance = ( - needs_convert and - dialect.returns_unicode_strings and - self.convert_unicode != 'force_nocheck' + needs_convert + and dialect.returns_unicode_strings + and self.convert_unicode != "force_nocheck" ) if needs_convert: if needs_isinstance: return processors.to_conditional_unicode_processor_factory( - dialect.encoding, self.unicode_error) + dialect.encoding, self.unicode_error + ) else: return processors.to_unicode_processor_factory( - dialect.encoding, self.unicode_error) + dialect.encoding, self.unicode_error + ) else: return None @@ -301,7 +325,8 @@ class Text(String): argument here, it will be rejected by others. """ - __visit_name__ = 'text' + + __visit_name__ = "text" class Unicode(String): @@ -360,7 +385,7 @@ class Unicode(String): """ - __visit_name__ = 'unicode' + __visit_name__ = "unicode" def __init__(self, length=None, **kwargs): """ @@ -371,8 +396,8 @@ class Unicode(String): defaults to ``True``. """ - kwargs.setdefault('convert_unicode', True) - kwargs.setdefault('_warn_on_bytestring', True) + kwargs.setdefault("convert_unicode", True) + kwargs.setdefault("_warn_on_bytestring", True) super(Unicode, self).__init__(length=length, **kwargs) @@ -389,7 +414,7 @@ class UnicodeText(Text): """ - __visit_name__ = 'unicode_text' + __visit_name__ = "unicode_text" def __init__(self, length=None, **kwargs): """ @@ -400,8 +425,8 @@ class UnicodeText(Text): defaults to ``True``. """ - kwargs.setdefault('convert_unicode', True) - kwargs.setdefault('_warn_on_bytestring', True) + kwargs.setdefault("convert_unicode", True) + kwargs.setdefault("_warn_on_bytestring", True) super(UnicodeText, self).__init__(length=length, **kwargs) @@ -409,7 +434,7 @@ class Integer(_LookupExpressionAdapter, TypeEngine): """A type for ``int`` integers.""" - __visit_name__ = 'integer' + __visit_name__ = "integer" def get_dbapi_type(self, dbapi): return dbapi.NUMBER @@ -421,6 +446,7 @@ class Integer(_LookupExpressionAdapter, TypeEngine): def literal_processor(self, dialect): def process(value): return str(value) + return process @util.memoized_property @@ -438,18 +464,9 @@ class Integer(_LookupExpressionAdapter, TypeEngine): Integer: self.__class__, Numeric: Numeric, }, - operators.div: { - Integer: self.__class__, - Numeric: Numeric, - }, - operators.truediv: { - Integer: self.__class__, - Numeric: Numeric, - }, - operators.sub: { - Integer: self.__class__, - Numeric: Numeric, - }, + operators.div: {Integer: self.__class__, Numeric: Numeric}, + operators.truediv: {Integer: self.__class__, Numeric: Numeric}, + operators.sub: {Integer: self.__class__, Numeric: Numeric}, } @@ -462,7 +479,7 @@ class SmallInteger(Integer): """ - __visit_name__ = 'small_integer' + __visit_name__ = "small_integer" class BigInteger(Integer): @@ -474,7 +491,7 @@ class BigInteger(Integer): """ - __visit_name__ = 'big_integer' + __visit_name__ = "big_integer" class Numeric(_LookupExpressionAdapter, TypeEngine): @@ -517,12 +534,17 @@ class Numeric(_LookupExpressionAdapter, TypeEngine): """ - __visit_name__ = 'numeric' + __visit_name__ = "numeric" _default_decimal_return_scale = 10 - def __init__(self, precision=None, scale=None, - decimal_return_scale=None, asdecimal=True): + def __init__( + self, + precision=None, + scale=None, + decimal_return_scale=None, + asdecimal=True, + ): """ Construct a Numeric. @@ -587,6 +609,7 @@ class Numeric(_LookupExpressionAdapter, TypeEngine): def literal_processor(self, dialect): def process(value): return str(value) + return process @property @@ -608,19 +631,23 @@ class Numeric(_LookupExpressionAdapter, TypeEngine): # we're a "numeric", DBAPI will give us Decimal directly return None else: - util.warn('Dialect %s+%s does *not* support Decimal ' - 'objects natively, and SQLAlchemy must ' - 'convert from floating point - rounding ' - 'errors and other issues may occur. Please ' - 'consider storing Decimal numbers as strings ' - 'or integers on this platform for lossless ' - 'storage.' % (dialect.name, dialect.driver)) + util.warn( + "Dialect %s+%s does *not* support Decimal " + "objects natively, and SQLAlchemy must " + "convert from floating point - rounding " + "errors and other issues may occur. Please " + "consider storing Decimal numbers as strings " + "or integers on this platform for lossless " + "storage." % (dialect.name, dialect.driver) + ) # we're a "numeric", DBAPI returns floats, convert. return processors.to_decimal_processor_factory( decimal.Decimal, - self.scale if self.scale is not None - else self._default_decimal_return_scale) + self.scale + if self.scale is not None + else self._default_decimal_return_scale, + ) else: if dialect.supports_native_decimal: return processors.to_float @@ -635,22 +662,13 @@ class Numeric(_LookupExpressionAdapter, TypeEngine): Numeric: self.__class__, Integer: self.__class__, }, - operators.div: { - Numeric: self.__class__, - Integer: self.__class__, - }, + operators.div: {Numeric: self.__class__, Integer: self.__class__}, operators.truediv: { Numeric: self.__class__, Integer: self.__class__, }, - operators.add: { - Numeric: self.__class__, - Integer: self.__class__, - }, - operators.sub: { - Numeric: self.__class__, - Integer: self.__class__, - } + operators.add: {Numeric: self.__class__, Integer: self.__class__}, + operators.sub: {Numeric: self.__class__, Integer: self.__class__}, } @@ -675,12 +693,17 @@ class Float(Numeric): """ - __visit_name__ = 'float' + __visit_name__ = "float" scale = None - def __init__(self, precision=None, asdecimal=False, - decimal_return_scale=None, **kwargs): + def __init__( + self, + precision=None, + asdecimal=False, + decimal_return_scale=None, + **kwargs + ): r""" Construct a Float. @@ -713,14 +736,15 @@ class Float(Numeric): self.asdecimal = asdecimal self.decimal_return_scale = decimal_return_scale if kwargs: - util.warn_deprecated("Additional keyword arguments " - "passed to Float ignored.") + util.warn_deprecated( + "Additional keyword arguments " "passed to Float ignored." + ) def result_processor(self, dialect, coltype): if self.asdecimal: return processors.to_decimal_processor_factory( - decimal.Decimal, - self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) elif dialect.supports_native_decimal: return processors.to_float else: @@ -746,7 +770,7 @@ class DateTime(_LookupExpressionAdapter, TypeEngine): """ - __visit_name__ = 'datetime' + __visit_name__ = "datetime" def __init__(self, timezone=False): """Construct a new :class:`.DateTime`. @@ -777,13 +801,8 @@ class DateTime(_LookupExpressionAdapter, TypeEngine): # static/functions-datetime.html. return { - operators.add: { - Interval: self.__class__, - }, - operators.sub: { - Interval: self.__class__, - DateTime: Interval, - }, + operators.add: {Interval: self.__class__}, + operators.sub: {Interval: self.__class__, DateTime: Interval}, } @@ -791,7 +810,7 @@ class Date(_LookupExpressionAdapter, TypeEngine): """A type for ``datetime.date()`` objects.""" - __visit_name__ = 'date' + __visit_name__ = "date" def get_dbapi_type(self, dbapi): return dbapi.DATETIME @@ -814,12 +833,9 @@ class Date(_LookupExpressionAdapter, TypeEngine): operators.sub: { # date - integer = date Integer: self.__class__, - # date - date = integer. Date: Integer, - Interval: DateTime, - # date - datetime = interval, # this one is not in the PG docs # but works @@ -832,7 +848,7 @@ class Time(_LookupExpressionAdapter, TypeEngine): """A type for ``datetime.time()`` objects.""" - __visit_name__ = 'time' + __visit_name__ = "time" def __init__(self, timezone=False): self.timezone = timezone @@ -850,14 +866,8 @@ class Time(_LookupExpressionAdapter, TypeEngine): # static/functions-datetime.html. return { - operators.add: { - Date: DateTime, - Interval: self.__class__ - }, - operators.sub: { - Time: Interval, - Interval: self.__class__, - }, + operators.add: {Date: DateTime, Interval: self.__class__}, + operators.sub: {Time: Interval, Interval: self.__class__}, } @@ -872,6 +882,7 @@ class _Binary(TypeEngine): def process(value): value = value.decode(dialect.encoding).replace("'", "''") return "'%s'" % value + return process @property @@ -891,14 +902,17 @@ class _Binary(TypeEngine): return DBAPIBinary(value) else: return None + return process # Python 3 has native bytes() type # both sqlite3 and pg8000 seem to return it, # psycopg2 as of 2.5 returns 'memoryview' if util.py2k: + def result_processor(self, dialect, coltype): if util.jython: + def process(value): if value is not None: if isinstance(value, array.array): @@ -906,15 +920,19 @@ class _Binary(TypeEngine): return str(value) else: return None + else: process = processors.to_str return process + else: + def result_processor(self, dialect, coltype): def process(value): if value is not None: value = bytes(value) return value + return process def coerce_compared_value(self, op, value): @@ -939,7 +957,7 @@ class LargeBinary(_Binary): """ - __visit_name__ = 'large_binary' + __visit_name__ = "large_binary" def __init__(self, length=None): """ @@ -958,8 +976,9 @@ class Binary(LargeBinary): """Deprecated. Renamed to LargeBinary.""" def __init__(self, *arg, **kw): - util.warn_deprecated('The Binary type has been renamed to ' - 'LargeBinary.') + util.warn_deprecated( + "The Binary type has been renamed to " "LargeBinary." + ) LargeBinary.__init__(self, *arg, **kw) @@ -986,8 +1005,15 @@ class SchemaType(SchemaEventTarget): """ - def __init__(self, name=None, schema=None, metadata=None, - inherit_schema=False, quote=None, _create_events=True): + def __init__( + self, + name=None, + schema=None, + metadata=None, + inherit_schema=False, + quote=None, + _create_events=True, + ): if name is not None: self.name = quoted_name(name, quote) else: @@ -1001,12 +1027,12 @@ class SchemaType(SchemaEventTarget): event.listen( self.metadata, "before_create", - util.portable_instancemethod(self._on_metadata_create) + util.portable_instancemethod(self._on_metadata_create), ) event.listen( self.metadata, "after_drop", - util.portable_instancemethod(self._on_metadata_drop) + util.portable_instancemethod(self._on_metadata_drop), ) def _translate_schema(self, effective_schema, map_): @@ -1018,7 +1044,7 @@ class SchemaType(SchemaEventTarget): def _variant_mapping_for_set_table(self, column): if isinstance(column.type, Variant): variant_mapping = column.type.mapping.copy() - variant_mapping['_default'] = column.type.impl + variant_mapping["_default"] = column.type.impl else: variant_mapping = None return variant_mapping @@ -1036,15 +1062,15 @@ class SchemaType(SchemaEventTarget): table, "before_create", util.portable_instancemethod( - self._on_table_create, - {"variant_mapping": variant_mapping}) + self._on_table_create, {"variant_mapping": variant_mapping} + ), ) event.listen( table, "after_drop", util.portable_instancemethod( - self._on_table_drop, - {"variant_mapping": variant_mapping}) + self._on_table_drop, {"variant_mapping": variant_mapping} + ), ) if self.metadata is None: # TODO: what's the difference between self.metadata @@ -1054,29 +1080,33 @@ class SchemaType(SchemaEventTarget): "before_create", util.portable_instancemethod( self._on_metadata_create, - {"variant_mapping": variant_mapping}) + {"variant_mapping": variant_mapping}, + ), ) event.listen( table.metadata, "after_drop", util.portable_instancemethod( self._on_metadata_drop, - {"variant_mapping": variant_mapping}) + {"variant_mapping": variant_mapping}, + ), ) def copy(self, **kw): return self.adapt(self.__class__, _create_events=True) def adapt(self, impltype, **kw): - schema = kw.pop('schema', self.schema) - metadata = kw.pop('metadata', self.metadata) - _create_events = kw.pop('_create_events', False) - return impltype(name=self.name, - schema=schema, - inherit_schema=self.inherit_schema, - metadata=metadata, - _create_events=_create_events, - **kw) + schema = kw.pop("schema", self.schema) + metadata = kw.pop("metadata", self.metadata) + _create_events = kw.pop("_create_events", False) + return impltype( + name=self.name, + schema=schema, + inherit_schema=self.inherit_schema, + metadata=metadata, + _create_events=_create_events, + **kw + ) @property def bind(self): @@ -1133,15 +1163,17 @@ class SchemaType(SchemaEventTarget): t._on_metadata_drop(target, bind, **kw) def _is_impl_for_variant(self, dialect, kw): - variant_mapping = kw.pop('variant_mapping', None) + variant_mapping = kw.pop("variant_mapping", None) if variant_mapping is None: return True - if dialect.name in variant_mapping and \ - variant_mapping[dialect.name] is self: + if ( + dialect.name in variant_mapping + and variant_mapping[dialect.name] is self + ): return True elif dialect.name not in variant_mapping: - return variant_mapping['_default'] is self + return variant_mapping["_default"] is self class Enum(Emulated, String, SchemaType): @@ -1220,7 +1252,8 @@ class Enum(Emulated, String, SchemaType): :class:`.mysql.ENUM` - MySQL-specific type """ - __visit_name__ = 'enum' + + __visit_name__ = "enum" def __init__(self, *enums, **kw): r"""Construct an enum. @@ -1322,15 +1355,15 @@ class Enum(Emulated, String, SchemaType): other arguments in kw to pass through. """ - self.native_enum = kw.pop('native_enum', True) - self.create_constraint = kw.pop('create_constraint', True) - self.values_callable = kw.pop('values_callable', None) + self.native_enum = kw.pop("native_enum", True) + self.create_constraint = kw.pop("create_constraint", True) + self.values_callable = kw.pop("values_callable", None) values, objects = self._parse_into_values(enums, kw) self._setup_for_values(values, objects, kw) - convert_unicode = kw.pop('convert_unicode', None) - self.validate_strings = kw.pop('validate_strings', False) + convert_unicode = kw.pop("convert_unicode", None) + self.validate_strings = kw.pop("validate_strings", False) if convert_unicode is None: for e in self.enums: @@ -1347,33 +1380,35 @@ class Enum(Emulated, String, SchemaType): self._valid_lookup[None] = self._object_lookup[None] = None super(Enum, self).__init__( - length=length, - convert_unicode=convert_unicode, + length=length, convert_unicode=convert_unicode ) if self.enum_class: - kw.setdefault('name', self.enum_class.__name__.lower()) + kw.setdefault("name", self.enum_class.__name__.lower()) SchemaType.__init__( self, - name=kw.pop('name', None), - schema=kw.pop('schema', None), - metadata=kw.pop('metadata', None), - inherit_schema=kw.pop('inherit_schema', False), - quote=kw.pop('quote', None), - _create_events=kw.pop('_create_events', True) + name=kw.pop("name", None), + schema=kw.pop("schema", None), + metadata=kw.pop("metadata", None), + inherit_schema=kw.pop("inherit_schema", False), + quote=kw.pop("quote", None), + _create_events=kw.pop("_create_events", True), ) def _parse_into_values(self, enums, kw): - if not enums and '_enums' in kw: - enums = kw.pop('_enums') + if not enums and "_enums" in kw: + enums = kw.pop("_enums") - if len(enums) == 1 and hasattr(enums[0], '__members__'): + if len(enums) == 1 and hasattr(enums[0], "__members__"): self.enum_class = enums[0] if self.values_callable: values = self.values_callable(self.enum_class) else: values = list(self.enum_class.__members__) - objects = [self.enum_class.__members__[k] for k in self.enum_class.__members__] + objects = [ + self.enum_class.__members__[k] + for k in self.enum_class.__members__ + ] return values, objects else: self.enum_class = None @@ -1382,18 +1417,16 @@ class Enum(Emulated, String, SchemaType): def _setup_for_values(self, values, objects, kw): self.enums = list(values) - self._valid_lookup = dict( - zip(reversed(objects), reversed(values)) - ) + self._valid_lookup = dict(zip(reversed(objects), reversed(values))) - self._object_lookup = dict( - zip(values, objects) - ) + self._object_lookup = dict(zip(values, objects)) - self._valid_lookup.update([ - (value, self._valid_lookup[self._object_lookup[value]]) - for value in values - ]) + self._valid_lookup.update( + [ + (value, self._valid_lookup[self._object_lookup[value]]) + for value in values + ] + ) @property def native(self): @@ -1411,22 +1444,24 @@ class Enum(Emulated, String, SchemaType): # here between an INSERT statement and a criteria used in a SELECT, # for now we're staying conservative w/ behavioral changes (perhaps # someone has a trigger that handles strings on INSERT) - if not self.validate_strings and \ - isinstance(elem, compat.string_types): + if not self.validate_strings and isinstance( + elem, compat.string_types + ): return elem else: raise LookupError( - '"%s" is not among the defined enum values' % elem) + '"%s" is not among the defined enum values' % elem + ) class Comparator(String.Comparator): - def _adapt_expression(self, op, other_comparator): op, typ = super(Enum.Comparator, self)._adapt_expression( - op, other_comparator) + op, other_comparator + ) if op is operators.concat_op: typ = String( - self.type.length, - convert_unicode=self.type.convert_unicode) + self.type.length, convert_unicode=self.type.convert_unicode + ) return op, typ comparator_factory = Comparator @@ -1436,38 +1471,40 @@ class Enum(Emulated, String, SchemaType): return self._object_lookup[elem] except KeyError: raise LookupError( - '"%s" is not among the defined enum values' % elem) + '"%s" is not among the defined enum values' % elem + ) def __repr__(self): return util.generic_repr( self, - additional_kw=[('native_enum', True)], + additional_kw=[("native_enum", True)], to_inspect=[Enum, SchemaType], ) def adapt_to_emulated(self, impltype, **kw): kw.setdefault("convert_unicode", self.convert_unicode) kw.setdefault("validate_strings", self.validate_strings) - kw.setdefault('name', self.name) - kw.setdefault('schema', self.schema) - kw.setdefault('inherit_schema', self.inherit_schema) - kw.setdefault('metadata', self.metadata) - kw.setdefault('_create_events', False) - kw.setdefault('native_enum', self.native_enum) - kw.setdefault('values_callable', self.values_callable) - kw.setdefault('create_constraint', self.create_constraint) - assert '_enums' in kw + kw.setdefault("name", self.name) + kw.setdefault("schema", self.schema) + kw.setdefault("inherit_schema", self.inherit_schema) + kw.setdefault("metadata", self.metadata) + kw.setdefault("_create_events", False) + kw.setdefault("native_enum", self.native_enum) + kw.setdefault("values_callable", self.values_callable) + kw.setdefault("create_constraint", self.create_constraint) + assert "_enums" in kw return impltype(**kw) def adapt(self, impltype, **kw): - kw['_enums'] = self._enums_argument + kw["_enums"] = self._enums_argument return super(Enum, self).adapt(impltype, **kw) def _should_create_constraint(self, compiler, **kw): if not self._is_impl_for_variant(compiler.dialect, kw): return False - return not self.native_enum or \ - not compiler.dialect.supports_native_enum + return ( + not self.native_enum or not compiler.dialect.supports_native_enum + ) @util.dependencies("sqlalchemy.sql.schema") def _set_table(self, schema, column, table): @@ -1483,20 +1520,21 @@ class Enum(Emulated, String, SchemaType): name=_defer_name(self.name), _create_rule=util.portable_instancemethod( self._should_create_constraint, - {"variant_mapping": variant_mapping}), - _type_bound=True + {"variant_mapping": variant_mapping}, + ), + _type_bound=True, ) assert e.table is table def literal_processor(self, dialect): - parent_processor = super( - Enum, self).literal_processor(dialect) + parent_processor = super(Enum, self).literal_processor(dialect) def process(value): value = self._db_value_for_elem(value) if parent_processor: value = parent_processor(value) return value + return process def bind_processor(self, dialect): @@ -1510,8 +1548,7 @@ class Enum(Emulated, String, SchemaType): return process def result_processor(self, dialect, coltype): - parent_processor = super(Enum, self).result_processor( - dialect, coltype) + parent_processor = super(Enum, self).result_processor(dialect, coltype) def process(value): if parent_processor: @@ -1548,8 +1585,9 @@ class PickleType(TypeDecorator): impl = LargeBinary - def __init__(self, protocol=pickle.HIGHEST_PROTOCOL, - pickler=None, comparator=None): + def __init__( + self, protocol=pickle.HIGHEST_PROTOCOL, pickler=None, comparator=None + ): """ Construct a PickleType. @@ -1570,40 +1608,46 @@ class PickleType(TypeDecorator): super(PickleType, self).__init__() def __reduce__(self): - return PickleType, (self.protocol, - None, - self.comparator) + return PickleType, (self.protocol, None, self.comparator) def bind_processor(self, dialect): impl_processor = self.impl.bind_processor(dialect) dumps = self.pickler.dumps protocol = self.protocol if impl_processor: + def process(value): if value is not None: value = dumps(value, protocol) return impl_processor(value) + else: + def process(value): if value is not None: value = dumps(value, protocol) return value + return process def result_processor(self, dialect, coltype): impl_processor = self.impl.result_processor(dialect, coltype) loads = self.pickler.loads if impl_processor: + def process(value): value = impl_processor(value) if value is None: return None return loads(value) + else: + def process(value): if value is None: return None return loads(value) + return process def compare_values(self, x, y): @@ -1635,11 +1679,10 @@ class Boolean(Emulated, TypeEngine, SchemaType): """ - __visit_name__ = 'boolean' + __visit_name__ = "boolean" native = True - def __init__( - self, create_constraint=True, name=None, _create_events=True): + def __init__(self, create_constraint=True, name=None, _create_events=True): """Construct a Boolean. :param create_constraint: defaults to True. If the boolean @@ -1657,8 +1700,10 @@ class Boolean(Emulated, TypeEngine, SchemaType): def _should_create_constraint(self, compiler, **kw): if not self._is_impl_for_variant(compiler.dialect, kw): return False - return not compiler.dialect.supports_native_boolean and \ - compiler.dialect.non_native_boolean_check_constraint + return ( + not compiler.dialect.supports_native_boolean + and compiler.dialect.non_native_boolean_check_constraint + ) @util.dependencies("sqlalchemy.sql.schema") def _set_table(self, schema, column, table): @@ -1672,8 +1717,9 @@ class Boolean(Emulated, TypeEngine, SchemaType): name=_defer_name(self.name), _create_rule=util.portable_instancemethod( self._should_create_constraint, - {"variant_mapping": variant_mapping}), - _type_bound=True + {"variant_mapping": variant_mapping}, + ), + _type_bound=True, ) assert e.table is table @@ -1686,11 +1732,11 @@ class Boolean(Emulated, TypeEngine, SchemaType): def _strict_as_bool(self, value): if value not in self._strict_bools: if not isinstance(value, int): - raise TypeError( - "Not a boolean value: %r" % value) + raise TypeError("Not a boolean value: %r" % value) else: raise ValueError( - "Value %r is not None, True, or False" % value) + "Value %r is not None, True, or False" % value + ) return value def literal_processor(self, dialect): @@ -1700,6 +1746,7 @@ class Boolean(Emulated, TypeEngine, SchemaType): def process(value): return true if self._strict_as_bool(value) else false + return process def bind_processor(self, dialect): @@ -1714,6 +1761,7 @@ class Boolean(Emulated, TypeEngine, SchemaType): if value is not None: value = _coerce(value) return value + return process def result_processor(self, dialect, coltype): @@ -1736,18 +1784,10 @@ class _AbstractInterval(_LookupExpressionAdapter, TypeEngine): DateTime: DateTime, Time: Time, }, - operators.sub: { - Interval: self.__class__ - }, - operators.mul: { - Numeric: self.__class__ - }, - operators.truediv: { - Numeric: self.__class__ - }, - operators.div: { - Numeric: self.__class__ - } + operators.sub: {Interval: self.__class__}, + operators.mul: {Numeric: self.__class__}, + operators.truediv: {Numeric: self.__class__}, + operators.div: {Numeric: self.__class__}, } @property @@ -1780,9 +1820,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator): impl = DateTime epoch = dt.datetime.utcfromtimestamp(0) - def __init__(self, native=True, - second_precision=None, - day_precision=None): + def __init__(self, native=True, second_precision=None, day_precision=None): """Construct an Interval object. :param native: when True, use the actual @@ -1815,31 +1853,39 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator): impl_processor = self.impl.bind_processor(dialect) epoch = self.epoch if impl_processor: + def process(value): if value is not None: value = epoch + value return impl_processor(value) + else: + def process(value): if value is not None: value = epoch + value return value + return process def result_processor(self, dialect, coltype): impl_processor = self.impl.result_processor(dialect, coltype) epoch = self.epoch if impl_processor: + def process(value): value = impl_processor(value) if value is None: return None return value - epoch + else: + def process(value): if value is None: return None return value - epoch + return process @@ -1986,10 +2032,11 @@ class JSON(Indexable, TypeEngine): """ - __visit_name__ = 'JSON' + + __visit_name__ = "JSON" hashable = False - NULL = util.symbol('JSON_NULL') + NULL = util.symbol("JSON_NULL") """Describe the json value of NULL. This value is used to force the JSON value of ``"null"`` to be @@ -2109,20 +2156,25 @@ class JSON(Indexable, TypeEngine): class Comparator(Indexable.Comparator, Concatenable.Comparator): """Define comparison operations for :class:`.types.JSON`.""" - @util.dependencies('sqlalchemy.sql.default_comparator') + @util.dependencies("sqlalchemy.sql.default_comparator") def _setup_getitem(self, default_comparator, index): - if not isinstance(index, util.string_types) and \ - isinstance(index, compat.collections_abc.Sequence): + if not isinstance(index, util.string_types) and isinstance( + index, compat.collections_abc.Sequence + ): index = default_comparator._check_literal( - self.expr, operators.json_path_getitem_op, - index, bindparam_type=JSON.JSONPathType + self.expr, + operators.json_path_getitem_op, + index, + bindparam_type=JSON.JSONPathType, ) operator = operators.json_path_getitem_op else: index = default_comparator._check_literal( - self.expr, operators.json_getitem_op, - index, bindparam_type=JSON.JSONIndexType + self.expr, + operators.json_getitem_op, + index, + bindparam_type=JSON.JSONIndexType, ) operator = operators.json_getitem_op @@ -2172,6 +2224,7 @@ class JSON(Indexable, TypeEngine): if string_process: value = string_process(value) return json_deserializer(value) + return process @@ -2266,7 +2319,8 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): :class:`.postgresql.ARRAY` """ - __visit_name__ = 'ARRAY' + + __visit_name__ = "ARRAY" zero_indexes = False """if True, Python zero-based indexes should be interpreted as one-based @@ -2285,21 +2339,23 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): if isinstance(index, slice): return_type = self.type if self.type.zero_indexes: - index = slice( - index.start + 1, - index.stop + 1, - index.step - ) + index = slice(index.start + 1, index.stop + 1, index.step) index = Slice( _literal_as_binds( - index.start, name=self.expr.key, - type_=type_api.INTEGERTYPE), + index.start, + name=self.expr.key, + type_=type_api.INTEGERTYPE, + ), _literal_as_binds( - index.stop, name=self.expr.key, - type_=type_api.INTEGERTYPE), + index.stop, + name=self.expr.key, + type_=type_api.INTEGERTYPE, + ), _literal_as_binds( - index.step, name=self.expr.key, - type_=type_api.INTEGERTYPE) + index.step, + name=self.expr.key, + type_=type_api.INTEGERTYPE, + ), ) else: if self.type.zero_indexes: @@ -2307,16 +2363,18 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): if self.type.dimensions is None or self.type.dimensions == 1: return_type = self.type.item_type else: - adapt_kw = {'dimensions': self.type.dimensions - 1} + adapt_kw = {"dimensions": self.type.dimensions - 1} return_type = self.type.adapt( - self.type.__class__, **adapt_kw) + self.type.__class__, **adapt_kw + ) return operators.getitem, index, return_type def contains(self, *arg, **kw): raise NotImplementedError( "ARRAY.contains() not implemented for the base " - "ARRAY type; please use the dialect-specific ARRAY type") + "ARRAY type; please use the dialect-specific ARRAY type" + ) @util.dependencies("sqlalchemy.sql.elements") def any(self, elements, other, operator=None): @@ -2350,7 +2408,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): operator = operator if operator else operators.eq return operator( elements._literal_as_binds(other), - elements.CollectionAggregate._create_any(self.expr) + elements.CollectionAggregate._create_any(self.expr), ) @util.dependencies("sqlalchemy.sql.elements") @@ -2385,13 +2443,14 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): operator = operator if operator else operators.eq return operator( elements._literal_as_binds(other), - elements.CollectionAggregate._create_all(self.expr) + elements.CollectionAggregate._create_all(self.expr), ) comparator_factory = Comparator - def __init__(self, item_type, as_tuple=False, dimensions=None, - zero_indexes=False): + def __init__( + self, item_type, as_tuple=False, dimensions=None, zero_indexes=False + ): """Construct an :class:`.types.ARRAY`. E.g.:: @@ -2424,8 +2483,10 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): """ if isinstance(item_type, ARRAY): - raise ValueError("Do not nest ARRAY types; ARRAY(basetype) " - "handles multi-dimensional arrays of basetype") + raise ValueError( + "Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype" + ) if isinstance(item_type, type): item_type = item_type() self.item_type = item_type @@ -2463,35 +2524,37 @@ class REAL(Float): """The SQL REAL type.""" - __visit_name__ = 'REAL' + __visit_name__ = "REAL" class FLOAT(Float): """The SQL FLOAT type.""" - __visit_name__ = 'FLOAT' + __visit_name__ = "FLOAT" class NUMERIC(Numeric): """The SQL NUMERIC type.""" - __visit_name__ = 'NUMERIC' + __visit_name__ = "NUMERIC" class DECIMAL(Numeric): """The SQL DECIMAL type.""" - __visit_name__ = 'DECIMAL' + __visit_name__ = "DECIMAL" class INTEGER(Integer): """The SQL INT or INTEGER type.""" - __visit_name__ = 'INTEGER' + __visit_name__ = "INTEGER" + + INT = INTEGER @@ -2499,14 +2562,14 @@ class SMALLINT(SmallInteger): """The SQL SMALLINT type.""" - __visit_name__ = 'SMALLINT' + __visit_name__ = "SMALLINT" class BIGINT(BigInteger): """The SQL BIGINT type.""" - __visit_name__ = 'BIGINT' + __visit_name__ = "BIGINT" class TIMESTAMP(DateTime): @@ -2520,7 +2583,7 @@ class TIMESTAMP(DateTime): """ - __visit_name__ = 'TIMESTAMP' + __visit_name__ = "TIMESTAMP" def __init__(self, timezone=False): """Construct a new :class:`.TIMESTAMP`. @@ -2543,28 +2606,28 @@ class DATETIME(DateTime): """The SQL DATETIME type.""" - __visit_name__ = 'DATETIME' + __visit_name__ = "DATETIME" class DATE(Date): """The SQL DATE type.""" - __visit_name__ = 'DATE' + __visit_name__ = "DATE" class TIME(Time): """The SQL TIME type.""" - __visit_name__ = 'TIME' + __visit_name__ = "TIME" class TEXT(Text): """The SQL TEXT type.""" - __visit_name__ = 'TEXT' + __visit_name__ = "TEXT" class CLOB(Text): @@ -2574,63 +2637,63 @@ class CLOB(Text): This type is found in Oracle and Informix. """ - __visit_name__ = 'CLOB' + __visit_name__ = "CLOB" class VARCHAR(String): """The SQL VARCHAR type.""" - __visit_name__ = 'VARCHAR' + __visit_name__ = "VARCHAR" class NVARCHAR(Unicode): """The SQL NVARCHAR type.""" - __visit_name__ = 'NVARCHAR' + __visit_name__ = "NVARCHAR" class CHAR(String): """The SQL CHAR type.""" - __visit_name__ = 'CHAR' + __visit_name__ = "CHAR" class NCHAR(Unicode): """The SQL NCHAR type.""" - __visit_name__ = 'NCHAR' + __visit_name__ = "NCHAR" class BLOB(LargeBinary): """The SQL BLOB type.""" - __visit_name__ = 'BLOB' + __visit_name__ = "BLOB" class BINARY(_Binary): """The SQL BINARY type.""" - __visit_name__ = 'BINARY' + __visit_name__ = "BINARY" class VARBINARY(_Binary): """The SQL VARBINARY type.""" - __visit_name__ = 'VARBINARY' + __visit_name__ = "VARBINARY" class BOOLEAN(Boolean): """The SQL BOOLEAN type.""" - __visit_name__ = 'BOOLEAN' + __visit_name__ = "BOOLEAN" class NullType(TypeEngine): @@ -2657,7 +2720,8 @@ class NullType(TypeEngine): construct. """ - __visit_name__ = 'null' + + __visit_name__ = "null" _isnull = True @@ -2666,16 +2730,18 @@ class NullType(TypeEngine): def literal_processor(self, dialect): def process(value): return "NULL" + return process class Comparator(TypeEngine.Comparator): - def _adapt_expression(self, op, other_comparator): - if isinstance(other_comparator, NullType.Comparator) or \ - not operators.is_commutative(op): + if isinstance( + other_comparator, NullType.Comparator + ) or not operators.is_commutative(op): return op, self.expr.type else: return other_comparator._adapt_expression(op, self) + comparator_factory = Comparator @@ -2694,6 +2760,7 @@ class MatchType(Boolean): """ + NULLTYPE = NullType() BOOLEANTYPE = Boolean() STRINGTYPE = String() @@ -2709,7 +2776,7 @@ _type_map = { dt.datetime: DateTime(), dt.time: Time(), dt.timedelta: Interval(), - util.NoneType: NULLTYPE + util.NoneType: NULLTYPE, } if util.py3k: @@ -2729,19 +2796,23 @@ def _resolve_value_to_type(value): # objects. insp = inspection.inspect(value, False) if ( - insp is not None and - # foil mock.Mock() and other impostors by ensuring - # the inspection target itself self-inspects - insp.__class__ in inspection._registrars + insp is not None + and + # foil mock.Mock() and other impostors by ensuring + # the inspection target itself self-inspects + insp.__class__ in inspection._registrars ): raise exc.ArgumentError( - "Object %r is not legal as a SQL literal value" % value) + "Object %r is not legal as a SQL literal value" % value + ) return NULLTYPE else: return _result_type + # back-assign to type_api from . import type_api + type_api.BOOLEANTYPE = BOOLEANTYPE type_api.STRINGTYPE = STRINGTYPE type_api.INTEGERTYPE = INTEGERTYPE diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index a8dfa19be..7fe780783 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -49,7 +49,8 @@ class TypeEngine(Visitable): """ - __slots__ = 'expr', 'type' + + __slots__ = "expr", "type" default_comparator = None @@ -57,16 +58,15 @@ class TypeEngine(Visitable): self.expr = expr self.type = expr.type - @util.dependencies('sqlalchemy.sql.default_comparator') + @util.dependencies("sqlalchemy.sql.default_comparator") def operate(self, default_comparator, op, *other, **kwargs): o = default_comparator.operator_lookup[op.__name__] return o[0](self.expr, op, *(other + o[1:]), **kwargs) - @util.dependencies('sqlalchemy.sql.default_comparator') + @util.dependencies("sqlalchemy.sql.default_comparator") def reverse_operate(self, default_comparator, op, other, **kwargs): o = default_comparator.operator_lookup[op.__name__] - return o[0](self.expr, op, other, - reverse=True, *o[1:], **kwargs) + return o[0](self.expr, op, other, reverse=True, *o[1:], **kwargs) def _adapt_expression(self, op, other_comparator): """evaluate the return type of <self> <op> <othertype>, @@ -97,7 +97,7 @@ class TypeEngine(Visitable): return op, self.type def __reduce__(self): - return _reconstitute_comparator, (self.expr, ) + return _reconstitute_comparator, (self.expr,) hashable = True """Flag, if False, means values from this type aren't hashable. @@ -313,8 +313,10 @@ class TypeEngine(Visitable): """ - return self.__class__.column_expression.__code__ \ + return ( + self.__class__.column_expression.__code__ is not TypeEngine.column_expression.__code__ + ) def bind_expression(self, bindvalue): """"Given a bind value (i.e. a :class:`.BindParameter` instance), @@ -351,8 +353,10 @@ class TypeEngine(Visitable): """ - return self.__class__.bind_expression.__code__ \ + return ( + self.__class__.bind_expression.__code__ is not TypeEngine.bind_expression.__code__ + ) @staticmethod def _to_instance(cls_or_self): @@ -441,9 +445,9 @@ class TypeEngine(Visitable): """ try: - return dialect._type_memos[self]['impl'] + return dialect._type_memos[self]["impl"] except KeyError: - return self._dialect_info(dialect)['impl'] + return self._dialect_info(dialect)["impl"] def _unwrapped_dialect_impl(self, dialect): """Return the 'unwrapped' dialect impl for this type. @@ -462,20 +466,20 @@ class TypeEngine(Visitable): def _cached_literal_processor(self, dialect): """Return a dialect-specific literal processor for this type.""" try: - return dialect._type_memos[self]['literal'] + return dialect._type_memos[self]["literal"] except KeyError: d = self._dialect_info(dialect) - d['literal'] = lp = d['impl'].literal_processor(dialect) + d["literal"] = lp = d["impl"].literal_processor(dialect) return lp def _cached_bind_processor(self, dialect): """Return a dialect-specific bind processor for this type.""" try: - return dialect._type_memos[self]['bind'] + return dialect._type_memos[self]["bind"] except KeyError: d = self._dialect_info(dialect) - d['bind'] = bp = d['impl'].bind_processor(dialect) + d["bind"] = bp = d["impl"].bind_processor(dialect) return bp def _cached_result_processor(self, dialect, coltype): @@ -488,7 +492,7 @@ class TypeEngine(Visitable): # key assumption: DBAPI type codes are # constants. Else this dictionary would # grow unbounded. - d[coltype] = rp = d['impl'].result_processor(dialect, coltype) + d[coltype] = rp = d["impl"].result_processor(dialect, coltype) return rp def _cached_custom_processor(self, dialect, key, fn): @@ -496,7 +500,7 @@ class TypeEngine(Visitable): return dialect._type_memos[self][key] except KeyError: d = self._dialect_info(dialect) - impl = d['impl'] + impl = d["impl"] d[key] = result = fn(impl) return result @@ -513,7 +517,7 @@ class TypeEngine(Visitable): impl = self.adapt(type(self)) # this can't be self, else we create a cycle assert impl is not self - dialect._type_memos[self] = d = {'impl': impl} + dialect._type_memos[self] = d = {"impl": impl} return d def _gen_dialect_impl(self, dialect): @@ -549,8 +553,10 @@ class TypeEngine(Visitable): """ _coerced_type = _resolve_value_to_type(value) - if _coerced_type is NULLTYPE or _coerced_type._type_affinity \ - is self._type_affinity: + if ( + _coerced_type is NULLTYPE + or _coerced_type._type_affinity is self._type_affinity + ): return self else: return _coerced_type @@ -586,8 +592,7 @@ class TypeEngine(Visitable): def __str__(self): if util.py2k: - return unicode(self.compile()).\ - encode('ascii', 'backslashreplace') + return unicode(self.compile()).encode("ascii", "backslashreplace") else: return str(self.compile()) @@ -645,15 +650,16 @@ class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)): ``type_expression``, if it receives ``**kw`` in its signature. """ + __visit_name__ = "user_defined" - ensure_kwarg = 'get_col_spec' + ensure_kwarg = "get_col_spec" class Comparator(TypeEngine.Comparator): __slots__ = () def _adapt_expression(self, op, other_comparator): - if hasattr(self.type, 'adapt_operator'): + if hasattr(self.type, "adapt_operator"): util.warn_deprecated( "UserDefinedType.adapt_operator is deprecated. Create " "a UserDefinedType.Comparator subclass instead which " @@ -854,6 +860,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): will cause the index value ``'foo'`` to be JSON encoded. """ + __visit_name__ = "type_decorator" def __init__(self, *args, **kwargs): @@ -874,14 +881,16 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): """ - if not hasattr(self.__class__, 'impl'): - raise AssertionError("TypeDecorator implementations " - "require a class-level variable " - "'impl' which refers to the class of " - "type being decorated") + if not hasattr(self.__class__, "impl"): + raise AssertionError( + "TypeDecorator implementations " + "require a class-level variable " + "'impl' which refers to the class of " + "type being decorated" + ) self.impl = to_instance(self.__class__.impl, *args, **kwargs) - coerce_to_is_types = (util.NoneType, ) + coerce_to_is_types = (util.NoneType,) """Specify those Python types which should be coerced at the expression level to "IS <constant>" when compared using ``==`` (and same for ``IS NOT`` in conjunction with ``!=``. @@ -906,24 +915,27 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): __slots__ = () def operate(self, op, *other, **kwargs): - kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types + kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types return super(TypeDecorator.Comparator, self).operate( - op, *other, **kwargs) + op, *other, **kwargs + ) def reverse_operate(self, op, other, **kwargs): - kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types + kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types return super(TypeDecorator.Comparator, self).reverse_operate( - op, other, **kwargs) + op, other, **kwargs + ) @property def comparator_factory(self): if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__: return self.impl.comparator_factory else: - return type("TDComparator", - (TypeDecorator.Comparator, - self.impl.comparator_factory), - {}) + return type( + "TDComparator", + (TypeDecorator.Comparator, self.impl.comparator_factory), + {}, + ) def _gen_dialect_impl(self, dialect): """ @@ -939,10 +951,11 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): typedesc = self._unwrapped_dialect_impl(dialect) tt = self.copy() if not isinstance(tt, self.__class__): - raise AssertionError('Type object %s does not properly ' - 'implement the copy() method, it must ' - 'return an object of type %s' % - (self, self.__class__)) + raise AssertionError( + "Type object %s does not properly " + "implement the copy() method, it must " + "return an object of type %s" % (self, self.__class__) + ) tt.impl = typedesc return tt @@ -1099,8 +1112,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): """ - return self.__class__.process_bind_param.__code__ \ + return ( + self.__class__.process_bind_param.__code__ is not TypeDecorator.process_bind_param.__code__ + ) @util.memoized_property def _has_literal_processor(self): @@ -1109,8 +1124,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): """ - return self.__class__.process_literal_param.__code__ \ + return ( + self.__class__.process_literal_param.__code__ is not TypeDecorator.process_literal_param.__code__ + ) def literal_processor(self, dialect): """Provide a literal processing function for the given @@ -1147,9 +1164,12 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): if process_param: impl_processor = self.impl.literal_processor(dialect) if impl_processor: + def process(value): return impl_processor(process_param(value, dialect)) + else: + def process(value): return process_param(value, dialect) @@ -1180,10 +1200,12 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): process_param = self.process_bind_param impl_processor = self.impl.bind_processor(dialect) if impl_processor: + def process(value): return impl_processor(process_param(value, dialect)) else: + def process(value): return process_param(value, dialect) @@ -1200,8 +1222,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): exception throw. """ - return self.__class__.process_result_value.__code__ \ + return ( + self.__class__.process_result_value.__code__ is not TypeDecorator.process_result_value.__code__ + ) def result_processor(self, dialect, coltype): """Provide a result value processing function for the given @@ -1225,13 +1249,14 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): """ if self._has_result_processor: process_value = self.process_result_value - impl_processor = self.impl.result_processor(dialect, - coltype) + impl_processor = self.impl.result_processor(dialect, coltype) if impl_processor: + def process(value): return process_value(impl_processor(value), dialect) else: + def process(value): return process_value(value, dialect) @@ -1397,7 +1422,8 @@ class Variant(TypeDecorator): if dialect_name in self.mapping: raise exc.ArgumentError( "Dialect '%s' is already present in " - "the mapping for this Variant" % dialect_name) + "the mapping for this Variant" % dialect_name + ) mapping = self.mapping.copy() mapping[dialect_name] = type_ return Variant(self.impl, mapping) @@ -1439,6 +1465,6 @@ def adapt_type(typeobj, colspecs): # but it turns out the originally given "generic" type # is actually a subclass of our resulting type, then we were already # given a more specific type than that required; so use that. - if (issubclass(typeobj.__class__, impltype)): + if issubclass(typeobj.__class__, impltype): return typeobj return typeobj.adapt(impltype) diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 12cfe09d1..4feaf9938 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -15,15 +15,29 @@ from . import operators, visitors from itertools import chain from collections import deque -from .elements import BindParameter, ColumnClause, ColumnElement, \ - Null, UnaryExpression, literal_column, Label, _label_reference, \ - _textual_label_reference -from .selectable import SelectBase, ScalarSelect, Join, FromClause, FromGrouping +from .elements import ( + BindParameter, + ColumnClause, + ColumnElement, + Null, + UnaryExpression, + literal_column, + Label, + _label_reference, + _textual_label_reference, +) +from .selectable import ( + SelectBase, + ScalarSelect, + Join, + FromClause, + FromGrouping, +) from .schema import Column join_condition = util.langhelpers.public_factory( - Join._join_condition, - ".sql.util.join_condition") + Join._join_condition, ".sql.util.join_condition" +) # names that are still being imported from the outside from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate @@ -88,8 +102,9 @@ def find_left_clause_that_matches_given(clauses, join_from): for idx in liberal_idx: f = clauses[idx] for s in selectables: - if set(surface_selectables(f)).\ - intersection(surface_selectables(s)): + if set(surface_selectables(f)).intersection( + surface_selectables(s) + ): conservative_idx.append(idx) break if conservative_idx: @@ -184,8 +199,9 @@ def visit_binary_product(fn, expr): # we don't want to dig into correlated subqueries, # those are just column elements by themselves yield element - elif element.__visit_name__ == 'binary' and \ - operators.is_comparison(element.operator): + elif element.__visit_name__ == "binary" and operators.is_comparison( + element.operator + ): stack.insert(0, element) for l in visit(element.left): for r in visit(element.right): @@ -199,38 +215,47 @@ def visit_binary_product(fn, expr): for elem in element.get_children(): for e in visit(elem): yield e + list(visit(expr)) -def find_tables(clause, check_columns=False, - include_aliases=False, include_joins=False, - include_selects=False, include_crud=False): +def find_tables( + clause, + check_columns=False, + include_aliases=False, + include_joins=False, + include_selects=False, + include_crud=False, +): """locate Table objects within the given expression.""" tables = [] _visitors = {} if include_selects: - _visitors['select'] = _visitors['compound_select'] = tables.append + _visitors["select"] = _visitors["compound_select"] = tables.append if include_joins: - _visitors['join'] = tables.append + _visitors["join"] = tables.append if include_aliases: - _visitors['alias'] = tables.append + _visitors["alias"] = tables.append if include_crud: - _visitors['insert'] = _visitors['update'] = \ - _visitors['delete'] = lambda ent: tables.append(ent.table) + _visitors["insert"] = _visitors["update"] = _visitors[ + "delete" + ] = lambda ent: tables.append(ent.table) if check_columns: + def visit_column(column): tables.append(column.table) - _visitors['column'] = visit_column - _visitors['table'] = tables.append + _visitors["column"] = visit_column - visitors.traverse(clause, {'column_collections': False}, _visitors) + _visitors["table"] = tables.append + + visitors.traverse(clause, {"column_collections": False}, _visitors) return tables @@ -243,10 +268,9 @@ def unwrap_order_by(clause): stack = deque([clause]) while stack: t = stack.popleft() - if isinstance(t, ColumnElement) and \ - ( - not isinstance(t, UnaryExpression) or - not operators.is_ordering_modifier(t.modifier) + if isinstance(t, ColumnElement) and ( + not isinstance(t, UnaryExpression) + or not operators.is_ordering_modifier(t.modifier) ): if isinstance(t, _label_reference): t = t.element @@ -266,9 +290,7 @@ def unwrap_label_reference(element): if isinstance(elem, (_label_reference, _textual_label_reference)): return elem.element - return visitors.replacement_traverse( - element, {}, replace - ) + return visitors.replacement_traverse(element, {}, replace) def expand_column_list_from_order_by(collist, order_by): @@ -278,17 +300,16 @@ def expand_column_list_from_order_by(collist, order_by): in the collist. """ - cols_already_present = set([ - col.element if col._order_by_label_element is not None - else col for col in collist - ]) + cols_already_present = set( + [ + col.element if col._order_by_label_element is not None else col + for col in collist + ] + ) return [ - col for col in - chain(*[ - unwrap_order_by(o) - for o in order_by - ]) + col + for col in chain(*[unwrap_order_by(o) for o in order_by]) if col not in cols_already_present ] @@ -325,9 +346,9 @@ def surface_column_elements(clause, include_scalar_selects=True): be addressable in the WHERE clause of a SELECT if this element were in the columns clause.""" - filter_ = (FromGrouping, ) + filter_ = (FromGrouping,) if not include_scalar_selects: - filter_ += (SelectBase, ) + filter_ += (SelectBase,) stack = deque([clause]) while stack: @@ -343,9 +364,7 @@ def selectables_overlap(left, right): """Return True if left/right have some overlapping selectable""" return bool( - set(surface_selectables(left)).intersection( - surface_selectables(right) - ) + set(surface_selectables(left)).intersection(surface_selectables(right)) ) @@ -366,7 +385,7 @@ def bind_values(clause): def visit_bindparam(bind): v.append(bind.effective_value) - visitors.traverse(clause, {}, {'bindparam': visit_bindparam}) + visitors.traverse(clause, {}, {"bindparam": visit_bindparam}) return v @@ -383,7 +402,7 @@ class _repr_base(object): _TUPLE = 1 _DICT = 2 - __slots__ = 'max_chars', + __slots__ = ("max_chars",) def trunc(self, value): rep = repr(value) @@ -391,10 +410,12 @@ class _repr_base(object): if lenrep > self.max_chars: segment_length = self.max_chars // 2 rep = ( - rep[0:segment_length] + - (" ... (%d characters truncated) ... " - % (lenrep - self.max_chars)) + - rep[-segment_length:] + rep[0:segment_length] + + ( + " ... (%d characters truncated) ... " + % (lenrep - self.max_chars) + ) + + rep[-segment_length:] ) return rep @@ -402,7 +423,7 @@ class _repr_base(object): class _repr_row(_repr_base): """Provide a string view of a row.""" - __slots__ = 'row', + __slots__ = ("row",) def __init__(self, row, max_chars=300): self.row = row @@ -412,7 +433,7 @@ class _repr_row(_repr_base): trunc = self.trunc return "(%s%s)" % ( ", ".join(trunc(value) for value in self.row), - "," if len(self.row) == 1 else "" + "," if len(self.row) == 1 else "", ) @@ -424,7 +445,7 @@ class _repr_params(_repr_base): """ - __slots__ = 'params', 'batches', + __slots__ = "params", "batches" def __init__(self, params, batches, max_chars=300): self.params = params @@ -435,11 +456,13 @@ class _repr_params(_repr_base): if isinstance(self.params, list): typ = self._LIST ismulti = self.params and isinstance( - self.params[0], (list, dict, tuple)) + self.params[0], (list, dict, tuple) + ) elif isinstance(self.params, tuple): typ = self._TUPLE ismulti = self.params and isinstance( - self.params[0], (list, dict, tuple)) + self.params[0], (list, dict, tuple) + ) elif isinstance(self.params, dict): typ = self._DICT ismulti = False @@ -448,11 +471,15 @@ class _repr_params(_repr_base): if ismulti and len(self.params) > self.batches: msg = " ... displaying %i of %i total bound parameter sets ... " - return ' '.join(( - self._repr_multi(self.params[:self.batches - 2], typ)[0:-1], - msg % (self.batches, len(self.params)), - self._repr_multi(self.params[-2:], typ)[1:] - )) + return " ".join( + ( + self._repr_multi(self.params[: self.batches - 2], typ)[ + 0:-1 + ], + msg % (self.batches, len(self.params)), + self._repr_multi(self.params[-2:], typ)[1:], + ) + ) elif ismulti: return self._repr_multi(self.params, typ) else: @@ -467,12 +494,13 @@ class _repr_params(_repr_base): elif isinstance(multi_params[0], dict): elem_type = self._DICT else: - assert False, \ - "Unknown parameter type %s" % (type(multi_params[0])) + assert False, "Unknown parameter type %s" % ( + type(multi_params[0]) + ) elements = ", ".join( - self._repr_params(params, elem_type) - for params in multi_params) + self._repr_params(params, elem_type) for params in multi_params + ) else: elements = "" @@ -493,13 +521,10 @@ class _repr_params(_repr_base): elif typ is self._TUPLE: return "(%s%s)" % ( ", ".join(trunc(value) for value in params), - "," if len(params) == 1 else "" - + "," if len(params) == 1 else "", ) else: - return "[%s]" % ( - ", ".join(trunc(value) for value in params) - ) + return "[%s]" % (", ".join(trunc(value) for value in params)) def adapt_criterion_to_null(crit, nulls): @@ -509,20 +534,24 @@ def adapt_criterion_to_null(crit, nulls): """ def visit_binary(binary): - if isinstance(binary.left, BindParameter) \ - and binary.left._identifying_key in nulls: + if ( + isinstance(binary.left, BindParameter) + and binary.left._identifying_key in nulls + ): # reverse order if the NULL is on the left side binary.left = binary.right binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot - elif isinstance(binary.right, BindParameter) \ - and binary.right._identifying_key in nulls: + elif ( + isinstance(binary.right, BindParameter) + and binary.right._identifying_key in nulls + ): binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot - return visitors.cloned_traverse(crit, {}, {'binary': visit_binary}) + return visitors.cloned_traverse(crit, {}, {"binary": visit_binary}) def splice_joins(left, right, stop_on=None): @@ -570,8 +599,8 @@ def reduce_columns(columns, *clauses, **kw): in the selectable to just those that are not repeated. """ - ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False) - only_synonyms = kw.pop('only_synonyms', False) + ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False) + only_synonyms = kw.pop("only_synonyms", False) columns = util.ordered_column_set(columns) @@ -597,39 +626,48 @@ def reduce_columns(columns, *clauses, **kw): continue else: raise - if fk_col.shares_lineage(c) and \ - (not only_synonyms or - c.name == col.name): + if fk_col.shares_lineage(c) and ( + not only_synonyms or c.name == col.name + ): omit.add(col) break if clauses: + def visit_binary(binary): if binary.operator == operators.eq: cols = util.column_set( - chain(*[c.proxy_set for c in columns.difference(omit)])) + chain(*[c.proxy_set for c in columns.difference(omit)]) + ) if binary.left in cols and binary.right in cols: for c in reversed(columns): - if c.shares_lineage(binary.right) and \ - (not only_synonyms or - c.name == binary.left.name): + if c.shares_lineage(binary.right) and ( + not only_synonyms or c.name == binary.left.name + ): omit.add(c) break + for clause in clauses: if clause is not None: - visitors.traverse(clause, {}, {'binary': visit_binary}) + visitors.traverse(clause, {}, {"binary": visit_binary}) return ColumnSet(columns.difference(omit)) -def criterion_as_pairs(expression, consider_as_foreign_keys=None, - consider_as_referenced_keys=None, any_operator=False): +def criterion_as_pairs( + expression, + consider_as_foreign_keys=None, + consider_as_referenced_keys=None, + any_operator=False, +): """traverse an expression and locate binary criterion pairs.""" if consider_as_foreign_keys and consider_as_referenced_keys: - raise exc.ArgumentError("Can only specify one of " - "'consider_as_foreign_keys' or " - "'consider_as_referenced_keys'") + raise exc.ArgumentError( + "Can only specify one of " + "'consider_as_foreign_keys' or " + "'consider_as_referenced_keys'" + ) def col_is(a, b): # return a is b @@ -638,37 +676,44 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, def visit_binary(binary): if not any_operator and binary.operator is not operators.eq: return - if not isinstance(binary.left, ColumnElement) or \ - not isinstance(binary.right, ColumnElement): + if not isinstance(binary.left, ColumnElement) or not isinstance( + binary.right, ColumnElement + ): return if consider_as_foreign_keys: - if binary.left in consider_as_foreign_keys and \ - (col_is(binary.right, binary.left) or - binary.right not in consider_as_foreign_keys): + if binary.left in consider_as_foreign_keys and ( + col_is(binary.right, binary.left) + or binary.right not in consider_as_foreign_keys + ): pairs.append((binary.right, binary.left)) - elif binary.right in consider_as_foreign_keys and \ - (col_is(binary.left, binary.right) or - binary.left not in consider_as_foreign_keys): + elif binary.right in consider_as_foreign_keys and ( + col_is(binary.left, binary.right) + or binary.left not in consider_as_foreign_keys + ): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: - if binary.left in consider_as_referenced_keys and \ - (col_is(binary.right, binary.left) or - binary.right not in consider_as_referenced_keys): + if binary.left in consider_as_referenced_keys and ( + col_is(binary.right, binary.left) + or binary.right not in consider_as_referenced_keys + ): pairs.append((binary.left, binary.right)) - elif binary.right in consider_as_referenced_keys and \ - (col_is(binary.left, binary.right) or - binary.left not in consider_as_referenced_keys): + elif binary.right in consider_as_referenced_keys and ( + col_is(binary.left, binary.right) + or binary.left not in consider_as_referenced_keys + ): pairs.append((binary.right, binary.left)) else: - if isinstance(binary.left, Column) and \ - isinstance(binary.right, Column): + if isinstance(binary.left, Column) and isinstance( + binary.right, Column + ): if binary.left.references(binary.right): pairs.append((binary.right, binary.left)) elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) + pairs = [] - visitors.traverse(expression, {}, {'binary': visit_binary}) + visitors.traverse(expression, {}, {"binary": visit_binary}) return pairs @@ -699,28 +744,38 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): """ - def __init__(self, selectable, equivalents=None, - include_fn=None, exclude_fn=None, - adapt_on_names=False, anonymize_labels=False): + def __init__( + self, + selectable, + equivalents=None, + include_fn=None, + exclude_fn=None, + adapt_on_names=False, + anonymize_labels=False, + ): self.__traverse_options__ = { - 'stop_on': [selectable], - 'anonymize_labels': anonymize_labels} + "stop_on": [selectable], + "anonymize_labels": anonymize_labels, + } self.selectable = selectable self.include_fn = include_fn self.exclude_fn = exclude_fn self.equivalents = util.column_dict(equivalents or {}) self.adapt_on_names = adapt_on_names - def _corresponding_column(self, col, require_embedded, - _seen=util.EMPTY_SET): + def _corresponding_column( + self, col, require_embedded, _seen=util.EMPTY_SET + ): newcol = self.selectable.corresponding_column( - col, - require_embedded=require_embedded) + col, require_embedded=require_embedded + ) if newcol is None and col in self.equivalents and col not in _seen: for equiv in self.equivalents[col]: newcol = self._corresponding_column( - equiv, require_embedded=require_embedded, - _seen=_seen.union([col])) + equiv, + require_embedded=require_embedded, + _seen=_seen.union([col]), + ) if newcol is not None: return newcol if self.adapt_on_names and newcol is None: @@ -728,8 +783,9 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): return newcol def replace(self, col): - if isinstance(col, FromClause) and \ - self.selectable.is_derived_from(col): + if isinstance(col, FromClause) and self.selectable.is_derived_from( + col + ): return self.selectable elif not isinstance(col, ColumnElement): return None @@ -772,16 +828,27 @@ class ColumnAdapter(ClauseAdapter): """ - def __init__(self, selectable, equivalents=None, - chain_to=None, adapt_required=False, - include_fn=None, exclude_fn=None, - adapt_on_names=False, - allow_label_resolve=True, - anonymize_labels=False): - ClauseAdapter.__init__(self, selectable, equivalents, - include_fn=include_fn, exclude_fn=exclude_fn, - adapt_on_names=adapt_on_names, - anonymize_labels=anonymize_labels) + def __init__( + self, + selectable, + equivalents=None, + chain_to=None, + adapt_required=False, + include_fn=None, + exclude_fn=None, + adapt_on_names=False, + allow_label_resolve=True, + anonymize_labels=False, + ): + ClauseAdapter.__init__( + self, + selectable, + equivalents, + include_fn=include_fn, + exclude_fn=exclude_fn, + adapt_on_names=adapt_on_names, + anonymize_labels=anonymize_labels, + ) if chain_to: self.chain(chain_to) @@ -800,9 +867,7 @@ class ColumnAdapter(ClauseAdapter): def __getitem__(self, key): if ( self.parent.include_fn and not self.parent.include_fn(key) - ) or ( - self.parent.exclude_fn and self.parent.exclude_fn(key) - ): + ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)): if self.parent._wrap: return self.parent._wrap.columns[key] else: @@ -843,7 +908,7 @@ class ColumnAdapter(ClauseAdapter): def __getstate__(self): d = self.__dict__.copy() - del d['columns'] + del d["columns"] return d def __setstate__(self, state): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index b39ec8167..bf1743643 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -29,11 +29,20 @@ from .. import util import operator from .. import exc -__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', - 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate', - 'iterate_depthfirst', 'traverse_using', 'traverse', - 'traverse_depthfirst', - 'cloned_traverse', 'replacement_traverse'] +__all__ = [ + "VisitableType", + "Visitable", + "ClauseVisitor", + "CloningVisitor", + "ReplacingCloningVisitor", + "iterate", + "iterate_depthfirst", + "traverse_using", + "traverse", + "traverse_depthfirst", + "cloned_traverse", + "replacement_traverse", +] class VisitableType(type): @@ -53,8 +62,7 @@ class VisitableType(type): """ def __init__(cls, clsname, bases, clsdict): - if clsname != 'Visitable' and \ - hasattr(cls, '__visit_name__'): + if clsname != "Visitable" and hasattr(cls, "__visit_name__"): _generate_dispatch(cls) super(VisitableType, cls).__init__(clsname, bases, clsdict) @@ -64,7 +72,7 @@ def _generate_dispatch(cls): """Return an optimized visit dispatch function for the cls for use by the compiler. """ - if '__visit_name__' in cls.__dict__: + if "__visit_name__" in cls.__dict__: visit_name = cls.__visit_name__ if isinstance(visit_name, str): # There is an optimization opportunity here because the @@ -79,12 +87,13 @@ def _generate_dispatch(cls): raise exc.UnsupportedCompilationError(visitor, cls) else: return meth(self, **kw) + else: # The optimization opportunity is lost for this case because the # __visit_name__ is not yet a string. As a result, the visit # string has to be recalculated with each compilation. def _compiler_dispatch(self, visitor, **kw): - visit_attr = 'visit_%s' % self.__visit_name__ + visit_attr = "visit_%s" % self.__visit_name__ try: meth = getattr(visitor, visit_attr) except AttributeError: @@ -92,8 +101,7 @@ def _generate_dispatch(cls): else: return meth(self, **kw) - _compiler_dispatch.__doc__ = \ - """Look for an attribute named "visit_" + self.__visit_name__ + _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__ on the visitor, and call it with the same kw params. """ cls._compiler_dispatch = _compiler_dispatch @@ -137,7 +145,7 @@ class ClauseVisitor(object): visitors = {} for name in dir(self): - if name.startswith('visit_'): + if name.startswith("visit_"): visitors[name[6:]] = getattr(self, name) return visitors @@ -148,7 +156,7 @@ class ClauseVisitor(object): v = self while v: yield v - v = getattr(v, '_next', None) + v = getattr(v, "_next", None) def chain(self, visitor): """'chain' an additional ClauseVisitor onto this ClauseVisitor. @@ -178,7 +186,8 @@ class CloningVisitor(ClauseVisitor): """traverse and visit the given expression structure.""" return cloned_traverse( - obj, self.__traverse_options__, self._visitor_dict) + obj, self.__traverse_options__, self._visitor_dict + ) class ReplacingCloningVisitor(CloningVisitor): @@ -204,6 +213,7 @@ class ReplacingCloningVisitor(CloningVisitor): e = v.replace(elem) if e is not None: return e + return replacement_traverse(obj, self.__traverse_options__, replace) @@ -282,7 +292,7 @@ def cloned_traverse(obj, opts, visitors): modifications by visitors.""" cloned = {} - stop_on = set(opts.get('stop_on', [])) + stop_on = set(opts.get("stop_on", [])) def clone(elem): if elem in stop_on: @@ -306,11 +316,13 @@ def replacement_traverse(obj, opts, replace): replacement by a given replacement function.""" cloned = {} - stop_on = {id(x) for x in opts.get('stop_on', [])} + stop_on = {id(x) for x in opts.get("stop_on", [])} def clone(elem, **kw): - if id(elem) in stop_on or \ - 'no_replacement_traverse' in elem._annotations: + if ( + id(elem) in stop_on + or "no_replacement_traverse" in elem._annotations + ): return elem else: newelem = replace(elem) |
