diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
| commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
| tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/ext | |
| parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
| download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz | |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/ext')
| -rw-r--r-- | lib/sqlalchemy/ext/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/associationproxy.py | 288 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/automap.py | 309 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/baked.py | 135 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/compiler.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/declarative/__init__.py | 35 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/declarative/api.py | 117 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/declarative/base.py | 326 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/declarative/clsregistry.py | 125 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 61 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/hybrid.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/indexable.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/instrumentation.py | 78 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/mutable.py | 78 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/orderinglist.py | 29 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/serializer.py | 26 |
16 files changed, 1021 insertions, 659 deletions
diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py index 9558b2a1f..9fed09e2b 100644 --- a/lib/sqlalchemy/ext/__init__.py +++ b/lib/sqlalchemy/ext/__init__.py @@ -8,4 +8,3 @@ from .. import util as _sa_util _sa_util.dependencies.resolve_all("sqlalchemy.ext") - diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index ff9433d4d..56b91ce0b 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -76,7 +76,7 @@ def association_proxy(target_collection, attr, **kw): return AssociationProxy(target_collection, attr, **kw) -ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY') +ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY") """Symbol indicating an :class:`InspectionAttr` that's of type :class:`.AssociationProxy`. @@ -92,10 +92,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo): is_attribute = False extension_type = ASSOCIATION_PROXY - def __init__(self, target_collection, attr, creator=None, - getset_factory=None, proxy_factory=None, - proxy_bulk_set=None, info=None, - cascade_scalar_deletes=False): + def __init__( + self, + target_collection, + attr, + creator=None, + getset_factory=None, + proxy_factory=None, + proxy_bulk_set=None, + info=None, + cascade_scalar_deletes=False, + ): """Construct a new :class:`.AssociationProxy`. The :func:`.association_proxy` function is provided as the usual @@ -162,8 +169,11 @@ class AssociationProxy(interfaces.InspectionAttrInfo): self.proxy_bulk_set = proxy_bulk_set self.cascade_scalar_deletes = cascade_scalar_deletes - self.key = '_%s_%s_%s' % ( - type(self).__name__, target_collection, id(self)) + self.key = "_%s_%s_%s" % ( + type(self).__name__, + target_collection, + id(self), + ) if info: self.info = info @@ -264,12 +274,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo): def getter(target): return _getter(target) if target is not None else None + if collection_class is dict: + def setter(o, k, v): setattr(o, attr, v) + else: + def setter(o, v): setattr(o, attr, v) + return getter, setter @@ -325,20 +340,21 @@ class AssociationProxyInstance(object): def for_proxy(cls, parent, owning_class, parent_instance): target_collection = parent.target_collection value_attr = parent.value_attr - prop = orm.class_mapper(owning_class).\ - get_property(target_collection) + prop = orm.class_mapper(owning_class).get_property(target_collection) # this was never asserted before but this should be made clear. if not isinstance(prop, orm.RelationshipProperty): raise NotImplementedError( "association proxy to a non-relationship " - "intermediary is not supported") + "intermediary is not supported" + ) target_class = prop.mapper.class_ try: target_assoc = cls._cls_unwrap_target_assoc_proxy( - target_class, value_attr) + target_class, value_attr + ) except AttributeError: # the proxied attribute doesn't exist on the target class; # return an "ambiguous" instance that will work on a per-object @@ -353,8 +369,8 @@ class AssociationProxyInstance(object): @classmethod def _construct_for_assoc( - cls, target_assoc, parent, owning_class, - target_class, value_attr): + cls, target_assoc, parent, owning_class, target_class, value_attr + ): if target_assoc is not None: return ObjectAssociationProxyInstance( parent, owning_class, target_class, value_attr @@ -371,8 +387,9 @@ class AssociationProxyInstance(object): ) def _get_property(self): - return orm.class_mapper(self.owning_class).\ - get_property(self.target_collection) + return orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) @property def _comparator(self): @@ -388,7 +405,8 @@ class AssociationProxyInstance(object): @util.memoized_property def _unwrap_target_assoc_proxy(self): return self._cls_unwrap_target_assoc_proxy( - self.target_class, self.value_attr) + self.target_class, self.value_attr + ) @property def remote_attr(self): @@ -448,8 +466,11 @@ class AssociationProxyInstance(object): @util.memoized_property def _value_is_scalar(self): - return not self._get_property().\ - mapper.get_property(self.value_attr).uselist + return ( + not self._get_property() + .mapper.get_property(self.value_attr) + .uselist + ) @property def _target_is_object(self): @@ -468,12 +489,17 @@ class AssociationProxyInstance(object): def getter(target): return _getter(target) if target is not None else None + if collection_class is dict: + def setter(o, k, v): return setattr(o, attr, v) + else: + def setter(o, v): return setattr(o, attr, v) + return getter, setter @property @@ -500,14 +526,18 @@ class AssociationProxyInstance(object): return proxy self.collection_class, proxy = self._new( - _lazy_collection(obj, self.target_collection)) + _lazy_collection(obj, self.target_collection) + ) setattr(obj, self.key, (id(obj), id(self), proxy)) return proxy def set(self, obj, values): if self.scalar: - creator = self.parent.creator \ - if self.parent.creator else self.target_class + creator = ( + self.parent.creator + if self.parent.creator + else self.target_class + ) target = getattr(obj, self.target_collection) if target is None: if values is None: @@ -535,35 +565,52 @@ class AssociationProxyInstance(object): delattr(obj, self.target_collection) def _new(self, lazy_collection): - creator = self.parent.creator if self.parent.creator else \ - self.target_class + creator = ( + self.parent.creator if self.parent.creator else self.target_class + ) collection_class = util.duck_type_collection(lazy_collection()) if self.parent.proxy_factory: - return collection_class, self.parent.proxy_factory( - lazy_collection, creator, self.value_attr, self) + return ( + collection_class, + self.parent.proxy_factory( + lazy_collection, creator, self.value_attr, self + ), + ) if self.parent.getset_factory: - getter, setter = self.parent.getset_factory( - collection_class, self) + getter, setter = self.parent.getset_factory(collection_class, self) else: getter, setter = self.parent._default_getset(collection_class) if collection_class is list: - return collection_class, _AssociationList( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationList( + lazy_collection, creator, getter, setter, self + ), + ) elif collection_class is dict: - return collection_class, _AssociationDict( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationDict( + lazy_collection, creator, getter, setter, self + ), + ) elif collection_class is set: - return collection_class, _AssociationSet( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationSet( + lazy_collection, creator, getter, setter, self + ), + ) else: raise exc.ArgumentError( - 'could not guess which interface to use for ' + "could not guess which interface to use for " 'collection_class "%s" backing "%s"; specify a ' - 'proxy_factory and proxy_bulk_set manually' % - (self.collection_class.__name__, self.target_collection)) + "proxy_factory and proxy_bulk_set manually" + % (self.collection_class.__name__, self.target_collection) + ) def _set(self, proxy, values): if self.parent.proxy_bulk_set: @@ -576,16 +623,19 @@ class AssociationProxyInstance(object): proxy.update(values) else: raise exc.ArgumentError( - 'no proxy_bulk_set supplied for custom ' - 'collection_class implementation') + "no proxy_bulk_set supplied for custom " + "collection_class implementation" + ) def _inflate(self, proxy): - creator = self.parent.creator and \ - self.parent.creator or self.target_class + creator = ( + self.parent.creator and self.parent.creator or self.target_class + ) if self.parent.getset_factory: getter, setter = self.parent.getset_factory( - self.collection_class, self) + self.collection_class, self + ) else: getter, setter = self.parent._default_getset(self.collection_class) @@ -594,12 +644,13 @@ class AssociationProxyInstance(object): proxy.setter = setter def _criterion_exists(self, criterion=None, **kwargs): - is_has = kwargs.pop('is_has', None) + is_has = kwargs.pop("is_has", None) target_assoc = self._unwrap_target_assoc_proxy if target_assoc is not None: inner = target_assoc._criterion_exists( - criterion=criterion, **kwargs) + criterion=criterion, **kwargs + ) return self._comparator._criterion_exists(inner) if self._target_is_object: @@ -631,15 +682,15 @@ class AssociationProxyInstance(object): """ if self._unwrap_target_assoc_proxy is None and ( - self.scalar and ( - not self._target_is_object or self._value_is_scalar) + self.scalar + and (not self._target_is_object or self._value_is_scalar) ): raise exc.InvalidRequestError( - "'any()' not implemented for scalar " - "attributes. Use has()." + "'any()' not implemented for scalar " "attributes. Use has()." ) return self._criterion_exists( - criterion=criterion, is_has=False, **kwargs) + criterion=criterion, is_has=False, **kwargs + ) def has(self, criterion=None, **kwargs): """Produce a proxied 'has' expression using EXISTS. @@ -651,14 +702,15 @@ class AssociationProxyInstance(object): """ if self._unwrap_target_assoc_proxy is None and ( - not self.scalar or ( - self._target_is_object and not self._value_is_scalar) + not self.scalar + or (self._target_is_object and not self._value_is_scalar) ): raise exc.InvalidRequestError( - "'has()' not implemented for collections. " - "Use any().") + "'has()' not implemented for collections. " "Use any()." + ) return self._criterion_exists( - criterion=criterion, is_has=True, **kwargs) + criterion=criterion, is_has=True, **kwargs + ) class AmbiguousAssociationProxyInstance(AssociationProxyInstance): @@ -673,10 +725,14 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): "Association proxy %s.%s refers to an attribute '%s' that is not " "directly mapped on class %s; therefore this operation cannot " "proceed since we don't know what type of object is referred " - "towards" % ( - self.owning_class.__name__, self.target_collection, - self.value_attr, self.target_class - )) + "towards" + % ( + self.owning_class.__name__, + self.target_collection, + self.value_attr, + self.target_class, + ) + ) def get(self, obj): self._ambiguous() @@ -718,27 +774,32 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): return self def _populate_cache(self, instance_class): - prop = orm.class_mapper(self.owning_class).\ - get_property(self.target_collection) + prop = orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) if inspect(instance_class).mapper.isa(prop.mapper): target_class = instance_class try: target_assoc = self._cls_unwrap_target_assoc_proxy( - target_class, self.value_attr) + target_class, self.value_attr + ) except AttributeError: pass else: - self._lookup_cache[instance_class] = \ - self._construct_for_assoc( - target_assoc, self.parent, self.owning_class, - target_class, self.value_attr + self._lookup_cache[instance_class] = self._construct_for_assoc( + target_assoc, + self.parent, + self.owning_class, + target_class, + self.value_attr, ) class ObjectAssociationProxyInstance(AssociationProxyInstance): """an :class:`.AssociationProxyInstance` that has an object as a target. """ + _target_is_object = True _is_canonical = True @@ -756,17 +817,21 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): if target_assoc is not None: return self._comparator._criterion_exists( target_assoc.contains(obj) - if not target_assoc.scalar else target_assoc == obj + if not target_assoc.scalar + else target_assoc == obj ) - elif self._target_is_object and self.scalar and \ - not self._value_is_scalar: + elif ( + self._target_is_object + and self.scalar + and not self._value_is_scalar + ): return self._comparator.has( getattr(self.target_class, self.value_attr).contains(obj) ) - elif self._target_is_object and self.scalar and \ - self._value_is_scalar: + elif self._target_is_object and self.scalar and self._value_is_scalar: raise exc.InvalidRequestError( - "contains() doesn't apply to a scalar object endpoint; use ==") + "contains() doesn't apply to a scalar object endpoint; use ==" + ) else: return self._comparator._criterion_exists(**{self.value_attr: obj}) @@ -777,7 +842,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): if obj is None: return or_( self._comparator.has(**{self.value_attr: obj}), - self._comparator == None + self._comparator == None, ) else: return self._comparator.has(**{self.value_attr: obj}) @@ -786,14 +851,17 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): # note the has() here will fail for collections; eq_() # is only allowed with a scalar. return self._comparator.has( - getattr(self.target_class, self.value_attr) != obj) + getattr(self.target_class, self.value_attr) != obj + ) class ColumnAssociationProxyInstance( - ColumnOperators, AssociationProxyInstance): + ColumnOperators, AssociationProxyInstance +): """an :class:`.AssociationProxyInstance` that has a database column as a target. """ + _target_is_object = False _is_canonical = True @@ -803,9 +871,7 @@ class ColumnAssociationProxyInstance( self.remote_attr.operate(operator.eq, other) ) if other is None: - return or_( - expr, self._comparator == None - ) + return or_(expr, self._comparator == None) else: return expr @@ -824,11 +890,11 @@ class _lazy_collection(object): return getattr(self.parent, self.target) def __getstate__(self): - return {'obj': self.parent, 'target': self.target} + return {"obj": self.parent, "target": self.target} def __setstate__(self, state): - self.parent = state['obj'] - self.target = state['target'] + self.parent = state["obj"] + self.target = state["target"] class _AssociationCollection(object): @@ -874,11 +940,11 @@ class _AssociationCollection(object): __nonzero__ = __bool__ def __getstate__(self): - return {'parent': self.parent, 'lazy_collection': self.lazy_collection} + return {"parent": self.parent, "lazy_collection": self.lazy_collection} def __setstate__(self, state): - self.parent = state['parent'] - self.lazy_collection = state['lazy_collection'] + self.parent = state["parent"] + self.lazy_collection = state["lazy_collection"] self.parent._inflate(self) @@ -925,8 +991,8 @@ class _AssociationList(_AssociationCollection): if len(value) != len(rng): raise ValueError( "attempt to assign sequence of size %s to " - "extended slice of size %s" % (len(value), - len(rng))) + "extended slice of size %s" % (len(value), len(rng)) + ) for i, item in zip(rng, value): self._set(self.col[i], item) @@ -968,8 +1034,14 @@ class _AssociationList(_AssociationCollection): col.append(item) def count(self, value): - return sum([1 for _ in - util.itertools_filter(lambda v: v == value, iter(self))]) + return sum( + [ + 1 + for _ in util.itertools_filter( + lambda v: v == value, iter(self) + ) + ] + ) def extend(self, values): for v in values: @@ -999,7 +1071,7 @@ class _AssociationList(_AssociationCollection): raise NotImplementedError def clear(self): - del self.col[0:len(self.col)] + del self.col[0 : len(self.col)] def __eq__(self, other): return list(self) == other @@ -1040,6 +1112,7 @@ class _AssociationList(_AssociationCollection): if not isinstance(n, int): return NotImplemented return list(self) * n + __rmul__ = __mul__ def __iadd__(self, iterable): @@ -1072,13 +1145,17 @@ class _AssociationList(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(list, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func -_NotProvided = util.symbol('_NotProvided') +_NotProvided = util.symbol("_NotProvided") class _AssociationDict(_AssociationCollection): @@ -1160,6 +1237,7 @@ class _AssociationDict(_AssociationCollection): return self.col.keys() if util.py2k: + def iteritems(self): return ((key, self._get(self.col[key])) for key in self.col) @@ -1174,7 +1252,9 @@ class _AssociationDict(_AssociationCollection): def items(self): return [(k, self._get(self.col[k])) for k in self] + else: + def items(self): return ((key, self._get(self.col[key])) for key in self.col) @@ -1194,14 +1274,15 @@ class _AssociationDict(_AssociationCollection): def update(self, *a, **kw): if len(a) > 1: - raise TypeError('update expected at most 1 arguments, got %i' % - len(a)) + raise TypeError( + "update expected at most 1 arguments, got %i" % len(a) + ) elif len(a) == 1: seq_or_map = a[0] # discern dict from sequence - took the advice from # http://www.voidspace.org.uk/python/articles/duck_typing.shtml # still not perfect :( - if hasattr(seq_or_map, 'keys'): + if hasattr(seq_or_map, "keys"): for item in seq_or_map: self[item] = seq_or_map[item] else: @@ -1211,7 +1292,8 @@ class _AssociationDict(_AssociationCollection): except ValueError: raise ValueError( "dictionary update sequence " - "requires 2-element tuples") + "requires 2-element tuples" + ) for key, value in kw: self[key] = value @@ -1223,8 +1305,12 @@ class _AssociationDict(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(dict, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(dict, func_name) + ): func.__doc__ = getattr(dict, func_name).__doc__ del func_name, func @@ -1288,7 +1374,7 @@ class _AssociationSet(_AssociationCollection): def pop(self): if not self.col: - raise KeyError('pop from an empty set') + raise KeyError("pop from an empty set") member = self.col.pop() return self._get(member) @@ -1420,7 +1506,11 @@ class _AssociationSet(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(set, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(set, func_name) + ): func.__doc__ = getattr(set, func_name).__doc__ del func_name, func diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index cafb3d61c..747373a2a 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -580,7 +580,8 @@ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): def name_for_collection_relationship( - base, local_cls, referred_cls, constraint): + base, local_cls, referred_cls, constraint +): """Return the attribute name that should be used to refer from one class to another, for a collection reference. @@ -607,7 +608,8 @@ def name_for_collection_relationship( def generate_relationship( - base, direction, return_fn, attrname, local_cls, referred_cls, **kw): + base, direction, return_fn, attrname, local_cls, referred_cls, **kw +): r"""Generate a :func:`.relationship` or :func:`.backref` on behalf of two mapped classes. @@ -677,6 +679,7 @@ class AutomapBase(object): :ref:`automap_toplevel` """ + __abstract__ = True classes = None @@ -694,15 +697,16 @@ class AutomapBase(object): @classmethod def prepare( - cls, - engine=None, - reflect=False, - schema=None, - classname_for_table=classname_for_table, - collection_class=list, - name_for_scalar_relationship=name_for_scalar_relationship, - name_for_collection_relationship=name_for_collection_relationship, - generate_relationship=generate_relationship): + cls, + engine=None, + reflect=False, + schema=None, + classname_for_table=classname_for_table, + collection_class=list, + name_for_scalar_relationship=name_for_scalar_relationship, + name_for_collection_relationship=name_for_collection_relationship, + generate_relationship=generate_relationship, + ): """Extract mapped classes and relationships from the :class:`.MetaData` and perform mappings. @@ -752,15 +756,16 @@ class AutomapBase(object): engine, schema=schema, extend_existing=True, - autoload_replace=False + autoload_replace=False, ) _CONFIGURE_MUTEX.acquire() try: table_to_map_config = dict( (m.local_table, m) - for m in _DeferredMapperConfig. - classes_for_base(cls, sort=False) + for m in _DeferredMapperConfig.classes_for_base( + cls, sort=False + ) ) many_to_many = [] @@ -774,30 +779,39 @@ class AutomapBase(object): elif table not in table_to_map_config: mapped_cls = type( classname_for_table(cls, table.name, table), - (cls, ), - {"__table__": table} + (cls,), + {"__table__": table}, ) map_config = _DeferredMapperConfig.config_for_cls( - mapped_cls) + mapped_cls + ) cls.classes[map_config.cls.__name__] = mapped_cls table_to_map_config[table] = map_config for map_config in table_to_map_config.values(): - _relationships_for_fks(cls, - map_config, - table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship) + _relationships_for_fks( + cls, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) for lcl_m2m, rem_m2m, m2m_const, table in many_to_many: - _m2m_relationship(cls, lcl_m2m, rem_m2m, m2m_const, table, - table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship) + _m2m_relationship( + cls, + lcl_m2m, + rem_m2m, + m2m_const, + table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) for map_config in _DeferredMapperConfig.classes_for_base(cls): map_config.map() @@ -853,20 +867,27 @@ def automap_base(declarative_base=None, **kw): return type( Base.__name__, - (AutomapBase, Base,), - {"__abstract__": True, "classes": util.Properties({})} + (AutomapBase, Base), + {"__abstract__": True, "classes": util.Properties({})}, ) def _is_many_to_many(automap_base, table): - fk_constraints = [const for const in table.constraints - if isinstance(const, ForeignKeyConstraint)] + fk_constraints = [ + const + for const in table.constraints + if isinstance(const, ForeignKeyConstraint) + ] if len(fk_constraints) != 2: return None, None, None cols = sum( - [[fk.parent for fk in fk_constraint.elements] - for fk_constraint in fk_constraints], []) + [ + [fk.parent for fk in fk_constraint.elements] + for fk_constraint in fk_constraints + ], + [], + ) if set(cols) != set(table.c): return None, None, None @@ -874,15 +895,19 @@ def _is_many_to_many(automap_base, table): return ( fk_constraints[0].elements[0].column.table, fk_constraints[1].elements[0].column.table, - fk_constraints + fk_constraints, ) -def _relationships_for_fks(automap_base, map_config, table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship): +def _relationships_for_fks( + automap_base, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, +): local_table = map_config.local_table local_cls = map_config.cls # derived from a weakref, may be None @@ -898,32 +923,33 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, referred_cls = referred_cfg.cls if local_cls is not referred_cls and issubclass( - local_cls, referred_cls): + local_cls, referred_cls + ): continue relationship_name = name_for_scalar_relationship( - automap_base, - local_cls, - referred_cls, constraint) + automap_base, local_cls, referred_cls, constraint + ) backref_name = name_for_collection_relationship( - automap_base, - referred_cls, - local_cls, - constraint + automap_base, referred_cls, local_cls, constraint ) o2m_kws = {} nullable = False not in {fk.parent.nullable for fk in fks} if not nullable: - o2m_kws['cascade'] = "all, delete-orphan" + o2m_kws["cascade"] = "all, delete-orphan" - if constraint.ondelete and \ - constraint.ondelete.lower() == "cascade": - o2m_kws['passive_deletes'] = True + if ( + constraint.ondelete + and constraint.ondelete.lower() == "cascade" + ): + o2m_kws["passive_deletes"] = True else: - if constraint.ondelete and \ - constraint.ondelete.lower() == "set null": - o2m_kws['passive_deletes'] = True + if ( + constraint.ondelete + and constraint.ondelete.lower() == "set null" + ): + o2m_kws["passive_deletes"] = True create_backref = backref_name not in referred_cfg.properties @@ -931,54 +957,65 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, if create_backref: backref_obj = generate_relationship( automap_base, - interfaces.ONETOMANY, backref, - backref_name, referred_cls, local_cls, + interfaces.ONETOMANY, + backref, + backref_name, + referred_cls, + local_cls, collection_class=collection_class, - **o2m_kws) + **o2m_kws + ) else: backref_obj = None - rel = generate_relationship(automap_base, - interfaces.MANYTOONE, - relationship, - relationship_name, - local_cls, referred_cls, - foreign_keys=[ - fk.parent - for fk in constraint.elements], - backref=backref_obj, - remote_side=[ - fk.column - for fk in constraint.elements] - ) + rel = generate_relationship( + automap_base, + interfaces.MANYTOONE, + relationship, + relationship_name, + local_cls, + referred_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + backref=backref_obj, + remote_side=[fk.column for fk in constraint.elements], + ) if rel is not None: map_config.properties[relationship_name] = rel if not create_backref: referred_cfg.properties[ - backref_name].back_populates = relationship_name + backref_name + ].back_populates = relationship_name elif create_backref: - rel = generate_relationship(automap_base, - interfaces.ONETOMANY, - relationship, - backref_name, - referred_cls, local_cls, - foreign_keys=[ - fk.parent - for fk in constraint.elements], - back_populates=relationship_name, - collection_class=collection_class, - **o2m_kws) + rel = generate_relationship( + automap_base, + interfaces.ONETOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + back_populates=relationship_name, + collection_class=collection_class, + **o2m_kws + ) if rel is not None: referred_cfg.properties[backref_name] = rel map_config.properties[ - relationship_name].back_populates = backref_name - - -def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table, - table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship): + relationship_name + ].back_populates = backref_name + + +def _m2m_relationship( + automap_base, + lcl_m2m, + rem_m2m, + m2m_const, + table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, +): map_config = table_to_map_config.get(lcl_m2m, None) referred_cfg = table_to_map_config.get(rem_m2m, None) @@ -989,14 +1026,10 @@ def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table, referred_cls = referred_cfg.cls relationship_name = name_for_collection_relationship( - automap_base, - local_cls, - referred_cls, m2m_const[0]) + automap_base, local_cls, referred_cls, m2m_const[0] + ) backref_name = name_for_collection_relationship( - automap_base, - referred_cls, - local_cls, - m2m_const[1] + automap_base, referred_cls, local_cls, m2m_const[1] ) create_backref = backref_name not in referred_cfg.properties @@ -1008,48 +1041,56 @@ def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table, interfaces.MANYTOMANY, backref, backref_name, - referred_cls, local_cls, - collection_class=collection_class + referred_cls, + local_cls, + collection_class=collection_class, ) else: backref_obj = None - rel = generate_relationship(automap_base, - interfaces.MANYTOMANY, - relationship, - relationship_name, - local_cls, referred_cls, - secondary=table, - primaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[0].elements), - secondaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[1].elements), - backref=backref_obj, - collection_class=collection_class - ) + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + relationship_name, + local_cls, + referred_cls, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), + backref=backref_obj, + collection_class=collection_class, + ) if rel is not None: map_config.properties[relationship_name] = rel if not create_backref: referred_cfg.properties[ - backref_name].back_populates = relationship_name + backref_name + ].back_populates = relationship_name elif create_backref: - rel = generate_relationship(automap_base, - interfaces.MANYTOMANY, - relationship, - backref_name, - referred_cls, local_cls, - secondary=table, - primaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[1].elements), - secondaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[0].elements), - back_populates=relationship_name, - collection_class=collection_class) + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), + back_populates=relationship_name, + collection_class=collection_class, + ) if rel is not None: referred_cfg.properties[backref_name] = rel map_config.properties[ - relationship_name].back_populates = backref_name + relationship_name + ].back_populates = backref_name diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 516879142..f55231a09 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -38,7 +38,8 @@ class Bakery(object): """ - __slots__ = 'cls', 'cache' + + __slots__ = "cls", "cache" def __init__(self, cls_, cache): self.cls = cls_ @@ -51,7 +52,7 @@ class Bakery(object): class BakedQuery(object): """A builder object for :class:`.query.Query` objects.""" - __slots__ = 'steps', '_bakery', '_cache_key', '_spoiled' + __slots__ = "steps", "_bakery", "_cache_key", "_spoiled" def __init__(self, bakery, initial_fn, args=()): self._cache_key = () @@ -148,7 +149,7 @@ class BakedQuery(object): """ if not full and not self._spoiled: _spoil_point = self._clone() - _spoil_point._cache_key += ('_query_only', ) + _spoil_point._cache_key += ("_query_only",) self.steps = [_spoil_point._retrieve_baked_query] self._spoiled = True return self @@ -164,7 +165,7 @@ class BakedQuery(object): session will want to use. """ - return self._cache_key + (session._query_cls, ) + return self._cache_key + (session._query_cls,) def _with_lazyload_options(self, options, effective_path, cache_path=None): """Cloning version of _add_lazyload_options. @@ -201,16 +202,20 @@ class BakedQuery(object): key += cache_key self.add_criteria( - lambda q: q._with_current_path(effective_path). - _conditional_options(*options), - cache_path.path, key + lambda q: q._with_current_path( + effective_path + )._conditional_options(*options), + cache_path.path, + key, ) def _retrieve_baked_query(self, session): query = self._bakery.get(self._effective_key(session), None) if query is None: query = self._as_query(session) - self._bakery[self._effective_key(session)] = query.with_session(None) + self._bakery[self._effective_key(session)] = query.with_session( + None + ) return query.with_session(session) def _bake(self, session): @@ -227,8 +232,12 @@ class BakedQuery(object): # so delete some compilation-use-only attributes that can take up # space for attr in ( - '_correlate', '_from_obj', '_mapper_adapter_map', - '_joinpath', '_joinpoint'): + "_correlate", + "_from_obj", + "_mapper_adapter_map", + "_joinpath", + "_joinpoint", + ): query.__dict__.pop(attr, None) self._bakery[self._effective_key(session)] = context return context @@ -276,11 +285,13 @@ class BakedQuery(object): session = query_or_session.session if session is None: raise sa_exc.ArgumentError( - "Given Query needs to be associated with a Session") + "Given Query needs to be associated with a Session" + ) else: raise TypeError( - "Query or Session object expected, got %r." % - type(query_or_session)) + "Query or Session object expected, got %r." + % type(query_or_session) + ) return self._as_query(session) def _as_query(self, session): @@ -299,10 +310,10 @@ class BakedQuery(object): a "baked" query so that we save on performance too. """ - context.attributes['baked_queries'] = baked_queries = [] + context.attributes["baked_queries"] = baked_queries = [] for k, v in list(context.attributes.items()): if isinstance(v, Query): - if 'subquery' in k: + if "subquery" in k: bk = BakedQuery(self._bakery, lambda *args: v) bk._cache_key = self._cache_key + k bk._bake(session) @@ -310,15 +321,17 @@ class BakedQuery(object): del context.attributes[k] def _unbake_subquery_loaders( - self, session, context, params, post_criteria): + self, session, context, params, post_criteria + ): """Retrieve subquery eager loaders stored by _bake_subquery_loaders and turn them back into Result objects that will iterate just like a Query object. """ for k, cache_key, query in context.attributes["baked_queries"]: - bk = BakedQuery(self._bakery, - lambda sess, q=query: q.with_session(sess)) + bk = BakedQuery( + self._bakery, lambda sess, q=query: q.with_session(sess) + ) bk._cache_key = cache_key q = bk.for_session(session) for fn in post_criteria: @@ -334,7 +347,8 @@ class Result(object): against a target :class:`.Session`, and is then invoked for results. """ - __slots__ = 'bq', 'session', '_params', '_post_criteria' + + __slots__ = "bq", "session", "_params", "_post_criteria" def __init__(self, bq, session): self.bq = bq @@ -350,7 +364,8 @@ class Result(object): elif len(args) > 0: raise sa_exc.ArgumentError( "params() takes zero or one positional argument, " - "which is a dictionary.") + "which is a dictionary." + ) self._params.update(kw) return self @@ -403,7 +418,8 @@ class Result(object): context.attributes = context.attributes.copy() bq._unbake_subquery_loaders( - self.session, context, self._params, self._post_criteria) + self.session, context, self._params, self._post_criteria + ) context.statement.use_labels = True if context.autoflush and not context.populate_existing: @@ -426,7 +442,7 @@ class Result(object): """ - col = func.count(literal_column('*')) + col = func.count(literal_column("*")) bq = self.bq.with_criteria(lambda q: q.from_self(col)) return bq.for_session(self.session).params(self._params).scalar() @@ -456,8 +472,10 @@ class Result(object): """ bq = self.bq.with_criteria(lambda q: q.slice(0, 1)) ret = list( - bq.for_session(self.session).params(self._params). - _using_post_criteria(self._post_criteria)) + bq.for_session(self.session) + .params(self._params) + ._using_post_criteria(self._post_criteria) + ) if len(ret) > 0: return ret[0] else: @@ -473,7 +491,8 @@ class Result(object): ret = self.one_or_none() except orm_exc.MultipleResultsFound: raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one()") + "Multiple rows were found for one()" + ) else: if ret is None: raise orm_exc.NoResultFound("No row was found for one()") @@ -497,7 +516,8 @@ class Result(object): return None else: raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one_or_none()") + "Multiple rows were found for one_or_none()" + ) def all(self): """Return all rows. @@ -533,13 +553,18 @@ class Result(object): # None present in ident - turn those comparisons # into "IS NULL" if None in primary_key_identity: - nones = set([ - _get_params[col].key for col, value in - zip(mapper.primary_key, primary_key_identity) - if value is None - ]) + nones = set( + [ + _get_params[col].key + for col, value in zip( + mapper.primary_key, primary_key_identity + ) + if value is None + ] + ) _lcl_get_clause = sql_util.adapt_criterion_to_null( - _lcl_get_clause, nones) + _lcl_get_clause, nones + ) _lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False) q._criterion = _lcl_get_clause @@ -556,16 +581,20 @@ class Result(object): # key so that if a race causes multiple calls to _get_clause, # we've cached on ours bq = bq._clone() - bq._cache_key += (_get_clause, ) + bq._cache_key += (_get_clause,) bq = bq.with_criteria( - setup, tuple(elem is None for elem in primary_key_identity)) + setup, tuple(elem is None for elem in primary_key_identity) + ) - params = dict([ - (_get_params[primary_key].key, id_val) - for id_val, primary_key - in zip(primary_key_identity, mapper.primary_key) - ]) + params = dict( + [ + (_get_params[primary_key].key, id_val) + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + ] + ) result = list(bq.for_session(self.session).params(**params)) l = len(result) @@ -578,7 +607,8 @@ class Result(object): @util.deprecated( - "1.2", "Baked lazy loading is now the default implementation.") + "1.2", "Baked lazy loading is now the default implementation." +) def bake_lazy_loaders(): """Enable the use of baked queries for all lazyloaders systemwide. @@ -590,7 +620,8 @@ def bake_lazy_loaders(): @util.deprecated( - "1.2", "Baked lazy loading is now the default implementation.") + "1.2", "Baked lazy loading is now the default implementation." +) def unbake_lazy_loaders(): """Disable the use of baked queries for all lazyloaders systemwide. @@ -601,7 +632,8 @@ def unbake_lazy_loaders(): """ raise NotImplementedError( - "Baked lazy loading is now the default implementation") + "Baked lazy loading is now the default implementation" + ) @strategy_options.loader_option() @@ -615,20 +647,27 @@ def baked_lazyload(loadopt, attr): @baked_lazyload._add_unbound_fn @util.deprecated( - "1.2", "Baked lazy loading is now the default " - "implementation for lazy loading.") + "1.2", + "Baked lazy loading is now the default " + "implementation for lazy loading.", +) def baked_lazyload(*keys): return strategy_options._UnboundLoad._from_keys( - strategy_options._UnboundLoad.baked_lazyload, keys, False, {}) + strategy_options._UnboundLoad.baked_lazyload, keys, False, {} + ) @baked_lazyload._add_unbound_all_fn @util.deprecated( - "1.2", "Baked lazy loading is now the default " - "implementation for lazy loading.") + "1.2", + "Baked lazy loading is now the default " + "implementation for lazy loading.", +) def baked_lazyload_all(*keys): return strategy_options._UnboundLoad._from_keys( - strategy_options._UnboundLoad.baked_lazyload, keys, True, {}) + strategy_options._UnboundLoad.baked_lazyload, keys, True, {} + ) + baked_lazyload = baked_lazyload._unbound_fn baked_lazyload_all = baked_lazyload_all._unbound_all_fn diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 6a0909d36..220b2c057 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -407,37 +407,44 @@ def compiles(class_, *specs): def decorate(fn): # get an existing @compiles handler - existing = class_.__dict__.get('_compiler_dispatcher', None) + existing = class_.__dict__.get("_compiler_dispatcher", None) # get the original handler. All ClauseElement classes have one # of these, but some TypeEngine classes will not. - existing_dispatch = getattr(class_, '_compiler_dispatch', None) + existing_dispatch = getattr(class_, "_compiler_dispatch", None) if not existing: existing = _dispatcher() if existing_dispatch: + def _wrap_existing_dispatch(element, compiler, **kw): try: return existing_dispatch(element, compiler, **kw) except exc.UnsupportedCompilationError: raise exc.CompileError( "%s construct has no default " - "compilation handler." % type(element)) - existing.specs['default'] = _wrap_existing_dispatch + "compilation handler." % type(element) + ) + + existing.specs["default"] = _wrap_existing_dispatch # TODO: why is the lambda needed ? - setattr(class_, '_compiler_dispatch', - lambda *arg, **kw: existing(*arg, **kw)) - setattr(class_, '_compiler_dispatcher', existing) + setattr( + class_, + "_compiler_dispatch", + lambda *arg, **kw: existing(*arg, **kw), + ) + setattr(class_, "_compiler_dispatcher", existing) if specs: for s in specs: existing.specs[s] = fn else: - existing.specs['default'] = fn + existing.specs["default"] = fn return fn + return decorate @@ -445,7 +452,7 @@ def deregister(class_): """Remove all custom compilers associated with a given :class:`.ClauseElement` type.""" - if hasattr(class_, '_compiler_dispatcher'): + if hasattr(class_, "_compiler_dispatcher"): # regenerate default _compiler_dispatch visitors._generate_dispatch(class_) # remove custom directive @@ -461,10 +468,11 @@ class _dispatcher(object): fn = self.specs.get(compiler.dialect.name, None) if not fn: try: - fn = self.specs['default'] + fn = self.specs["default"] except KeyError: raise exc.CompileError( "%s construct has no default " - "compilation handler." % type(element)) + "compilation handler." % type(element) + ) return fn(element, compiler, **kw) diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py index cb81f51e5..2b0a37884 100644 --- a/lib/sqlalchemy/ext/declarative/__init__.py +++ b/lib/sqlalchemy/ext/declarative/__init__.py @@ -5,14 +5,31 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from .api import declarative_base, synonym_for, comparable_using, \ - instrument_declarative, ConcreteBase, AbstractConcreteBase, \ - DeclarativeMeta, DeferredReflection, has_inherited_table,\ - declared_attr, as_declarative +from .api import ( + declarative_base, + synonym_for, + comparable_using, + instrument_declarative, + ConcreteBase, + AbstractConcreteBase, + DeclarativeMeta, + DeferredReflection, + has_inherited_table, + declared_attr, + as_declarative, +) -__all__ = ['declarative_base', 'synonym_for', 'has_inherited_table', - 'comparable_using', 'instrument_declarative', 'declared_attr', - 'as_declarative', - 'ConcreteBase', 'AbstractConcreteBase', 'DeclarativeMeta', - 'DeferredReflection'] +__all__ = [ + "declarative_base", + "synonym_for", + "has_inherited_table", + "comparable_using", + "instrument_declarative", + "declared_attr", + "as_declarative", + "ConcreteBase", + "AbstractConcreteBase", + "DeclarativeMeta", + "DeferredReflection", +] diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py index 865cd16f0..987e92119 100644 --- a/lib/sqlalchemy/ext/declarative/api.py +++ b/lib/sqlalchemy/ext/declarative/api.py @@ -8,9 +8,13 @@ from ...schema import Table, MetaData, Column -from ...orm import synonym as _orm_synonym, \ - comparable_property,\ - interfaces, properties, attributes +from ...orm import ( + synonym as _orm_synonym, + comparable_property, + interfaces, + properties, + attributes, +) from ...orm.util import polymorphic_union from ...orm.base import _mapper_or_none from ...util import OrderedDict, hybridmethod, hybridproperty @@ -19,9 +23,13 @@ from ... import exc import weakref import re -from .base import _as_declarative, \ - _declarative_constructor,\ - _DeferredMapperConfig, _add_attribute, _del_attribute +from .base import ( + _as_declarative, + _declarative_constructor, + _DeferredMapperConfig, + _add_attribute, + _del_attribute, +) from .clsregistry import _class_resolver @@ -31,10 +39,10 @@ def instrument_declarative(cls, registry, metadata): MetaData object. """ - if '_decl_class_registry' in cls.__dict__: + if "_decl_class_registry" in cls.__dict__: raise exc.InvalidRequestError( - "Class %r already has been " - "instrumented declaratively" % cls) + "Class %r already has been " "instrumented declaratively" % cls + ) cls._decl_class_registry = registry cls.metadata = metadata _as_declarative(cls, cls.__name__, cls.__dict__) @@ -54,14 +62,14 @@ def has_inherited_table(cls): """ for class_ in cls.__mro__[1:]: - if getattr(class_, '__table__', None) is not None: + if getattr(class_, "__table__", None) is not None: return True return False class DeclarativeMeta(type): def __init__(cls, classname, bases, dict_): - if '_decl_class_registry' not in cls.__dict__: + if "_decl_class_registry" not in cls.__dict__: _as_declarative(cls, classname, cls.__dict__) type.__init__(cls, classname, bases, dict_) @@ -71,6 +79,7 @@ class DeclarativeMeta(type): def __delattr__(cls, key): _del_attribute(cls, key) + def synonym_for(name, map_column=False): """Decorator that produces an :func:`.orm.synonym` attribute in conjunction with a Python descriptor. @@ -104,8 +113,10 @@ def synonym_for(name, map_column=False): can be achieved with synonyms. """ + def decorate(fn): return _orm_synonym(name, map_column=map_column, descriptor=fn) + return decorate @@ -127,8 +138,10 @@ def comparable_using(comparator_factory): prop = comparable_property(MyComparatorType) """ + def decorate(fn): return comparable_property(comparator_factory, fn) + return decorate @@ -190,14 +203,16 @@ class declared_attr(interfaces._MappedAttribute, property): self._cascading = cascading def __get__(desc, self, cls): - reg = cls.__dict__.get('_sa_declared_attr_reg', None) + reg = cls.__dict__.get("_sa_declared_attr_reg", None) if reg is None: - if not re.match(r'^__.+__$', desc.fget.__name__) and \ - attributes.manager_of_class(cls) is None: + if ( + not re.match(r"^__.+__$", desc.fget.__name__) + and attributes.manager_of_class(cls) is None + ): util.warn( "Unmanaged access of declarative attribute %s from " - "non-mapped class %s" % - (desc.fget.__name__, cls.__name__)) + "non-mapped class %s" % (desc.fget.__name__, cls.__name__) + ) return desc.fget(cls) elif desc in reg: return reg[desc] @@ -283,10 +298,16 @@ class _stateful_declared_attr(declared_attr): return declared_attr(fn, **self.kw) -def declarative_base(bind=None, metadata=None, mapper=None, cls=object, - name='Base', constructor=_declarative_constructor, - class_registry=None, - metaclass=DeclarativeMeta): +def declarative_base( + bind=None, + metadata=None, + mapper=None, + cls=object, + name="Base", + constructor=_declarative_constructor, + class_registry=None, + metaclass=DeclarativeMeta, +): r"""Construct a base class for declarative class definitions. The new base class will be given a metaclass that produces @@ -357,16 +378,17 @@ def declarative_base(bind=None, metadata=None, mapper=None, cls=object, class_registry = weakref.WeakValueDictionary() bases = not isinstance(cls, tuple) and (cls,) or cls - class_dict = dict(_decl_class_registry=class_registry, - metadata=lcl_metadata) + class_dict = dict( + _decl_class_registry=class_registry, metadata=lcl_metadata + ) if isinstance(cls, type): - class_dict['__doc__'] = cls.__doc__ + class_dict["__doc__"] = cls.__doc__ if constructor: - class_dict['__init__'] = constructor + class_dict["__init__"] = constructor if mapper: - class_dict['__mapper_cls__'] = mapper + class_dict["__mapper_cls__"] = mapper return metaclass(name, bases, class_dict) @@ -401,9 +423,10 @@ def as_declarative(**kw): :func:`.declarative_base` """ + def decorate(cls): - kw['cls'] = cls - kw['name'] = cls.__name__ + kw["cls"] = cls + kw["name"] = cls.__name__ return declarative_base(**kw) return decorate @@ -456,10 +479,13 @@ class ConcreteBase(object): @classmethod def _create_polymorphic_union(cls, mappers): - return polymorphic_union(OrderedDict( - (mp.polymorphic_identity, mp.local_table) - for mp in mappers - ), 'type', 'pjoin') + return polymorphic_union( + OrderedDict( + (mp.polymorphic_identity, mp.local_table) for mp in mappers + ), + "type", + "pjoin", + ) @classmethod def __declare_first__(cls): @@ -568,7 +594,7 @@ class AbstractConcreteBase(ConcreteBase): @classmethod def _sa_decl_prepare_nocascade(cls): - if getattr(cls, '__mapper__', None): + if getattr(cls, "__mapper__", None): return to_map = _DeferredMapperConfig.config_for_cls(cls) @@ -604,8 +630,9 @@ class AbstractConcreteBase(ConcreteBase): def mapper_args(): args = m_args() - args['polymorphic_on'] = pjoin.c.type + args["polymorphic_on"] = pjoin.c.type return args + to_map.mapper_args_fn = mapper_args m = to_map.map() @@ -684,6 +711,7 @@ class DeferredReflection(object): .. versionadded:: 0.8 """ + @classmethod def prepare(cls, engine): """Reflect all :class:`.Table` objects for all current @@ -696,8 +724,10 @@ class DeferredReflection(object): mapper = thingy.cls.__mapper__ metadata = mapper.class_.metadata for rel in mapper._props.values(): - if isinstance(rel, properties.RelationshipProperty) and \ - rel.secondary is not None: + if ( + isinstance(rel, properties.RelationshipProperty) + and rel.secondary is not None + ): if isinstance(rel.secondary, Table): cls._reflect_table(rel.secondary, engine) elif isinstance(rel.secondary, _class_resolver): @@ -711,6 +741,7 @@ class DeferredReflection(object): t1 = Table(key, metadata) cls._reflect_table(t1, engine) return t1 + return _resolve @classmethod @@ -724,10 +755,12 @@ class DeferredReflection(object): @classmethod def _reflect_table(cls, table, engine): - Table(table.name, - table.metadata, - extend_existing=True, - autoload_replace=False, - autoload=True, - autoload_with=engine, - schema=table.schema) + Table( + table.name, + table.metadata, + extend_existing=True, + autoload_replace=False, + autoload=True, + autoload_with=engine, + schema=table.schema, + ) diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py index f27314b5e..07778f733 100644 --- a/lib/sqlalchemy/ext/declarative/base.py +++ b/lib/sqlalchemy/ext/declarative/base.py @@ -39,7 +39,7 @@ def _resolve_for_abstract_or_classical(cls): if cls is object: return None - if _get_immediate_cls_attr(cls, '__abstract__', strict=True): + if _get_immediate_cls_attr(cls, "__abstract__", strict=True): for sup in cls.__bases__: sup = _resolve_for_abstract_or_classical(sup) if sup is not None: @@ -59,7 +59,7 @@ def _dive_for_classically_mapped_class(cls): # if we are within a base hierarchy, don't # search at all for classical mappings - if hasattr(cls, '_decl_class_registry'): + if hasattr(cls, "_decl_class_registry"): return None manager = instrumentation.manager_of_class(cls) @@ -89,15 +89,19 @@ def _get_immediate_cls_attr(cls, attrname, strict=False): return None for base in cls.__mro__: - _is_declarative_inherits = hasattr(base, '_decl_class_registry') - _is_classicial_inherits = not _is_declarative_inherits and \ - _dive_for_classically_mapped_class(base) is not None + _is_declarative_inherits = hasattr(base, "_decl_class_registry") + _is_classicial_inherits = ( + not _is_declarative_inherits + and _dive_for_classically_mapped_class(base) is not None + ) if attrname in base.__dict__ and ( - base is cls or - ((base in cls.__bases__ if strict else True) + base is cls + or ( + (base in cls.__bases__ if strict else True) and not _is_declarative_inherits - and not _is_classicial_inherits) + and not _is_classicial_inherits + ) ): return getattr(base, attrname) else: @@ -108,9 +112,10 @@ def _as_declarative(cls, classname, dict_): global declared_attr, declarative_props if declared_attr is None: from .api import declared_attr + declarative_props = (declared_attr, util.classproperty) - if _get_immediate_cls_attr(cls, '__abstract__', strict=True): + if _get_immediate_cls_attr(cls, "__abstract__", strict=True): return _MapperConfig.setup_mapping(cls, classname, dict_) @@ -119,23 +124,23 @@ def _as_declarative(cls, classname, dict_): def _check_declared_props_nocascade(obj, name, cls): if isinstance(obj, declarative_props): - if getattr(obj, '_cascading', False): + if getattr(obj, "_cascading", False): util.warn( "@declared_attr.cascading is not supported on the %s " "attribute on class %s. This attribute invokes for " - "subclasses in any case." % (name, cls)) + "subclasses in any case." % (name, cls) + ) return True else: return False class _MapperConfig(object): - @classmethod def setup_mapping(cls, cls_, classname, dict_): defer_map = _get_immediate_cls_attr( - cls_, '_sa_decl_prepare_nocascade', strict=True) or \ - hasattr(cls_, '_sa_decl_prepare') + cls_, "_sa_decl_prepare_nocascade", strict=True + ) or hasattr(cls_, "_sa_decl_prepare") if defer_map: cfg_cls = _DeferredMapperConfig @@ -179,12 +184,14 @@ class _MapperConfig(object): self.map() def _setup_declared_events(self): - if _get_immediate_cls_attr(self.cls, '__declare_last__'): + if _get_immediate_cls_attr(self.cls, "__declare_last__"): + @event.listens_for(mapper, "after_configured") def after_configured(): self.cls.__declare_last__() - if _get_immediate_cls_attr(self.cls, '__declare_first__'): + if _get_immediate_cls_attr(self.cls, "__declare_first__"): + @event.listens_for(mapper, "before_configured") def before_configured(): self.cls.__declare_first__() @@ -198,59 +205,62 @@ class _MapperConfig(object): tablename = None for base in cls.__mro__: - class_mapped = base is not cls and \ - _declared_mapping_info(base) is not None and \ - not _get_immediate_cls_attr( - base, '_sa_decl_prepare_nocascade', strict=True) + class_mapped = ( + base is not cls + and _declared_mapping_info(base) is not None + and not _get_immediate_cls_attr( + base, "_sa_decl_prepare_nocascade", strict=True + ) + ) if not class_mapped and base is not cls: self._produce_column_copies(base) for name, obj in vars(base).items(): - if name == '__mapper_args__': - check_decl = \ - _check_declared_props_nocascade(obj, name, cls) - if not mapper_args_fn and ( - not class_mapped or - check_decl - ): + if name == "__mapper_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not mapper_args_fn and (not class_mapped or check_decl): # don't even invoke __mapper_args__ until # after we've determined everything about the # mapped table. # make a copy of it so a class-level dictionary # is not overwritten when we update column-based # arguments. - mapper_args_fn = lambda: dict(cls.__mapper_args__) # noqa - elif name == '__tablename__': - check_decl = \ - _check_declared_props_nocascade(obj, name, cls) - if not tablename and ( - not class_mapped or - check_decl - ): + mapper_args_fn = lambda: dict( + cls.__mapper_args__ + ) # noqa + elif name == "__tablename__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not tablename and (not class_mapped or check_decl): tablename = cls.__tablename__ - elif name == '__table_args__': - check_decl = \ - _check_declared_props_nocascade(obj, name, cls) - if not table_args and ( - not class_mapped or - check_decl - ): + elif name == "__table_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not table_args and (not class_mapped or check_decl): table_args = cls.__table_args__ if not isinstance( - table_args, (tuple, dict, type(None))): + table_args, (tuple, dict, type(None)) + ): raise exc.ArgumentError( "__table_args__ value must be a tuple, " - "dict, or None") + "dict, or None" + ) if base is not cls: inherited_table_args = True elif class_mapped: if isinstance(obj, declarative_props): - util.warn("Regular (i.e. not __special__) " - "attribute '%s.%s' uses @declared_attr, " - "but owning class %s is mapped - " - "not applying to subclass %s." - % (base.__name__, name, base, cls)) + util.warn( + "Regular (i.e. not __special__) " + "attribute '%s.%s' uses @declared_attr, " + "but owning class %s is mapped - " + "not applying to subclass %s." + % (base.__name__, name, base, cls) + ) continue elif base is not cls: # we're a mixin, abstract base, or something that is @@ -263,7 +273,8 @@ class _MapperConfig(object): "Mapper properties (i.e. deferred," "column_property(), relationship(), etc.) must " "be declared as @declared_attr callables " - "on declarative mixin classes.") + "on declarative mixin classes." + ) elif isinstance(obj, declarative_props): oldclassprop = isinstance(obj, util.classproperty) if not oldclassprop and obj._cascading: @@ -278,15 +289,18 @@ class _MapperConfig(object): "Attribute '%s' on class %s cannot be " "processed due to " "@declared_attr.cascading; " - "skipping" % (name, cls)) - dict_[name] = column_copies[obj] = \ - ret = obj.__get__(obj, cls) + "skipping" % (name, cls) + ) + dict_[name] = column_copies[ + obj + ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) else: if oldclassprop: util.warn_deprecated( "Use of sqlalchemy.util.classproperty on " - "declarative classes is deprecated.") + "declarative classes is deprecated." + ) # access attribute using normal class access ret = getattr(cls, name) @@ -294,14 +308,20 @@ class _MapperConfig(object): # or similar. note there is no known case that # produces nested proxies, so we are only # looking one level deep right now. - if isinstance(ret, InspectionAttr) and \ - ret._is_internal_proxy and not isinstance( - ret.original_property, MapperProperty): + if ( + isinstance(ret, InspectionAttr) + and ret._is_internal_proxy + and not isinstance( + ret.original_property, MapperProperty + ) + ): ret = ret.descriptor dict_[name] = column_copies[obj] = ret - if isinstance(ret, (Column, MapperProperty)) and \ - ret.doc is None: + if ( + isinstance(ret, (Column, MapperProperty)) + and ret.doc is None + ): ret.doc = obj.__doc__ # here, the attribute is some other kind of property that # we assume is not part of the declarative mapping. @@ -321,8 +341,9 @@ class _MapperConfig(object): util.warn( "Attribute '%s' on class %s appears to be a non-schema " "'sqlalchemy.sql.column()' " - "object; this won't be part of the declarative mapping" % - (key, cls)) + "object; this won't be part of the declarative mapping" + % (key, cls) + ) def _produce_column_copies(self, base): cls = self.cls @@ -340,10 +361,11 @@ class _MapperConfig(object): raise exc.InvalidRequestError( "Columns with foreign keys to other columns " "must be declared as @declared_attr callables " - "on declarative mixin classes. ") + "on declarative mixin classes. " + ) elif name not in dict_ and not ( - '__table__' in dict_ and - (obj.name or name) in dict_['__table__'].c + "__table__" in dict_ + and (obj.name or name) in dict_["__table__"].c ): column_copies[obj] = copy_ = obj.copy() copy_._creation_order = obj._creation_order @@ -357,11 +379,12 @@ class _MapperConfig(object): our_stuff = self.properties late_mapped = _get_immediate_cls_attr( - cls, '_sa_decl_prepare_nocascade', strict=True) + cls, "_sa_decl_prepare_nocascade", strict=True + ) for k in list(dict_): - if k in ('__table__', '__tablename__', '__mapper_args__'): + if k in ("__table__", "__tablename__", "__mapper_args__"): continue value = dict_[k] @@ -371,29 +394,37 @@ class _MapperConfig(object): "Use of @declared_attr.cascading only applies to " "Declarative 'mixin' and 'abstract' classes. " "Currently, this flag is ignored on mapped class " - "%s" % self.cls) + "%s" % self.cls + ) value = getattr(cls, k) - elif isinstance(value, QueryableAttribute) and \ - value.class_ is not cls and \ - value.key != k: + elif ( + isinstance(value, QueryableAttribute) + and value.class_ is not cls + and value.key != k + ): # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() value = synonym(value.key) setattr(cls, k, value) - if (isinstance(value, tuple) and len(value) == 1 and - isinstance(value[0], (Column, MapperProperty))): - util.warn("Ignoring declarative-like tuple value of attribute " - "'%s': possibly a copy-and-paste error with a comma " - "accidentally placed at the end of the line?" % k) + if ( + isinstance(value, tuple) + and len(value) == 1 + and isinstance(value[0], (Column, MapperProperty)) + ): + util.warn( + "Ignoring declarative-like tuple value of attribute " + "'%s': possibly a copy-and-paste error with a comma " + "accidentally placed at the end of the line?" % k + ) continue elif not isinstance(value, (Column, MapperProperty)): # using @declared_attr for some object that # isn't Column/MapperProperty; remove from the dict_ # and place the evaluated value onto the class. - if not k.startswith('__'): + if not k.startswith("__"): dict_.pop(k) self._warn_for_decl_attributes(cls, k, value) if not late_mapped: @@ -402,7 +433,7 @@ class _MapperConfig(object): # we expect to see the name 'metadata' in some valid cases; # however at this point we see it's assigned to something trying # to be mapped, so raise for that. - elif k == 'metadata': + elif k == "metadata": raise exc.InvalidRequestError( "Attribute name 'metadata' is reserved " "for the MetaData instance when using a " @@ -423,8 +454,7 @@ class _MapperConfig(object): for key, c in list(our_stuff.items()): if isinstance(c, (ColumnProperty, CompositeProperty)): for col in c.columns: - if isinstance(col, Column) and \ - col.table is None: + if isinstance(col, Column) and col.table is None: _undefer_column_name(key, col) if not isinstance(c, CompositeProperty): name_to_prop_key[col.name].add(key) @@ -447,8 +477,8 @@ class _MapperConfig(object): "On class %r, Column object %r named " "directly multiple times, " "only one will be used: %s. " - "Consider using orm.synonym instead" % - (self.classname, name, (", ".join(sorted(keys)))) + "Consider using orm.synonym instead" + % (self.classname, name, (", ".join(sorted(keys)))) ) def _setup_table(self): @@ -459,15 +489,16 @@ class _MapperConfig(object): declared_columns = self.declared_columns declared_columns = self.declared_columns = sorted( - declared_columns, key=lambda c: c._creation_order) + declared_columns, key=lambda c: c._creation_order + ) table = None - if hasattr(cls, '__table_cls__'): + if hasattr(cls, "__table_cls__"): table_cls = util.unbound_method_to_callable(cls.__table_cls__) else: table_cls = Table - if '__table__' not in dict_: + if "__table__" not in dict_: if tablename is not None: args, table_kw = (), {} @@ -480,14 +511,16 @@ class _MapperConfig(object): else: args = table_args - autoload = dict_.get('__autoload__') + autoload = dict_.get("__autoload__") if autoload: - table_kw['autoload'] = True + table_kw["autoload"] = True cls.__table__ = table = table_cls( - tablename, cls.metadata, + tablename, + cls.metadata, *(tuple(declared_columns) + tuple(args)), - **table_kw) + **table_kw + ) else: table = cls.__table__ if declared_columns: @@ -512,21 +545,27 @@ class _MapperConfig(object): c = _resolve_for_abstract_or_classical(c) if c is None: continue - if _declared_mapping_info(c) is not None and \ - not _get_immediate_cls_attr( - c, '_sa_decl_prepare_nocascade', strict=True): + if _declared_mapping_info( + c + ) is not None and not _get_immediate_cls_attr( + c, "_sa_decl_prepare_nocascade", strict=True + ): inherits.append(c) if inherits: if len(inherits) > 1: raise exc.InvalidRequestError( - "Class %s has multiple mapped bases: %r" % (cls, inherits)) + "Class %s has multiple mapped bases: %r" % (cls, inherits) + ) self.inherits = inherits[0] else: self.inherits = None - if table is None and self.inherits is None and \ - not _get_immediate_cls_attr(cls, '__no_table__'): + if ( + table is None + and self.inherits is None + and not _get_immediate_cls_attr(cls, "__no_table__") + ): raise exc.InvalidRequestError( "Class %r does not have a __table__ or __tablename__ " @@ -553,8 +592,8 @@ class _MapperConfig(object): continue raise exc.ArgumentError( "Column '%s' on class %s conflicts with " - "existing column '%s'" % - (c, cls, inherited_table.c[c.name]) + "existing column '%s'" + % (c, cls, inherited_table.c[c.name]) ) if c.primary_key: raise exc.ArgumentError( @@ -562,8 +601,10 @@ class _MapperConfig(object): "class with no table." ) inherited_table.append_column(c) - if inherited_mapped_table is not None and \ - inherited_mapped_table is not inherited_table: + if ( + inherited_mapped_table is not None + and inherited_mapped_table is not inherited_table + ): inherited_mapped_table._refresh_for_new_column(c) def _prepare_mapper_arguments(self): @@ -575,18 +616,19 @@ class _MapperConfig(object): # make sure that column copies are used rather # than the original columns from any mixins - for k in ('version_id_col', 'polymorphic_on',): + for k in ("version_id_col", "polymorphic_on"): if k in mapper_args: v = mapper_args[k] mapper_args[k] = self.column_copies.get(v, v) - assert 'inherits' not in mapper_args, \ - "Can't specify 'inherits' explicitly with declarative mappings" + assert ( + "inherits" not in mapper_args + ), "Can't specify 'inherits' explicitly with declarative mappings" if self.inherits: - mapper_args['inherits'] = self.inherits + mapper_args["inherits"] = self.inherits - if self.inherits and not mapper_args.get('concrete', False): + if self.inherits and not mapper_args.get("concrete", False): # single or joined inheritance # exclude any cols on the inherited table which are # not mapped on the parent class, to avoid @@ -594,16 +636,17 @@ class _MapperConfig(object): inherited_mapper = _declared_mapping_info(self.inherits) inherited_table = inherited_mapper.local_table - if 'exclude_properties' not in mapper_args: - mapper_args['exclude_properties'] = exclude_properties = \ - set( - [c.key for c in inherited_table.c - if c not in inherited_mapper._columntoproperty] - ).union( - inherited_mapper.exclude_properties or () - ) + if "exclude_properties" not in mapper_args: + mapper_args["exclude_properties"] = exclude_properties = set( + [ + c.key + for c in inherited_table.c + if c not in inherited_mapper._columntoproperty + ] + ).union(inherited_mapper.exclude_properties or ()) exclude_properties.difference_update( - [c.key for c in self.declared_columns]) + [c.key for c in self.declared_columns] + ) # look through columns in the current mapper that # are keyed to a propname different than the colname @@ -621,21 +664,20 @@ class _MapperConfig(object): # first. See [ticket:1892] for background. properties[k] = [col] + p.columns result_mapper_args = mapper_args.copy() - result_mapper_args['properties'] = properties + result_mapper_args["properties"] = properties self.mapper_args = result_mapper_args def map(self): self._prepare_mapper_arguments() - if hasattr(self.cls, '__mapper_cls__'): + if hasattr(self.cls, "__mapper_cls__"): mapper_cls = util.unbound_method_to_callable( - self.cls.__mapper_cls__) + self.cls.__mapper_cls__ + ) else: mapper_cls = mapper self.cls.__mapper__ = mp_ = mapper_cls( - self.cls, - self.local_table, - **self.mapper_args + self.cls, self.local_table, **self.mapper_args ) del self.cls._sa_declared_attr_reg return mp_ @@ -663,8 +705,7 @@ class _DeferredMapperConfig(_MapperConfig): @classmethod def has_cls(cls, class_): # 2.6 fails on weakref if class_ is an old style class - return isinstance(class_, type) and \ - weakref.ref(class_) in cls._configs + return isinstance(class_, type) and weakref.ref(class_) in cls._configs @classmethod def config_for_cls(cls, class_): @@ -673,18 +714,15 @@ class _DeferredMapperConfig(_MapperConfig): @classmethod def classes_for_base(cls, base_cls, sort=True): classes_for_base = [ - m for m, cls_ in - [(m, m.cls) for m in cls._configs.values()] + m + for m, cls_ in [(m, m.cls) for m in cls._configs.values()] if cls_ is not None and issubclass(cls_, base_cls) ] if not sort: return classes_for_base - all_m_by_cls = dict( - (m.cls, m) - for m in classes_for_base - ) + all_m_by_cls = dict((m.cls, m) for m in classes_for_base) tuples = [] for m_cls in all_m_by_cls: @@ -693,12 +731,7 @@ class _DeferredMapperConfig(_MapperConfig): for base_cls in m_cls.__bases__ if base_cls in all_m_by_cls ) - return list( - topological.sort( - tuples, - classes_for_base - ) - ) + return list(topological.sort(tuples, classes_for_base)) def map(self): self._configs.pop(self._cls, None) @@ -713,7 +746,7 @@ def _add_attribute(cls, key, value): """ - if '__mapper__' in cls.__dict__: + if "__mapper__" in cls.__dict__: if isinstance(value, Column): _undefer_column_name(key, value) cls.__table__.append_column(value) @@ -726,16 +759,14 @@ def _add_attribute(cls, key, value): cls.__mapper__.add_property(key, value) elif isinstance(value, MapperProperty): cls.__mapper__.add_property( - key, - clsregistry._deferred_relationship(cls, value) + key, clsregistry._deferred_relationship(cls, value) ) elif isinstance(value, QueryableAttribute) and value.key != key: # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() value = synonym(value.key) cls.__mapper__.add_property( - key, - clsregistry._deferred_relationship(cls, value) + key, clsregistry._deferred_relationship(cls, value) ) else: type.__setattr__(cls, key, value) @@ -746,15 +777,18 @@ def _add_attribute(cls, key, value): def _del_attribute(cls, key): - if '__mapper__' in cls.__dict__ and \ - key in cls.__dict__ and not cls.__mapper__._dispose_called: + if ( + "__mapper__" in cls.__dict__ + and key in cls.__dict__ + and not cls.__mapper__._dispose_called + ): value = cls.__dict__[key] if isinstance( - value, - (Column, ColumnProperty, MapperProperty, QueryableAttribute) + value, (Column, ColumnProperty, MapperProperty, QueryableAttribute) ): raise NotImplementedError( - "Can't un-map individual mapped attributes on a mapped class.") + "Can't un-map individual mapped attributes on a mapped class." + ) else: type.__delattr__(cls, key) cls.__mapper__._expire_memoizations() @@ -776,10 +810,12 @@ def _declarative_constructor(self, **kwargs): for k in kwargs: if not hasattr(cls_, k): raise TypeError( - "%r is an invalid keyword argument for %s" % - (k, cls_.__name__)) + "%r is an invalid keyword argument for %s" % (k, cls_.__name__) + ) setattr(self, k, kwargs[k]) -_declarative_constructor.__name__ = '__init__' + + +_declarative_constructor.__name__ = "__init__" def _undefer_column_name(key, column): diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index e941b9ed3..c52ae4a2f 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -10,8 +10,11 @@ This system allows specification of classes and expressions used in :func:`.relationship` using strings. """ -from ...orm.properties import ColumnProperty, RelationshipProperty, \ - SynonymProperty +from ...orm.properties import ( + ColumnProperty, + RelationshipProperty, + SynonymProperty, +) from ...schema import _get_table_key from ...orm import class_mapper, interfaces from ... import util @@ -35,17 +38,18 @@ def add_class(classname, cls): # class already exists. existing = cls._decl_class_registry[classname] if not isinstance(existing, _MultipleClassMarker): - existing = \ - cls._decl_class_registry[classname] = \ - _MultipleClassMarker([cls, existing]) + existing = cls._decl_class_registry[ + classname + ] = _MultipleClassMarker([cls, existing]) else: cls._decl_class_registry[classname] = cls try: - root_module = cls._decl_class_registry['_sa_module_registry'] + root_module = cls._decl_class_registry["_sa_module_registry"] except KeyError: - cls._decl_class_registry['_sa_module_registry'] = \ - root_module = _ModuleMarker('_sa_module_registry', None) + cls._decl_class_registry[ + "_sa_module_registry" + ] = root_module = _ModuleMarker("_sa_module_registry", None) tokens = cls.__module__.split(".") @@ -71,12 +75,13 @@ class _MultipleClassMarker(object): """ - __slots__ = 'on_remove', 'contents', '__weakref__' + __slots__ = "on_remove", "contents", "__weakref__" def __init__(self, classes, on_remove=None): self.on_remove = on_remove - self.contents = set([ - weakref.ref(item, self._remove_item) for item in classes]) + self.contents = set( + [weakref.ref(item, self._remove_item) for item in classes] + ) _registries.add(self) def __iter__(self): @@ -85,10 +90,10 @@ class _MultipleClassMarker(object): def attempt_get(self, path, key): if len(self.contents) > 1: raise exc.InvalidRequestError( - "Multiple classes found for path \"%s\" " + 'Multiple classes found for path "%s" ' "in the registry of this declarative " - "base. Please use a fully module-qualified path." % - (".".join(path + [key])) + "base. Please use a fully module-qualified path." + % (".".join(path + [key])) ) else: ref = list(self.contents)[0] @@ -108,17 +113,19 @@ class _MultipleClassMarker(object): # protect against class registration race condition against # asynchronous garbage collection calling _remove_item, # [ticket:3208] - modules = set([ - cls.__module__ for cls in - [ref() for ref in self.contents] if cls is not None]) + modules = set( + [ + cls.__module__ + for cls in [ref() for ref in self.contents] + if cls is not None + ] + ) if item.__module__ in modules: util.warn( "This declarative base already contains a class with the " "same class name and module name as %s.%s, and will " - "be replaced in the string-lookup table." % ( - item.__module__, - item.__name__ - ) + "be replaced in the string-lookup table." + % (item.__module__, item.__name__) ) self.contents.add(weakref.ref(item, self._remove_item)) @@ -129,7 +136,7 @@ class _ModuleMarker(object): """ - __slots__ = 'parent', 'name', 'contents', 'mod_ns', 'path', '__weakref__' + __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" def __init__(self, name, parent): self.parent = parent @@ -170,13 +177,13 @@ class _ModuleMarker(object): existing = self.contents[name] existing.add_item(cls) else: - existing = self.contents[name] = \ - _MultipleClassMarker([cls], - on_remove=lambda: self._remove_item(name)) + existing = self.contents[name] = _MultipleClassMarker( + [cls], on_remove=lambda: self._remove_item(name) + ) class _ModNS(object): - __slots__ = '__parent', + __slots__ = ("__parent",) def __init__(self, parent): self.__parent = parent @@ -193,13 +200,14 @@ class _ModNS(object): else: assert isinstance(value, _MultipleClassMarker) return value.attempt_get(self.__parent.path, key) - raise AttributeError("Module %r has no mapped classes " - "registered under the name %r" % ( - self.__parent.name, key)) + raise AttributeError( + "Module %r has no mapped classes " + "registered under the name %r" % (self.__parent.name, key) + ) class _GetColumns(object): - __slots__ = 'cls', + __slots__ = ("cls",) def __init__(self, cls): self.cls = cls @@ -210,7 +218,8 @@ class _GetColumns(object): if key not in mp.all_orm_descriptors: raise exc.InvalidRequestError( "Class %r does not have a mapped column named %r" - % (self.cls, key)) + % (self.cls, key) + ) desc = mp.all_orm_descriptors[key] if desc.extension_type is interfaces.NOT_EXTENSION: @@ -221,24 +230,25 @@ class _GetColumns(object): raise exc.InvalidRequestError( "Property %r is not an instance of" " ColumnProperty (i.e. does not correspond" - " directly to a Column)." % key) + " directly to a Column)." % key + ) return getattr(self.cls, key) + inspection._inspects(_GetColumns)( - lambda target: inspection.inspect(target.cls)) + lambda target: inspection.inspect(target.cls) +) class _GetTable(object): - __slots__ = 'key', 'metadata' + __slots__ = "key", "metadata" def __init__(self, key, metadata): self.key = key self.metadata = metadata def __getattr__(self, key): - return self.metadata.tables[ - _get_table_key(key, self.key) - ] + return self.metadata.tables[_get_table_key(key, self.key)] def _determine_container(key, value): @@ -264,9 +274,11 @@ class _class_resolver(object): return cls.metadata.tables[key] elif key in cls.metadata._schemas: return _GetTable(key, cls.metadata) - elif '_sa_module_registry' in cls._decl_class_registry and \ - key in cls._decl_class_registry['_sa_module_registry']: - registry = cls._decl_class_registry['_sa_module_registry'] + elif ( + "_sa_module_registry" in cls._decl_class_registry + and key in cls._decl_class_registry["_sa_module_registry"] + ): + registry = cls._decl_class_registry["_sa_module_registry"] return registry.resolve_attr(key) elif self._resolvers: for resolv in self._resolvers: @@ -289,8 +301,8 @@ class _class_resolver(object): "When initializing mapper %s, expression %r failed to " "locate a name (%r). If this is a class name, consider " "adding this relationship() to the %r class after " - "both dependent classes have been defined." % - (self.prop.parent, self.arg, n.args[0], self.cls) + "both dependent classes have been defined." + % (self.prop.parent, self.arg, n.args[0], self.cls) ) @@ -299,10 +311,11 @@ def _resolver(cls, prop): from sqlalchemy.orm import foreign, remote fallback = sqlalchemy.__dict__.copy() - fallback.update({'foreign': foreign, 'remote': remote}) + fallback.update({"foreign": foreign, "remote": remote}) def resolve_arg(arg): return _class_resolver(cls, prop, fallback, arg) + return resolve_arg @@ -311,18 +324,32 @@ def _deferred_relationship(cls, prop): if isinstance(prop, RelationshipProperty): resolve_arg = _resolver(cls, prop) - for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin', - 'secondary', '_user_defined_foreign_keys', 'remote_side'): + for attr in ( + "argument", + "order_by", + "primaryjoin", + "secondaryjoin", + "secondary", + "_user_defined_foreign_keys", + "remote_side", + ): v = getattr(prop, attr) if isinstance(v, util.string_types): setattr(prop, attr, resolve_arg(v)) if prop.backref and isinstance(prop.backref, tuple): key, kwargs = prop.backref - for attr in ('primaryjoin', 'secondaryjoin', 'secondary', - 'foreign_keys', 'remote_side', 'order_by'): - if attr in kwargs and isinstance(kwargs[attr], - util.string_types): + for attr in ( + "primaryjoin", + "secondaryjoin", + "secondary", + "foreign_keys", + "remote_side", + "order_by", + ): + if attr in kwargs and isinstance( + kwargs[attr], util.string_types + ): kwargs[attr] = resolve_arg(kwargs[attr]) return prop diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index f86e4fc93..7248e5b4d 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -20,7 +20,7 @@ from .. import util from ..orm.session import Session from ..orm.query import Query -__all__ = ['ShardedSession', 'ShardedQuery'] +__all__ = ["ShardedSession", "ShardedQuery"] class ShardedQuery(Query): @@ -43,12 +43,10 @@ class ShardedQuery(Query): def _execute_and_instances(self, context): def iter_for_shard(shard_id): - context.attributes['shard_id'] = context.identity_token = shard_id + context.attributes["shard_id"] = context.identity_token = shard_id result = self._connection_from_session( - mapper=self._bind_mapper(), - shard_id=shard_id).execute( - context.statement, - self._params) + mapper=self._bind_mapper(), shard_id=shard_id + ).execute(context.statement, self._params) return self.instances(result, context) if context.identity_token is not None: @@ -70,7 +68,8 @@ class ShardedQuery(Query): mapper=mapper, shard_id=shard_id, clause=stmt, - close_with_result=True) + close_with_result=True, + ) result = conn.execute(stmt, self._params) return result @@ -87,8 +86,13 @@ class ShardedQuery(Query): return ShardedResult(results, rowcount) def _identity_lookup( - self, mapper, primary_key_identity, identity_token=None, - lazy_loaded_from=None, **kw): + self, + mapper, + primary_key_identity, + identity_token=None, + lazy_loaded_from=None, + **kw + ): """override the default Query._identity_lookup method so that we search for a given non-token primary key identity across all possible identity tokens (e.g. shard ids). @@ -97,8 +101,10 @@ class ShardedQuery(Query): if identity_token is not None: return super(ShardedQuery, self)._identity_lookup( - mapper, primary_key_identity, - identity_token=identity_token, **kw + mapper, + primary_key_identity, + identity_token=identity_token, + **kw ) else: q = self.session.query(mapper) @@ -113,13 +119,13 @@ class ShardedQuery(Query): return None - def _get_impl( - self, primary_key_identity, db_load_fn, identity_token=None): + def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): """Override the default Query._get_impl() method so that we emit a query to the DB for each possible identity token, if we don't have one already. """ + def _db_load_fn(query, primary_key_identity): # load from the database. The original db_load_fn will # use the given Query object to load from the DB, so our @@ -142,7 +148,8 @@ class ShardedQuery(Query): identity_token = self._shard_id return super(ShardedQuery, self)._get_impl( - primary_key_identity, _db_load_fn, identity_token=identity_token) + primary_key_identity, _db_load_fn, identity_token=identity_token + ) class ShardedResult(object): @@ -158,7 +165,7 @@ class ShardedResult(object): .. versionadded:: 1.3 """ - __slots__ = ('result_proxies', 'aggregate_rowcount',) + __slots__ = ("result_proxies", "aggregate_rowcount") def __init__(self, result_proxies, aggregate_rowcount): self.result_proxies = result_proxies @@ -168,9 +175,17 @@ class ShardedResult(object): def rowcount(self): return self.aggregate_rowcount + class ShardedSession(Session): - def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, - query_cls=ShardedQuery, **kwargs): + def __init__( + self, + shard_chooser, + id_chooser, + query_chooser, + shards=None, + query_cls=ShardedQuery, + **kwargs + ): """Construct a ShardedSession. :param shard_chooser: A callable which, passed a Mapper, a mapped @@ -225,16 +240,16 @@ class ShardedSession(Session): return self.transaction.connection(mapper, shard_id=shard_id) else: return self.get_bind( - mapper, - shard_id=shard_id, - instance=instance + mapper, shard_id=shard_id, instance=instance ).contextual_connect(**kwargs) - def get_bind(self, mapper, shard_id=None, - instance=None, clause=None, **kw): + def get_bind( + self, mapper, shard_id=None, instance=None, clause=None, **kw + ): if shard_id is None: shard_id = self._choose_shard_and_assign( - mapper, instance, clause=clause) + mapper, instance, clause=clause + ) return self.__binds[shard_id] def bind_shard(self, shard_id, bind): diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 95eecb93f..d51a083da 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -778,7 +778,7 @@ there's probably a whole lot of amazing things it can be used for. from .. import util from ..orm import attributes, interfaces -HYBRID_METHOD = util.symbol('HYBRID_METHOD') +HYBRID_METHOD = util.symbol("HYBRID_METHOD") """Symbol indicating an :class:`InspectionAttr` that's of type :class:`.hybrid_method`. @@ -791,7 +791,7 @@ HYBRID_METHOD = util.symbol('HYBRID_METHOD') """ -HYBRID_PROPERTY = util.symbol('HYBRID_PROPERTY') +HYBRID_PROPERTY = util.symbol("HYBRID_PROPERTY") """Symbol indicating an :class:`InspectionAttr` that's of type :class:`.hybrid_method`. @@ -860,8 +860,14 @@ class hybrid_property(interfaces.InspectionAttrInfo): extension_type = HYBRID_PROPERTY def __init__( - self, fget, fset=None, fdel=None, - expr=None, custom_comparator=None, update_expr=None): + self, + fget, + fset=None, + fdel=None, + expr=None, + custom_comparator=None, + update_expr=None, + ): """Create a new :class:`.hybrid_property`. Usage is typically via decorator:: @@ -906,7 +912,8 @@ class hybrid_property(interfaces.InspectionAttrInfo): defaults = { key: value for key, value in self.__dict__.items() - if not key.startswith("_")} + if not key.startswith("_") + } defaults.update(**kw) return type(self)(**defaults) @@ -1078,9 +1085,9 @@ class hybrid_property(interfaces.InspectionAttrInfo): return self._get_expr(self.fget) def _get_expr(self, expr): - def _expr(cls): return ExprComparator(cls, expr(cls), self) + util.update_wrapper(_expr, expr) return self._get_comparator(_expr) @@ -1091,8 +1098,13 @@ class hybrid_property(interfaces.InspectionAttrInfo): def expr_comparator(owner): return proxy_attr( - owner, self.__name__, self, comparator(owner), - doc=comparator.__doc__ or self.__doc__) + owner, + self.__name__, + self, + comparator(owner), + doc=comparator.__doc__ or self.__doc__, + ) + return expr_comparator @@ -1108,7 +1120,7 @@ class Comparator(interfaces.PropComparator): def __clause_element__(self): expr = self.expression - if hasattr(expr, '__clause_element__'): + if hasattr(expr, "__clause_element__"): expr = expr.__clause_element__() return expr diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py index 0bc2b65bb..368e5b00a 100644 --- a/lib/sqlalchemy/ext/indexable.py +++ b/lib/sqlalchemy/ext/indexable.py @@ -232,7 +232,7 @@ from ..orm.attributes import flag_modified from ..ext.hybrid import hybrid_property -__all__ = ['index_property'] +__all__ = ["index_property"] class index_property(hybrid_property): # noqa @@ -251,8 +251,14 @@ class index_property(hybrid_property): # noqa _NO_DEFAULT_ARGUMENT = object() def __init__( - self, attr_name, index, default=_NO_DEFAULT_ARGUMENT, - datatype=None, mutable=True, onebased=True): + self, + attr_name, + index, + default=_NO_DEFAULT_ARGUMENT, + datatype=None, + mutable=True, + onebased=True, + ): """Create a new :class:`.index_property`. :param attr_name: diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index 30a0ab7d7..b2b8dd7c5 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -28,15 +28,18 @@ see the example :ref:`examples_instrumentation`. """ from ..orm import instrumentation as orm_instrumentation from ..orm.instrumentation import ( - ClassManager, InstrumentationFactory, _default_state_getter, - _default_dict_getter, _default_manager_getter + ClassManager, + InstrumentationFactory, + _default_state_getter, + _default_dict_getter, + _default_manager_getter, ) from ..orm import attributes, collections, base as orm_base from .. import util from ..orm import exc as orm_exc import weakref -INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__' +INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__" """Attribute, elects custom instrumentation when present on a mapped class. Allows a class to specify a slightly or wildly different technique for @@ -66,6 +69,7 @@ def find_native_user_instrumentation_hook(cls): """Find user-specified instrumentation management for a class.""" return getattr(cls, INSTRUMENTATION_MANAGER, None) + instrumentation_finders = [find_native_user_instrumentation_hook] """An extensible sequence of callables which return instrumentation implementations @@ -89,6 +93,7 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): class managers. """ + _manager_finders = weakref.WeakKeyDictionary() _state_finders = weakref.WeakKeyDictionary() _dict_finders = weakref.WeakKeyDictionary() @@ -104,13 +109,15 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): return None, None def _check_conflicts(self, class_, factory): - existing_factories = self._collect_management_factories_for(class_).\ - difference([factory]) + existing_factories = self._collect_management_factories_for( + class_ + ).difference([factory]) if existing_factories: raise TypeError( "multiple instrumentation implementations specified " - "in %s inheritance hierarchy: %r" % ( - class_.__name__, list(existing_factories))) + "in %s inheritance hierarchy: %r" + % (class_.__name__, list(existing_factories)) + ) def _extended_class_manager(self, class_, factory): manager = factory(class_) @@ -178,17 +185,20 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): if instance is None: raise AttributeError("None has no persistent state.") return self._state_finders.get( - instance.__class__, _default_state_getter)(instance) + instance.__class__, _default_state_getter + )(instance) def dict_of(self, instance): if instance is None: raise AttributeError("None has no persistent state.") return self._dict_finders.get( - instance.__class__, _default_dict_getter)(instance) + instance.__class__, _default_dict_getter + )(instance) -orm_instrumentation._instrumentation_factory = \ - _instrumentation_factory = ExtendedInstrumentationRegistry() +orm_instrumentation._instrumentation_factory = ( + _instrumentation_factory +) = ExtendedInstrumentationRegistry() orm_instrumentation.instrumentation_finders = instrumentation_finders @@ -222,14 +232,15 @@ class InstrumentationManager(object): pass def manage(self, class_, manager): - setattr(class_, '_default_class_manager', manager) + setattr(class_, "_default_class_manager", manager) def dispose(self, class_, manager): - delattr(class_, '_default_class_manager') + delattr(class_, "_default_class_manager") def manager_getter(self, class_): def get(cls): return cls._default_class_manager + return get def instrument_attribute(self, class_, key, inst): @@ -260,13 +271,13 @@ class InstrumentationManager(object): pass def install_state(self, class_, instance, state): - setattr(instance, '_default_state', state) + setattr(instance, "_default_state", state) def remove_state(self, class_, instance): - delattr(instance, '_default_state') + delattr(instance, "_default_state") def state_getter(self, class_): - return lambda instance: getattr(instance, '_default_state') + return lambda instance: getattr(instance, "_default_state") def dict_getter(self, class_): return lambda inst: self.get_instance_dict(class_, inst) @@ -314,15 +325,17 @@ class _ClassInstrumentationAdapter(ClassManager): def instrument_collection_class(self, key, collection_class): return self._adapted.instrument_collection_class( - self.class_, key, collection_class) + self.class_, key, collection_class + ) def initialize_collection(self, key, state, factory): - delegate = getattr(self._adapted, 'initialize_collection', None) + delegate = getattr(self._adapted, "initialize_collection", None) if delegate: return delegate(key, state, factory) else: - return ClassManager.initialize_collection(self, key, - state, factory) + return ClassManager.initialize_collection( + self, key, state, factory + ) def new_instance(self, state=None): instance = self.class_.__new__(self.class_) @@ -384,7 +397,7 @@ def _install_instrumented_lookups(): dict( instance_state=_instrumentation_factory.state_of, instance_dict=_instrumentation_factory.dict_of, - manager_of_class=_instrumentation_factory.manager_of_class + manager_of_class=_instrumentation_factory.manager_of_class, ) ) @@ -395,7 +408,7 @@ def _reinstall_default_lookups(): dict( instance_state=_default_state_getter, instance_dict=_default_dict_getter, - manager_of_class=_default_manager_getter + manager_of_class=_default_manager_getter, ) ) _instrumentation_factory._extended = False @@ -403,12 +416,15 @@ def _reinstall_default_lookups(): def _install_lookups(lookups): global instance_state, instance_dict, manager_of_class - instance_state = lookups['instance_state'] - instance_dict = lookups['instance_dict'] - manager_of_class = lookups['manager_of_class'] - orm_base.instance_state = attributes.instance_state = \ - orm_instrumentation.instance_state = instance_state - orm_base.instance_dict = attributes.instance_dict = \ - orm_instrumentation.instance_dict = instance_dict - orm_base.manager_of_class = attributes.manager_of_class = \ - orm_instrumentation.manager_of_class = manager_of_class + instance_state = lookups["instance_state"] + instance_dict = lookups["instance_dict"] + manager_of_class = lookups["manager_of_class"] + orm_base.instance_state = ( + attributes.instance_state + ) = orm_instrumentation.instance_state = instance_state + orm_base.instance_dict = ( + attributes.instance_dict + ) = orm_instrumentation.instance_dict = instance_dict + orm_base.manager_of_class = ( + attributes.manager_of_class + ) = orm_instrumentation.manager_of_class = manager_of_class diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 014cef3cc..0f6ccdc33 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -502,27 +502,29 @@ class MutableBase(object): def pickle(state, state_dict): val = state.dict.get(key, None) if val is not None: - if 'ext.mutable.values' not in state_dict: - state_dict['ext.mutable.values'] = [] - state_dict['ext.mutable.values'].append(val) + if "ext.mutable.values" not in state_dict: + state_dict["ext.mutable.values"] = [] + state_dict["ext.mutable.values"].append(val) def unpickle(state, state_dict): - if 'ext.mutable.values' in state_dict: - for val in state_dict['ext.mutable.values']: + if "ext.mutable.values" in state_dict: + for val in state_dict["ext.mutable.values"]: val._parents[state.obj()] = key - event.listen(parent_cls, 'load', load, - raw=True, propagate=True) - event.listen(parent_cls, 'refresh', load_attrs, - raw=True, propagate=True) - event.listen(parent_cls, 'refresh_flush', load_attrs, - raw=True, propagate=True) - event.listen(attribute, 'set', set, - raw=True, retval=True, propagate=True) - event.listen(parent_cls, 'pickle', pickle, - raw=True, propagate=True) - event.listen(parent_cls, 'unpickle', unpickle, - raw=True, propagate=True) + event.listen(parent_cls, "load", load, raw=True, propagate=True) + event.listen( + parent_cls, "refresh", load_attrs, raw=True, propagate=True + ) + event.listen( + parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True + ) + event.listen( + attribute, "set", set, raw=True, retval=True, propagate=True + ) + event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True) + event.listen( + parent_cls, "unpickle", unpickle, raw=True, propagate=True + ) class Mutable(MutableBase): @@ -572,7 +574,7 @@ class Mutable(MutableBase): if isinstance(prop.columns[0].type, sqltype): cls.associate_with_attribute(getattr(class_, prop.key)) - event.listen(mapper, 'mapper_configured', listen_for_type) + event.listen(mapper, "mapper_configured", listen_for_type) @classmethod def as_mutable(cls, sqltype): @@ -613,9 +615,11 @@ class Mutable(MutableBase): # and we'll lose our ability to link that type back to the original. # so track our original type w/ columns if isinstance(sqltype, SchemaEventTarget): + @event.listens_for(sqltype, "before_parent_attach") def _add_column_memo(sqltyp, parent): - parent.info['_ext_mutable_orig_type'] = sqltyp + parent.info["_ext_mutable_orig_type"] = sqltyp + schema_event_check = True else: schema_event_check = False @@ -625,16 +629,14 @@ class Mutable(MutableBase): return for prop in mapper.column_attrs: if ( - schema_event_check and - hasattr(prop.expression, 'info') and - prop.expression.info.get('_ext_mutable_orig_type') - is sqltype - ) or ( - prop.columns[0].type is sqltype - ): + schema_event_check + and hasattr(prop.expression, "info") + and prop.expression.info.get("_ext_mutable_orig_type") + is sqltype + ) or (prop.columns[0].type is sqltype): cls.associate_with_attribute(getattr(class_, prop.key)) - event.listen(mapper, 'mapper_configured', listen_for_type) + event.listen(mapper, "mapper_configured", listen_for_type) return sqltype @@ -659,21 +661,27 @@ class MutableComposite(MutableBase): prop = object_mapper(parent).get_property(key) for value, attr_name in zip( - self.__composite_values__(), - prop._attribute_keys): + self.__composite_values__(), prop._attribute_keys + ): setattr(parent, attr_name, value) def _setup_composite_listener(): def _listen_for_type(mapper, class_): for prop in mapper.iterate_properties: - if (hasattr(prop, 'composite_class') and - isinstance(prop.composite_class, type) and - issubclass(prop.composite_class, MutableComposite)): + if ( + hasattr(prop, "composite_class") + and isinstance(prop.composite_class, type) + and issubclass(prop.composite_class, MutableComposite) + ): prop.composite_class._listen_on_attribute( - getattr(class_, prop.key), False, class_) + getattr(class_, prop.key), False, class_ + ) + if not event.contains(Mapper, "mapper_configured", _listen_for_type): - event.listen(Mapper, 'mapper_configured', _listen_for_type) + event.listen(Mapper, "mapper_configured", _listen_for_type) + + _setup_composite_listener() @@ -947,4 +955,4 @@ class MutableSet(Mutable, set): self.update(state) def __reduce_ex__(self, proto): - return (self.__class__, (list(self), )) + return (self.__class__, (list(self),)) diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 316742a67..2a8522120 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -122,7 +122,7 @@ start numbering at 1 or some other integer, provide ``count_from=1``. from ..orm.collections import collection, collection_adapter from .. import util -__all__ = ['ordering_list'] +__all__ = ["ordering_list"] def ordering_list(attr, count_from=None, **kw): @@ -180,8 +180,9 @@ def count_from_n_factory(start): def f(index, collection): return index + start + try: - f.__name__ = 'count_from_%i' % start + f.__name__ = "count_from_%i" % start except TypeError: pass return f @@ -194,14 +195,14 @@ def _unsugar_count_from(**kw): ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. """ - count_from = kw.pop('count_from', None) - if kw.get('ordering_func', None) is None and count_from is not None: + count_from = kw.pop("count_from", None) + if kw.get("ordering_func", None) is None and count_from is not None: if count_from == 0: - kw['ordering_func'] = count_from_0 + kw["ordering_func"] = count_from_0 elif count_from == 1: - kw['ordering_func'] = count_from_1 + kw["ordering_func"] = count_from_1 else: - kw['ordering_func'] = count_from_n_factory(count_from) + kw["ordering_func"] = count_from_n_factory(count_from) return kw @@ -214,8 +215,9 @@ class OrderingList(list): """ - def __init__(self, ordering_attr=None, ordering_func=None, - reorder_on_append=False): + def __init__( + self, ordering_attr=None, ordering_func=None, reorder_on_append=False + ): """A custom list that manages position information for its children. ``OrderingList`` is a ``collection_class`` list implementation that @@ -311,6 +313,7 @@ class OrderingList(list): """Append without any ordering behavior.""" super(OrderingList, self).append(entity) + _raw_append = collection.adds(1)(_raw_append) def insert(self, index, entity): @@ -361,8 +364,12 @@ class OrderingList(list): return _reconstitute, (self.__class__, self.__dict__, list(self)) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(list, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py index 2fded51d1..3adcec34f 100644 --- a/lib/sqlalchemy/ext/serializer.py +++ b/lib/sqlalchemy/ext/serializer.py @@ -64,7 +64,7 @@ from ..util import pickle, byte_buffer, b64encode, b64decode, text_type import re -__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads'] +__all__ = ["Serializer", "Deserializer", "dumps", "loads"] def Serializer(*args, **kw): @@ -79,13 +79,18 @@ def Serializer(*args, **kw): elif isinstance(obj, Mapper) and not obj.non_primary: id = "mapper:" + b64encode(pickle.dumps(obj.class_)) elif isinstance(obj, MapperProperty) and not obj.parent.non_primary: - id = "mapperprop:" + b64encode(pickle.dumps(obj.parent.class_)) + \ - ":" + obj.key + id = ( + "mapperprop:" + + b64encode(pickle.dumps(obj.parent.class_)) + + ":" + + obj.key + ) elif isinstance(obj, Table): id = "table:" + text_type(obj.key) elif isinstance(obj, Column) and isinstance(obj.table, Table): - id = "column:" + \ - text_type(obj.table.key) + ":" + text_type(obj.key) + id = ( + "column:" + text_type(obj.table.key) + ":" + text_type(obj.key) + ) elif isinstance(obj, Session): id = "session:" elif isinstance(obj, Engine): @@ -97,8 +102,10 @@ def Serializer(*args, **kw): pickler.persistent_id = persistent_id return pickler + our_ids = re.compile( - r'(mapperprop|mapper|table|column|session|attribute|engine):(.*)') + r"(mapperprop|mapper|table|column|session|attribute|engine):(.*)" +) def Deserializer(file, metadata=None, scoped_session=None, engine=None): @@ -120,7 +127,7 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): return None else: type_, args = m.group(1, 2) - if type_ == 'attribute': + if type_ == "attribute": key, clsarg = args.split(":") cls = pickle.loads(b64decode(clsarg)) return getattr(cls, key) @@ -128,13 +135,13 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): cls = pickle.loads(b64decode(args)) return class_mapper(cls) elif type_ == "mapperprop": - mapper, keyname = args.split(':') + mapper, keyname = args.split(":") cls = pickle.loads(b64decode(mapper)) return class_mapper(cls).attrs[keyname] elif type_ == "table": return metadata.tables[args] elif type_ == "column": - table, colname = args.split(':') + table, colname = args.split(":") return metadata.tables[table].c[colname] elif type_ == "session": return scoped_session() @@ -142,6 +149,7 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): return get_engine() else: raise Exception("Unknown token: %s" % type_) + unpickler.persistent_load = persistent_load return unpickler |
