diff options
Diffstat (limited to 'lib/sqlalchemy/ext')
| -rw-r--r-- | lib/sqlalchemy/ext/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/associationproxy.py | 288 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/automap.py | 309 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/baked.py | 135 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/compiler.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/declarative/__init__.py | 35 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/declarative/api.py | 117 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/declarative/base.py | 326 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/declarative/clsregistry.py | 125 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 61 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/hybrid.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/indexable.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/instrumentation.py | 78 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/mutable.py | 78 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/orderinglist.py | 29 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/serializer.py | 26 |
16 files changed, 1021 insertions, 659 deletions
diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py index 9558b2a1f..9fed09e2b 100644 --- a/lib/sqlalchemy/ext/__init__.py +++ b/lib/sqlalchemy/ext/__init__.py @@ -8,4 +8,3 @@ from .. import util as _sa_util _sa_util.dependencies.resolve_all("sqlalchemy.ext") - diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index ff9433d4d..56b91ce0b 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -76,7 +76,7 @@ def association_proxy(target_collection, attr, **kw): return AssociationProxy(target_collection, attr, **kw) -ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY') +ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY") """Symbol indicating an :class:`InspectionAttr` that's of type :class:`.AssociationProxy`. @@ -92,10 +92,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo): is_attribute = False extension_type = ASSOCIATION_PROXY - def __init__(self, target_collection, attr, creator=None, - getset_factory=None, proxy_factory=None, - proxy_bulk_set=None, info=None, - cascade_scalar_deletes=False): + def __init__( + self, + target_collection, + attr, + creator=None, + getset_factory=None, + proxy_factory=None, + proxy_bulk_set=None, + info=None, + cascade_scalar_deletes=False, + ): """Construct a new :class:`.AssociationProxy`. The :func:`.association_proxy` function is provided as the usual @@ -162,8 +169,11 @@ class AssociationProxy(interfaces.InspectionAttrInfo): self.proxy_bulk_set = proxy_bulk_set self.cascade_scalar_deletes = cascade_scalar_deletes - self.key = '_%s_%s_%s' % ( - type(self).__name__, target_collection, id(self)) + self.key = "_%s_%s_%s" % ( + type(self).__name__, + target_collection, + id(self), + ) if info: self.info = info @@ -264,12 +274,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo): def getter(target): return _getter(target) if target is not None else None + if collection_class is dict: + def setter(o, k, v): setattr(o, attr, v) + else: + def setter(o, v): setattr(o, attr, v) + return getter, setter @@ -325,20 +340,21 @@ class AssociationProxyInstance(object): def for_proxy(cls, parent, owning_class, parent_instance): target_collection = parent.target_collection value_attr = parent.value_attr - prop = orm.class_mapper(owning_class).\ - get_property(target_collection) + prop = orm.class_mapper(owning_class).get_property(target_collection) # this was never asserted before but this should be made clear. if not isinstance(prop, orm.RelationshipProperty): raise NotImplementedError( "association proxy to a non-relationship " - "intermediary is not supported") + "intermediary is not supported" + ) target_class = prop.mapper.class_ try: target_assoc = cls._cls_unwrap_target_assoc_proxy( - target_class, value_attr) + target_class, value_attr + ) except AttributeError: # the proxied attribute doesn't exist on the target class; # return an "ambiguous" instance that will work on a per-object @@ -353,8 +369,8 @@ class AssociationProxyInstance(object): @classmethod def _construct_for_assoc( - cls, target_assoc, parent, owning_class, - target_class, value_attr): + cls, target_assoc, parent, owning_class, target_class, value_attr + ): if target_assoc is not None: return ObjectAssociationProxyInstance( parent, owning_class, target_class, value_attr @@ -371,8 +387,9 @@ class AssociationProxyInstance(object): ) def _get_property(self): - return orm.class_mapper(self.owning_class).\ - get_property(self.target_collection) + return orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) @property def _comparator(self): @@ -388,7 +405,8 @@ class AssociationProxyInstance(object): @util.memoized_property def _unwrap_target_assoc_proxy(self): return self._cls_unwrap_target_assoc_proxy( - self.target_class, self.value_attr) + self.target_class, self.value_attr + ) @property def remote_attr(self): @@ -448,8 +466,11 @@ class AssociationProxyInstance(object): @util.memoized_property def _value_is_scalar(self): - return not self._get_property().\ - mapper.get_property(self.value_attr).uselist + return ( + not self._get_property() + .mapper.get_property(self.value_attr) + .uselist + ) @property def _target_is_object(self): @@ -468,12 +489,17 @@ class AssociationProxyInstance(object): def getter(target): return _getter(target) if target is not None else None + if collection_class is dict: + def setter(o, k, v): return setattr(o, attr, v) + else: + def setter(o, v): return setattr(o, attr, v) + return getter, setter @property @@ -500,14 +526,18 @@ class AssociationProxyInstance(object): return proxy self.collection_class, proxy = self._new( - _lazy_collection(obj, self.target_collection)) + _lazy_collection(obj, self.target_collection) + ) setattr(obj, self.key, (id(obj), id(self), proxy)) return proxy def set(self, obj, values): if self.scalar: - creator = self.parent.creator \ - if self.parent.creator else self.target_class + creator = ( + self.parent.creator + if self.parent.creator + else self.target_class + ) target = getattr(obj, self.target_collection) if target is None: if values is None: @@ -535,35 +565,52 @@ class AssociationProxyInstance(object): delattr(obj, self.target_collection) def _new(self, lazy_collection): - creator = self.parent.creator if self.parent.creator else \ - self.target_class + creator = ( + self.parent.creator if self.parent.creator else self.target_class + ) collection_class = util.duck_type_collection(lazy_collection()) if self.parent.proxy_factory: - return collection_class, self.parent.proxy_factory( - lazy_collection, creator, self.value_attr, self) + return ( + collection_class, + self.parent.proxy_factory( + lazy_collection, creator, self.value_attr, self + ), + ) if self.parent.getset_factory: - getter, setter = self.parent.getset_factory( - collection_class, self) + getter, setter = self.parent.getset_factory(collection_class, self) else: getter, setter = self.parent._default_getset(collection_class) if collection_class is list: - return collection_class, _AssociationList( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationList( + lazy_collection, creator, getter, setter, self + ), + ) elif collection_class is dict: - return collection_class, _AssociationDict( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationDict( + lazy_collection, creator, getter, setter, self + ), + ) elif collection_class is set: - return collection_class, _AssociationSet( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationSet( + lazy_collection, creator, getter, setter, self + ), + ) else: raise exc.ArgumentError( - 'could not guess which interface to use for ' + "could not guess which interface to use for " 'collection_class "%s" backing "%s"; specify a ' - 'proxy_factory and proxy_bulk_set manually' % - (self.collection_class.__name__, self.target_collection)) + "proxy_factory and proxy_bulk_set manually" + % (self.collection_class.__name__, self.target_collection) + ) def _set(self, proxy, values): if self.parent.proxy_bulk_set: @@ -576,16 +623,19 @@ class AssociationProxyInstance(object): proxy.update(values) else: raise exc.ArgumentError( - 'no proxy_bulk_set supplied for custom ' - 'collection_class implementation') + "no proxy_bulk_set supplied for custom " + "collection_class implementation" + ) def _inflate(self, proxy): - creator = self.parent.creator and \ - self.parent.creator or self.target_class + creator = ( + self.parent.creator and self.parent.creator or self.target_class + ) if self.parent.getset_factory: getter, setter = self.parent.getset_factory( - self.collection_class, self) + self.collection_class, self + ) else: getter, setter = self.parent._default_getset(self.collection_class) @@ -594,12 +644,13 @@ class AssociationProxyInstance(object): proxy.setter = setter def _criterion_exists(self, criterion=None, **kwargs): - is_has = kwargs.pop('is_has', None) + is_has = kwargs.pop("is_has", None) target_assoc = self._unwrap_target_assoc_proxy if target_assoc is not None: inner = target_assoc._criterion_exists( - criterion=criterion, **kwargs) + criterion=criterion, **kwargs + ) return self._comparator._criterion_exists(inner) if self._target_is_object: @@ -631,15 +682,15 @@ class AssociationProxyInstance(object): """ if self._unwrap_target_assoc_proxy is None and ( - self.scalar and ( - not self._target_is_object or self._value_is_scalar) + self.scalar + and (not self._target_is_object or self._value_is_scalar) ): raise exc.InvalidRequestError( - "'any()' not implemented for scalar " - "attributes. Use has()." + "'any()' not implemented for scalar " "attributes. Use has()." ) return self._criterion_exists( - criterion=criterion, is_has=False, **kwargs) + criterion=criterion, is_has=False, **kwargs + ) def has(self, criterion=None, **kwargs): """Produce a proxied 'has' expression using EXISTS. @@ -651,14 +702,15 @@ class AssociationProxyInstance(object): """ if self._unwrap_target_assoc_proxy is None and ( - not self.scalar or ( - self._target_is_object and not self._value_is_scalar) + not self.scalar + or (self._target_is_object and not self._value_is_scalar) ): raise exc.InvalidRequestError( - "'has()' not implemented for collections. " - "Use any().") + "'has()' not implemented for collections. " "Use any()." + ) return self._criterion_exists( - criterion=criterion, is_has=True, **kwargs) + criterion=criterion, is_has=True, **kwargs + ) class AmbiguousAssociationProxyInstance(AssociationProxyInstance): @@ -673,10 +725,14 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): "Association proxy %s.%s refers to an attribute '%s' that is not " "directly mapped on class %s; therefore this operation cannot " "proceed since we don't know what type of object is referred " - "towards" % ( - self.owning_class.__name__, self.target_collection, - self.value_attr, self.target_class - )) + "towards" + % ( + self.owning_class.__name__, + self.target_collection, + self.value_attr, + self.target_class, + ) + ) def get(self, obj): self._ambiguous() @@ -718,27 +774,32 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): return self def _populate_cache(self, instance_class): - prop = orm.class_mapper(self.owning_class).\ - get_property(self.target_collection) + prop = orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) if inspect(instance_class).mapper.isa(prop.mapper): target_class = instance_class try: target_assoc = self._cls_unwrap_target_assoc_proxy( - target_class, self.value_attr) + target_class, self.value_attr + ) except AttributeError: pass else: - self._lookup_cache[instance_class] = \ - self._construct_for_assoc( - target_assoc, self.parent, self.owning_class, - target_class, self.value_attr + self._lookup_cache[instance_class] = self._construct_for_assoc( + target_assoc, + self.parent, + self.owning_class, + target_class, + self.value_attr, ) class ObjectAssociationProxyInstance(AssociationProxyInstance): """an :class:`.AssociationProxyInstance` that has an object as a target. """ + _target_is_object = True _is_canonical = True @@ -756,17 +817,21 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): if target_assoc is not None: return self._comparator._criterion_exists( target_assoc.contains(obj) - if not target_assoc.scalar else target_assoc == obj + if not target_assoc.scalar + else target_assoc == obj ) - elif self._target_is_object and self.scalar and \ - not self._value_is_scalar: + elif ( + self._target_is_object + and self.scalar + and not self._value_is_scalar + ): return self._comparator.has( getattr(self.target_class, self.value_attr).contains(obj) ) - elif self._target_is_object and self.scalar and \ - self._value_is_scalar: + elif self._target_is_object and self.scalar and self._value_is_scalar: raise exc.InvalidRequestError( - "contains() doesn't apply to a scalar object endpoint; use ==") + "contains() doesn't apply to a scalar object endpoint; use ==" + ) else: return self._comparator._criterion_exists(**{self.value_attr: obj}) @@ -777,7 +842,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): if obj is None: return or_( self._comparator.has(**{self.value_attr: obj}), - self._comparator == None + self._comparator == None, ) else: return self._comparator.has(**{self.value_attr: obj}) @@ -786,14 +851,17 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): # note the has() here will fail for collections; eq_() # is only allowed with a scalar. return self._comparator.has( - getattr(self.target_class, self.value_attr) != obj) + getattr(self.target_class, self.value_attr) != obj + ) class ColumnAssociationProxyInstance( - ColumnOperators, AssociationProxyInstance): + ColumnOperators, AssociationProxyInstance +): """an :class:`.AssociationProxyInstance` that has a database column as a target. """ + _target_is_object = False _is_canonical = True @@ -803,9 +871,7 @@ class ColumnAssociationProxyInstance( self.remote_attr.operate(operator.eq, other) ) if other is None: - return or_( - expr, self._comparator == None - ) + return or_(expr, self._comparator == None) else: return expr @@ -824,11 +890,11 @@ class _lazy_collection(object): return getattr(self.parent, self.target) def __getstate__(self): - return {'obj': self.parent, 'target': self.target} + return {"obj": self.parent, "target": self.target} def __setstate__(self, state): - self.parent = state['obj'] - self.target = state['target'] + self.parent = state["obj"] + self.target = state["target"] class _AssociationCollection(object): @@ -874,11 +940,11 @@ class _AssociationCollection(object): __nonzero__ = __bool__ def __getstate__(self): - return {'parent': self.parent, 'lazy_collection': self.lazy_collection} + return {"parent": self.parent, "lazy_collection": self.lazy_collection} def __setstate__(self, state): - self.parent = state['parent'] - self.lazy_collection = state['lazy_collection'] + self.parent = state["parent"] + self.lazy_collection = state["lazy_collection"] self.parent._inflate(self) @@ -925,8 +991,8 @@ class _AssociationList(_AssociationCollection): if len(value) != len(rng): raise ValueError( "attempt to assign sequence of size %s to " - "extended slice of size %s" % (len(value), - len(rng))) + "extended slice of size %s" % (len(value), len(rng)) + ) for i, item in zip(rng, value): self._set(self.col[i], item) @@ -968,8 +1034,14 @@ class _AssociationList(_AssociationCollection): col.append(item) def count(self, value): - return sum([1 for _ in - util.itertools_filter(lambda v: v == value, iter(self))]) + return sum( + [ + 1 + for _ in util.itertools_filter( + lambda v: v == value, iter(self) + ) + ] + ) def extend(self, values): for v in values: @@ -999,7 +1071,7 @@ class _AssociationList(_AssociationCollection): raise NotImplementedError def clear(self): - del self.col[0:len(self.col)] + del self.col[0 : len(self.col)] def __eq__(self, other): return list(self) == other @@ -1040,6 +1112,7 @@ class _AssociationList(_AssociationCollection): if not isinstance(n, int): return NotImplemented return list(self) * n + __rmul__ = __mul__ def __iadd__(self, iterable): @@ -1072,13 +1145,17 @@ class _AssociationList(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(list, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func -_NotProvided = util.symbol('_NotProvided') +_NotProvided = util.symbol("_NotProvided") class _AssociationDict(_AssociationCollection): @@ -1160,6 +1237,7 @@ class _AssociationDict(_AssociationCollection): return self.col.keys() if util.py2k: + def iteritems(self): return ((key, self._get(self.col[key])) for key in self.col) @@ -1174,7 +1252,9 @@ class _AssociationDict(_AssociationCollection): def items(self): return [(k, self._get(self.col[k])) for k in self] + else: + def items(self): return ((key, self._get(self.col[key])) for key in self.col) @@ -1194,14 +1274,15 @@ class _AssociationDict(_AssociationCollection): def update(self, *a, **kw): if len(a) > 1: - raise TypeError('update expected at most 1 arguments, got %i' % - len(a)) + raise TypeError( + "update expected at most 1 arguments, got %i" % len(a) + ) elif len(a) == 1: seq_or_map = a[0] # discern dict from sequence - took the advice from # http://www.voidspace.org.uk/python/articles/duck_typing.shtml # still not perfect :( - if hasattr(seq_or_map, 'keys'): + if hasattr(seq_or_map, "keys"): for item in seq_or_map: self[item] = seq_or_map[item] else: @@ -1211,7 +1292,8 @@ class _AssociationDict(_AssociationCollection): except ValueError: raise ValueError( "dictionary update sequence " - "requires 2-element tuples") + "requires 2-element tuples" + ) for key, value in kw: self[key] = value @@ -1223,8 +1305,12 @@ class _AssociationDict(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(dict, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(dict, func_name) + ): func.__doc__ = getattr(dict, func_name).__doc__ del func_name, func @@ -1288,7 +1374,7 @@ class _AssociationSet(_AssociationCollection): def pop(self): if not self.col: - raise KeyError('pop from an empty set') + raise KeyError("pop from an empty set") member = self.col.pop() return self._get(member) @@ -1420,7 +1506,11 @@ class _AssociationSet(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(set, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(set, func_name) + ): func.__doc__ = getattr(set, func_name).__doc__ del func_name, func diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index cafb3d61c..747373a2a 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -580,7 +580,8 @@ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): def name_for_collection_relationship( - base, local_cls, referred_cls, constraint): + base, local_cls, referred_cls, constraint +): """Return the attribute name that should be used to refer from one class to another, for a collection reference. @@ -607,7 +608,8 @@ def name_for_collection_relationship( def generate_relationship( - base, direction, return_fn, attrname, local_cls, referred_cls, **kw): + base, direction, return_fn, attrname, local_cls, referred_cls, **kw +): r"""Generate a :func:`.relationship` or :func:`.backref` on behalf of two mapped classes. @@ -677,6 +679,7 @@ class AutomapBase(object): :ref:`automap_toplevel` """ + __abstract__ = True classes = None @@ -694,15 +697,16 @@ class AutomapBase(object): @classmethod def prepare( - cls, - engine=None, - reflect=False, - schema=None, - classname_for_table=classname_for_table, - collection_class=list, - name_for_scalar_relationship=name_for_scalar_relationship, - name_for_collection_relationship=name_for_collection_relationship, - generate_relationship=generate_relationship): + cls, + engine=None, + reflect=False, + schema=None, + classname_for_table=classname_for_table, + collection_class=list, + name_for_scalar_relationship=name_for_scalar_relationship, + name_for_collection_relationship=name_for_collection_relationship, + generate_relationship=generate_relationship, + ): """Extract mapped classes and relationships from the :class:`.MetaData` and perform mappings. @@ -752,15 +756,16 @@ class AutomapBase(object): engine, schema=schema, extend_existing=True, - autoload_replace=False + autoload_replace=False, ) _CONFIGURE_MUTEX.acquire() try: table_to_map_config = dict( (m.local_table, m) - for m in _DeferredMapperConfig. - classes_for_base(cls, sort=False) + for m in _DeferredMapperConfig.classes_for_base( + cls, sort=False + ) ) many_to_many = [] @@ -774,30 +779,39 @@ class AutomapBase(object): elif table not in table_to_map_config: mapped_cls = type( classname_for_table(cls, table.name, table), - (cls, ), - {"__table__": table} + (cls,), + {"__table__": table}, ) map_config = _DeferredMapperConfig.config_for_cls( - mapped_cls) + mapped_cls + ) cls.classes[map_config.cls.__name__] = mapped_cls table_to_map_config[table] = map_config for map_config in table_to_map_config.values(): - _relationships_for_fks(cls, - map_config, - table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship) + _relationships_for_fks( + cls, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) for lcl_m2m, rem_m2m, m2m_const, table in many_to_many: - _m2m_relationship(cls, lcl_m2m, rem_m2m, m2m_const, table, - table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship) + _m2m_relationship( + cls, + lcl_m2m, + rem_m2m, + m2m_const, + table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) for map_config in _DeferredMapperConfig.classes_for_base(cls): map_config.map() @@ -853,20 +867,27 @@ def automap_base(declarative_base=None, **kw): return type( Base.__name__, - (AutomapBase, Base,), - {"__abstract__": True, "classes": util.Properties({})} + (AutomapBase, Base), + {"__abstract__": True, "classes": util.Properties({})}, ) def _is_many_to_many(automap_base, table): - fk_constraints = [const for const in table.constraints - if isinstance(const, ForeignKeyConstraint)] + fk_constraints = [ + const + for const in table.constraints + if isinstance(const, ForeignKeyConstraint) + ] if len(fk_constraints) != 2: return None, None, None cols = sum( - [[fk.parent for fk in fk_constraint.elements] - for fk_constraint in fk_constraints], []) + [ + [fk.parent for fk in fk_constraint.elements] + for fk_constraint in fk_constraints + ], + [], + ) if set(cols) != set(table.c): return None, None, None @@ -874,15 +895,19 @@ def _is_many_to_many(automap_base, table): return ( fk_constraints[0].elements[0].column.table, fk_constraints[1].elements[0].column.table, - fk_constraints + fk_constraints, ) -def _relationships_for_fks(automap_base, map_config, table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship): +def _relationships_for_fks( + automap_base, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, +): local_table = map_config.local_table local_cls = map_config.cls # derived from a weakref, may be None @@ -898,32 +923,33 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, referred_cls = referred_cfg.cls if local_cls is not referred_cls and issubclass( - local_cls, referred_cls): + local_cls, referred_cls + ): continue relationship_name = name_for_scalar_relationship( - automap_base, - local_cls, - referred_cls, constraint) + automap_base, local_cls, referred_cls, constraint + ) backref_name = name_for_collection_relationship( - automap_base, - referred_cls, - local_cls, - constraint + automap_base, referred_cls, local_cls, constraint ) o2m_kws = {} nullable = False not in {fk.parent.nullable for fk in fks} if not nullable: - o2m_kws['cascade'] = "all, delete-orphan" + o2m_kws["cascade"] = "all, delete-orphan" - if constraint.ondelete and \ - constraint.ondelete.lower() == "cascade": - o2m_kws['passive_deletes'] = True + if ( + constraint.ondelete + and constraint.ondelete.lower() == "cascade" + ): + o2m_kws["passive_deletes"] = True else: - if constraint.ondelete and \ - constraint.ondelete.lower() == "set null": - o2m_kws['passive_deletes'] = True + if ( + constraint.ondelete + and constraint.ondelete.lower() == "set null" + ): + o2m_kws["passive_deletes"] = True create_backref = backref_name not in referred_cfg.properties @@ -931,54 +957,65 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, if create_backref: backref_obj = generate_relationship( automap_base, - interfaces.ONETOMANY, backref, - backref_name, referred_cls, local_cls, + interfaces.ONETOMANY, + backref, + backref_name, + referred_cls, + local_cls, collection_class=collection_class, - **o2m_kws) + **o2m_kws + ) else: backref_obj = None - rel = generate_relationship(automap_base, - interfaces.MANYTOONE, - relationship, - relationship_name, - local_cls, referred_cls, - foreign_keys=[ - fk.parent - for fk in constraint.elements], - backref=backref_obj, - remote_side=[ - fk.column - for fk in constraint.elements] - ) + rel = generate_relationship( + automap_base, + interfaces.MANYTOONE, + relationship, + relationship_name, + local_cls, + referred_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + backref=backref_obj, + remote_side=[fk.column for fk in constraint.elements], + ) if rel is not None: map_config.properties[relationship_name] = rel if not create_backref: referred_cfg.properties[ - backref_name].back_populates = relationship_name + backref_name + ].back_populates = relationship_name elif create_backref: - rel = generate_relationship(automap_base, - interfaces.ONETOMANY, - relationship, - backref_name, - referred_cls, local_cls, - foreign_keys=[ - fk.parent - for fk in constraint.elements], - back_populates=relationship_name, - collection_class=collection_class, - **o2m_kws) + rel = generate_relationship( + automap_base, + interfaces.ONETOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + back_populates=relationship_name, + collection_class=collection_class, + **o2m_kws + ) if rel is not None: referred_cfg.properties[backref_name] = rel map_config.properties[ - relationship_name].back_populates = backref_name - - -def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table, - table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship): + relationship_name + ].back_populates = backref_name + + +def _m2m_relationship( + automap_base, + lcl_m2m, + rem_m2m, + m2m_const, + table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, +): map_config = table_to_map_config.get(lcl_m2m, None) referred_cfg = table_to_map_config.get(rem_m2m, None) @@ -989,14 +1026,10 @@ def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table, referred_cls = referred_cfg.cls relationship_name = name_for_collection_relationship( - automap_base, - local_cls, - referred_cls, m2m_const[0]) + automap_base, local_cls, referred_cls, m2m_const[0] + ) backref_name = name_for_collection_relationship( - automap_base, - referred_cls, - local_cls, - m2m_const[1] + automap_base, referred_cls, local_cls, m2m_const[1] ) create_backref = backref_name not in referred_cfg.properties @@ -1008,48 +1041,56 @@ def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table, interfaces.MANYTOMANY, backref, backref_name, - referred_cls, local_cls, - collection_class=collection_class + referred_cls, + local_cls, + collection_class=collection_class, ) else: backref_obj = None - rel = generate_relationship(automap_base, - interfaces.MANYTOMANY, - relationship, - relationship_name, - local_cls, referred_cls, - secondary=table, - primaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[0].elements), - secondaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[1].elements), - backref=backref_obj, - collection_class=collection_class - ) + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + relationship_name, + local_cls, + referred_cls, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), + backref=backref_obj, + collection_class=collection_class, + ) if rel is not None: map_config.properties[relationship_name] = rel if not create_backref: referred_cfg.properties[ - backref_name].back_populates = relationship_name + backref_name + ].back_populates = relationship_name elif create_backref: - rel = generate_relationship(automap_base, - interfaces.MANYTOMANY, - relationship, - backref_name, - referred_cls, local_cls, - secondary=table, - primaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[1].elements), - secondaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[0].elements), - back_populates=relationship_name, - collection_class=collection_class) + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), + back_populates=relationship_name, + collection_class=collection_class, + ) if rel is not None: referred_cfg.properties[backref_name] = rel map_config.properties[ - relationship_name].back_populates = backref_name + relationship_name + ].back_populates = backref_name diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 516879142..f55231a09 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -38,7 +38,8 @@ class Bakery(object): """ - __slots__ = 'cls', 'cache' + + __slots__ = "cls", "cache" def __init__(self, cls_, cache): self.cls = cls_ @@ -51,7 +52,7 @@ class Bakery(object): class BakedQuery(object): """A builder object for :class:`.query.Query` objects.""" - __slots__ = 'steps', '_bakery', '_cache_key', '_spoiled' + __slots__ = "steps", "_bakery", "_cache_key", "_spoiled" def __init__(self, bakery, initial_fn, args=()): self._cache_key = () @@ -148,7 +149,7 @@ class BakedQuery(object): """ if not full and not self._spoiled: _spoil_point = self._clone() - _spoil_point._cache_key += ('_query_only', ) + _spoil_point._cache_key += ("_query_only",) self.steps = [_spoil_point._retrieve_baked_query] self._spoiled = True return self @@ -164,7 +165,7 @@ class BakedQuery(object): session will want to use. """ - return self._cache_key + (session._query_cls, ) + return self._cache_key + (session._query_cls,) def _with_lazyload_options(self, options, effective_path, cache_path=None): """Cloning version of _add_lazyload_options. @@ -201,16 +202,20 @@ class BakedQuery(object): key += cache_key self.add_criteria( - lambda q: q._with_current_path(effective_path). - _conditional_options(*options), - cache_path.path, key + lambda q: q._with_current_path( + effective_path + )._conditional_options(*options), + cache_path.path, + key, ) def _retrieve_baked_query(self, session): query = self._bakery.get(self._effective_key(session), None) if query is None: query = self._as_query(session) - self._bakery[self._effective_key(session)] = query.with_session(None) + self._bakery[self._effective_key(session)] = query.with_session( + None + ) return query.with_session(session) def _bake(self, session): @@ -227,8 +232,12 @@ class BakedQuery(object): # so delete some compilation-use-only attributes that can take up # space for attr in ( - '_correlate', '_from_obj', '_mapper_adapter_map', - '_joinpath', '_joinpoint'): + "_correlate", + "_from_obj", + "_mapper_adapter_map", + "_joinpath", + "_joinpoint", + ): query.__dict__.pop(attr, None) self._bakery[self._effective_key(session)] = context return context @@ -276,11 +285,13 @@ class BakedQuery(object): session = query_or_session.session if session is None: raise sa_exc.ArgumentError( - "Given Query needs to be associated with a Session") + "Given Query needs to be associated with a Session" + ) else: raise TypeError( - "Query or Session object expected, got %r." % - type(query_or_session)) + "Query or Session object expected, got %r." + % type(query_or_session) + ) return self._as_query(session) def _as_query(self, session): @@ -299,10 +310,10 @@ class BakedQuery(object): a "baked" query so that we save on performance too. """ - context.attributes['baked_queries'] = baked_queries = [] + context.attributes["baked_queries"] = baked_queries = [] for k, v in list(context.attributes.items()): if isinstance(v, Query): - if 'subquery' in k: + if "subquery" in k: bk = BakedQuery(self._bakery, lambda *args: v) bk._cache_key = self._cache_key + k bk._bake(session) @@ -310,15 +321,17 @@ class BakedQuery(object): del context.attributes[k] def _unbake_subquery_loaders( - self, session, context, params, post_criteria): + self, session, context, params, post_criteria + ): """Retrieve subquery eager loaders stored by _bake_subquery_loaders and turn them back into Result objects that will iterate just like a Query object. """ for k, cache_key, query in context.attributes["baked_queries"]: - bk = BakedQuery(self._bakery, - lambda sess, q=query: q.with_session(sess)) + bk = BakedQuery( + self._bakery, lambda sess, q=query: q.with_session(sess) + ) bk._cache_key = cache_key q = bk.for_session(session) for fn in post_criteria: @@ -334,7 +347,8 @@ class Result(object): against a target :class:`.Session`, and is then invoked for results. """ - __slots__ = 'bq', 'session', '_params', '_post_criteria' + + __slots__ = "bq", "session", "_params", "_post_criteria" def __init__(self, bq, session): self.bq = bq @@ -350,7 +364,8 @@ class Result(object): elif len(args) > 0: raise sa_exc.ArgumentError( "params() takes zero or one positional argument, " - "which is a dictionary.") + "which is a dictionary." + ) self._params.update(kw) return self @@ -403,7 +418,8 @@ class Result(object): context.attributes = context.attributes.copy() bq._unbake_subquery_loaders( - self.session, context, self._params, self._post_criteria) + self.session, context, self._params, self._post_criteria + ) context.statement.use_labels = True if context.autoflush and not context.populate_existing: @@ -426,7 +442,7 @@ class Result(object): """ - col = func.count(literal_column('*')) + col = func.count(literal_column("*")) bq = self.bq.with_criteria(lambda q: q.from_self(col)) return bq.for_session(self.session).params(self._params).scalar() @@ -456,8 +472,10 @@ class Result(object): """ bq = self.bq.with_criteria(lambda q: q.slice(0, 1)) ret = list( - bq.for_session(self.session).params(self._params). - _using_post_criteria(self._post_criteria)) + bq.for_session(self.session) + .params(self._params) + ._using_post_criteria(self._post_criteria) + ) if len(ret) > 0: return ret[0] else: @@ -473,7 +491,8 @@ class Result(object): ret = self.one_or_none() except orm_exc.MultipleResultsFound: raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one()") + "Multiple rows were found for one()" + ) else: if ret is None: raise orm_exc.NoResultFound("No row was found for one()") @@ -497,7 +516,8 @@ class Result(object): return None else: raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one_or_none()") + "Multiple rows were found for one_or_none()" + ) def all(self): """Return all rows. @@ -533,13 +553,18 @@ class Result(object): # None present in ident - turn those comparisons # into "IS NULL" if None in primary_key_identity: - nones = set([ - _get_params[col].key for col, value in - zip(mapper.primary_key, primary_key_identity) - if value is None - ]) + nones = set( + [ + _get_params[col].key + for col, value in zip( + mapper.primary_key, primary_key_identity + ) + if value is None + ] + ) _lcl_get_clause = sql_util.adapt_criterion_to_null( - _lcl_get_clause, nones) + _lcl_get_clause, nones + ) _lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False) q._criterion = _lcl_get_clause @@ -556,16 +581,20 @@ class Result(object): # key so that if a race causes multiple calls to _get_clause, # we've cached on ours bq = bq._clone() - bq._cache_key += (_get_clause, ) + bq._cache_key += (_get_clause,) bq = bq.with_criteria( - setup, tuple(elem is None for elem in primary_key_identity)) + setup, tuple(elem is None for elem in primary_key_identity) + ) - params = dict([ - (_get_params[primary_key].key, id_val) - for id_val, primary_key - in zip(primary_key_identity, mapper.primary_key) - ]) + params = dict( + [ + (_get_params[primary_key].key, id_val) + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + ] + ) result = list(bq.for_session(self.session).params(**params)) l = len(result) @@ -578,7 +607,8 @@ class Result(object): @util.deprecated( - "1.2", "Baked lazy loading is now the default implementation.") + "1.2", "Baked lazy loading is now the default implementation." +) def bake_lazy_loaders(): """Enable the use of baked queries for all lazyloaders systemwide. @@ -590,7 +620,8 @@ def bake_lazy_loaders(): @util.deprecated( - "1.2", "Baked lazy loading is now the default implementation.") + "1.2", "Baked lazy loading is now the default implementation." +) def unbake_lazy_loaders(): """Disable the use of baked queries for all lazyloaders systemwide. @@ -601,7 +632,8 @@ def unbake_lazy_loaders(): """ raise NotImplementedError( - "Baked lazy loading is now the default implementation") + "Baked lazy loading is now the default implementation" + ) @strategy_options.loader_option() @@ -615,20 +647,27 @@ def baked_lazyload(loadopt, attr): @baked_lazyload._add_unbound_fn @util.deprecated( - "1.2", "Baked lazy loading is now the default " - "implementation for lazy loading.") + "1.2", + "Baked lazy loading is now the default " + "implementation for lazy loading.", +) def baked_lazyload(*keys): return strategy_options._UnboundLoad._from_keys( - strategy_options._UnboundLoad.baked_lazyload, keys, False, {}) + strategy_options._UnboundLoad.baked_lazyload, keys, False, {} + ) @baked_lazyload._add_unbound_all_fn @util.deprecated( - "1.2", "Baked lazy loading is now the default " - "implementation for lazy loading.") + "1.2", + "Baked lazy loading is now the default " + "implementation for lazy loading.", +) def baked_lazyload_all(*keys): return strategy_options._UnboundLoad._from_keys( - strategy_options._UnboundLoad.baked_lazyload, keys, True, {}) + strategy_options._UnboundLoad.baked_lazyload, keys, True, {} + ) + baked_lazyload = baked_lazyload._unbound_fn baked_lazyload_all = baked_lazyload_all._unbound_all_fn diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 6a0909d36..220b2c057 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -407,37 +407,44 @@ def compiles(class_, *specs): def decorate(fn): # get an existing @compiles handler - existing = class_.__dict__.get('_compiler_dispatcher', None) + existing = class_.__dict__.get("_compiler_dispatcher", None) # get the original handler. All ClauseElement classes have one # of these, but some TypeEngine classes will not. - existing_dispatch = getattr(class_, '_compiler_dispatch', None) + existing_dispatch = getattr(class_, "_compiler_dispatch", None) if not existing: existing = _dispatcher() if existing_dispatch: + def _wrap_existing_dispatch(element, compiler, **kw): try: return existing_dispatch(element, compiler, **kw) except exc.UnsupportedCompilationError: raise exc.CompileError( "%s construct has no default " - "compilation handler." % type(element)) - existing.specs['default'] = _wrap_existing_dispatch + "compilation handler." % type(element) + ) + + existing.specs["default"] = _wrap_existing_dispatch # TODO: why is the lambda needed ? - setattr(class_, '_compiler_dispatch', - lambda *arg, **kw: existing(*arg, **kw)) - setattr(class_, '_compiler_dispatcher', existing) + setattr( + class_, + "_compiler_dispatch", + lambda *arg, **kw: existing(*arg, **kw), + ) + setattr(class_, "_compiler_dispatcher", existing) if specs: for s in specs: existing.specs[s] = fn else: - existing.specs['default'] = fn + existing.specs["default"] = fn return fn + return decorate @@ -445,7 +452,7 @@ def deregister(class_): """Remove all custom compilers associated with a given :class:`.ClauseElement` type.""" - if hasattr(class_, '_compiler_dispatcher'): + if hasattr(class_, "_compiler_dispatcher"): # regenerate default _compiler_dispatch visitors._generate_dispatch(class_) # remove custom directive @@ -461,10 +468,11 @@ class _dispatcher(object): fn = self.specs.get(compiler.dialect.name, None) if not fn: try: - fn = self.specs['default'] + fn = self.specs["default"] except KeyError: raise exc.CompileError( "%s construct has no default " - "compilation handler." % type(element)) + "compilation handler." % type(element) + ) return fn(element, compiler, **kw) diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py index cb81f51e5..2b0a37884 100644 --- a/lib/sqlalchemy/ext/declarative/__init__.py +++ b/lib/sqlalchemy/ext/declarative/__init__.py @@ -5,14 +5,31 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from .api import declarative_base, synonym_for, comparable_using, \ - instrument_declarative, ConcreteBase, AbstractConcreteBase, \ - DeclarativeMeta, DeferredReflection, has_inherited_table,\ - declared_attr, as_declarative +from .api import ( + declarative_base, + synonym_for, + comparable_using, + instrument_declarative, + ConcreteBase, + AbstractConcreteBase, + DeclarativeMeta, + DeferredReflection, + has_inherited_table, + declared_attr, + as_declarative, +) -__all__ = ['declarative_base', 'synonym_for', 'has_inherited_table', - 'comparable_using', 'instrument_declarative', 'declared_attr', - 'as_declarative', - 'ConcreteBase', 'AbstractConcreteBase', 'DeclarativeMeta', - 'DeferredReflection'] +__all__ = [ + "declarative_base", + "synonym_for", + "has_inherited_table", + "comparable_using", + "instrument_declarative", + "declared_attr", + "as_declarative", + "ConcreteBase", + "AbstractConcreteBase", + "DeclarativeMeta", + "DeferredReflection", +] diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py index 865cd16f0..987e92119 100644 --- a/lib/sqlalchemy/ext/declarative/api.py +++ b/lib/sqlalchemy/ext/declarative/api.py @@ -8,9 +8,13 @@ from ...schema import Table, MetaData, Column -from ...orm import synonym as _orm_synonym, \ - comparable_property,\ - interfaces, properties, attributes +from ...orm import ( + synonym as _orm_synonym, + comparable_property, + interfaces, + properties, + attributes, +) from ...orm.util import polymorphic_union from ...orm.base import _mapper_or_none from ...util import OrderedDict, hybridmethod, hybridproperty @@ -19,9 +23,13 @@ from ... import exc import weakref import re -from .base import _as_declarative, \ - _declarative_constructor,\ - _DeferredMapperConfig, _add_attribute, _del_attribute +from .base import ( + _as_declarative, + _declarative_constructor, + _DeferredMapperConfig, + _add_attribute, + _del_attribute, +) from .clsregistry import _class_resolver @@ -31,10 +39,10 @@ def instrument_declarative(cls, registry, metadata): MetaData object. """ - if '_decl_class_registry' in cls.__dict__: + if "_decl_class_registry" in cls.__dict__: raise exc.InvalidRequestError( - "Class %r already has been " - "instrumented declaratively" % cls) + "Class %r already has been " "instrumented declaratively" % cls + ) cls._decl_class_registry = registry cls.metadata = metadata _as_declarative(cls, cls.__name__, cls.__dict__) @@ -54,14 +62,14 @@ def has_inherited_table(cls): """ for class_ in cls.__mro__[1:]: - if getattr(class_, '__table__', None) is not None: + if getattr(class_, "__table__", None) is not None: return True return False class DeclarativeMeta(type): def __init__(cls, classname, bases, dict_): - if '_decl_class_registry' not in cls.__dict__: + if "_decl_class_registry" not in cls.__dict__: _as_declarative(cls, classname, cls.__dict__) type.__init__(cls, classname, bases, dict_) @@ -71,6 +79,7 @@ class DeclarativeMeta(type): def __delattr__(cls, key): _del_attribute(cls, key) + def synonym_for(name, map_column=False): """Decorator that produces an :func:`.orm.synonym` attribute in conjunction with a Python descriptor. @@ -104,8 +113,10 @@ def synonym_for(name, map_column=False): can be achieved with synonyms. """ + def decorate(fn): return _orm_synonym(name, map_column=map_column, descriptor=fn) + return decorate @@ -127,8 +138,10 @@ def comparable_using(comparator_factory): prop = comparable_property(MyComparatorType) """ + def decorate(fn): return comparable_property(comparator_factory, fn) + return decorate @@ -190,14 +203,16 @@ class declared_attr(interfaces._MappedAttribute, property): self._cascading = cascading def __get__(desc, self, cls): - reg = cls.__dict__.get('_sa_declared_attr_reg', None) + reg = cls.__dict__.get("_sa_declared_attr_reg", None) if reg is None: - if not re.match(r'^__.+__$', desc.fget.__name__) and \ - attributes.manager_of_class(cls) is None: + if ( + not re.match(r"^__.+__$", desc.fget.__name__) + and attributes.manager_of_class(cls) is None + ): util.warn( "Unmanaged access of declarative attribute %s from " - "non-mapped class %s" % - (desc.fget.__name__, cls.__name__)) + "non-mapped class %s" % (desc.fget.__name__, cls.__name__) + ) return desc.fget(cls) elif desc in reg: return reg[desc] @@ -283,10 +298,16 @@ class _stateful_declared_attr(declared_attr): return declared_attr(fn, **self.kw) -def declarative_base(bind=None, metadata=None, mapper=None, cls=object, - name='Base', constructor=_declarative_constructor, - class_registry=None, - metaclass=DeclarativeMeta): +def declarative_base( + bind=None, + metadata=None, + mapper=None, + cls=object, + name="Base", + constructor=_declarative_constructor, + class_registry=None, + metaclass=DeclarativeMeta, +): r"""Construct a base class for declarative class definitions. The new base class will be given a metaclass that produces @@ -357,16 +378,17 @@ def declarative_base(bind=None, metadata=None, mapper=None, cls=object, class_registry = weakref.WeakValueDictionary() bases = not isinstance(cls, tuple) and (cls,) or cls - class_dict = dict(_decl_class_registry=class_registry, - metadata=lcl_metadata) + class_dict = dict( + _decl_class_registry=class_registry, metadata=lcl_metadata + ) if isinstance(cls, type): - class_dict['__doc__'] = cls.__doc__ + class_dict["__doc__"] = cls.__doc__ if constructor: - class_dict['__init__'] = constructor + class_dict["__init__"] = constructor if mapper: - class_dict['__mapper_cls__'] = mapper + class_dict["__mapper_cls__"] = mapper return metaclass(name, bases, class_dict) @@ -401,9 +423,10 @@ def as_declarative(**kw): :func:`.declarative_base` """ + def decorate(cls): - kw['cls'] = cls - kw['name'] = cls.__name__ + kw["cls"] = cls + kw["name"] = cls.__name__ return declarative_base(**kw) return decorate @@ -456,10 +479,13 @@ class ConcreteBase(object): @classmethod def _create_polymorphic_union(cls, mappers): - return polymorphic_union(OrderedDict( - (mp.polymorphic_identity, mp.local_table) - for mp in mappers - ), 'type', 'pjoin') + return polymorphic_union( + OrderedDict( + (mp.polymorphic_identity, mp.local_table) for mp in mappers + ), + "type", + "pjoin", + ) @classmethod def __declare_first__(cls): @@ -568,7 +594,7 @@ class AbstractConcreteBase(ConcreteBase): @classmethod def _sa_decl_prepare_nocascade(cls): - if getattr(cls, '__mapper__', None): + if getattr(cls, "__mapper__", None): return to_map = _DeferredMapperConfig.config_for_cls(cls) @@ -604,8 +630,9 @@ class AbstractConcreteBase(ConcreteBase): def mapper_args(): args = m_args() - args['polymorphic_on'] = pjoin.c.type + args["polymorphic_on"] = pjoin.c.type return args + to_map.mapper_args_fn = mapper_args m = to_map.map() @@ -684,6 +711,7 @@ class DeferredReflection(object): .. versionadded:: 0.8 """ + @classmethod def prepare(cls, engine): """Reflect all :class:`.Table` objects for all current @@ -696,8 +724,10 @@ class DeferredReflection(object): mapper = thingy.cls.__mapper__ metadata = mapper.class_.metadata for rel in mapper._props.values(): - if isinstance(rel, properties.RelationshipProperty) and \ - rel.secondary is not None: + if ( + isinstance(rel, properties.RelationshipProperty) + and rel.secondary is not None + ): if isinstance(rel.secondary, Table): cls._reflect_table(rel.secondary, engine) elif isinstance(rel.secondary, _class_resolver): @@ -711,6 +741,7 @@ class DeferredReflection(object): t1 = Table(key, metadata) cls._reflect_table(t1, engine) return t1 + return _resolve @classmethod @@ -724,10 +755,12 @@ class DeferredReflection(object): @classmethod def _reflect_table(cls, table, engine): - Table(table.name, - table.metadata, - extend_existing=True, - autoload_replace=False, - autoload=True, - autoload_with=engine, - schema=table.schema) + Table( + table.name, + table.metadata, + extend_existing=True, + autoload_replace=False, + autoload=True, + autoload_with=engine, + schema=table.schema, + ) diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py index f27314b5e..07778f733 100644 --- a/lib/sqlalchemy/ext/declarative/base.py +++ b/lib/sqlalchemy/ext/declarative/base.py @@ -39,7 +39,7 @@ def _resolve_for_abstract_or_classical(cls): if cls is object: return None - if _get_immediate_cls_attr(cls, '__abstract__', strict=True): + if _get_immediate_cls_attr(cls, "__abstract__", strict=True): for sup in cls.__bases__: sup = _resolve_for_abstract_or_classical(sup) if sup is not None: @@ -59,7 +59,7 @@ def _dive_for_classically_mapped_class(cls): # if we are within a base hierarchy, don't # search at all for classical mappings - if hasattr(cls, '_decl_class_registry'): + if hasattr(cls, "_decl_class_registry"): return None manager = instrumentation.manager_of_class(cls) @@ -89,15 +89,19 @@ def _get_immediate_cls_attr(cls, attrname, strict=False): return None for base in cls.__mro__: - _is_declarative_inherits = hasattr(base, '_decl_class_registry') - _is_classicial_inherits = not _is_declarative_inherits and \ - _dive_for_classically_mapped_class(base) is not None + _is_declarative_inherits = hasattr(base, "_decl_class_registry") + _is_classicial_inherits = ( + not _is_declarative_inherits + and _dive_for_classically_mapped_class(base) is not None + ) if attrname in base.__dict__ and ( - base is cls or - ((base in cls.__bases__ if strict else True) + base is cls + or ( + (base in cls.__bases__ if strict else True) and not _is_declarative_inherits - and not _is_classicial_inherits) + and not _is_classicial_inherits + ) ): return getattr(base, attrname) else: @@ -108,9 +112,10 @@ def _as_declarative(cls, classname, dict_): global declared_attr, declarative_props if declared_attr is None: from .api import declared_attr + declarative_props = (declared_attr, util.classproperty) - if _get_immediate_cls_attr(cls, '__abstract__', strict=True): + if _get_immediate_cls_attr(cls, "__abstract__", strict=True): return _MapperConfig.setup_mapping(cls, classname, dict_) @@ -119,23 +124,23 @@ def _as_declarative(cls, classname, dict_): def _check_declared_props_nocascade(obj, name, cls): if isinstance(obj, declarative_props): - if getattr(obj, '_cascading', False): + if getattr(obj, "_cascading", False): util.warn( "@declared_attr.cascading is not supported on the %s " "attribute on class %s. This attribute invokes for " - "subclasses in any case." % (name, cls)) + "subclasses in any case." % (name, cls) + ) return True else: return False class _MapperConfig(object): - @classmethod def setup_mapping(cls, cls_, classname, dict_): defer_map = _get_immediate_cls_attr( - cls_, '_sa_decl_prepare_nocascade', strict=True) or \ - hasattr(cls_, '_sa_decl_prepare') + cls_, "_sa_decl_prepare_nocascade", strict=True + ) or hasattr(cls_, "_sa_decl_prepare") if defer_map: cfg_cls = _DeferredMapperConfig @@ -179,12 +184,14 @@ class _MapperConfig(object): self.map() def _setup_declared_events(self): - if _get_immediate_cls_attr(self.cls, '__declare_last__'): + if _get_immediate_cls_attr(self.cls, "__declare_last__"): + @event.listens_for(mapper, "after_configured") def after_configured(): self.cls.__declare_last__() - if _get_immediate_cls_attr(self.cls, '__declare_first__'): + if _get_immediate_cls_attr(self.cls, "__declare_first__"): + @event.listens_for(mapper, "before_configured") def before_configured(): self.cls.__declare_first__() @@ -198,59 +205,62 @@ class _MapperConfig(object): tablename = None for base in cls.__mro__: - class_mapped = base is not cls and \ - _declared_mapping_info(base) is not None and \ - not _get_immediate_cls_attr( - base, '_sa_decl_prepare_nocascade', strict=True) + class_mapped = ( + base is not cls + and _declared_mapping_info(base) is not None + and not _get_immediate_cls_attr( + base, "_sa_decl_prepare_nocascade", strict=True + ) + ) if not class_mapped and base is not cls: self._produce_column_copies(base) for name, obj in vars(base).items(): - if name == '__mapper_args__': - check_decl = \ - _check_declared_props_nocascade(obj, name, cls) - if not mapper_args_fn and ( - not class_mapped or - check_decl - ): + if name == "__mapper_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not mapper_args_fn and (not class_mapped or check_decl): # don't even invoke __mapper_args__ until # after we've determined everything about the # mapped table. # make a copy of it so a class-level dictionary # is not overwritten when we update column-based # arguments. - mapper_args_fn = lambda: dict(cls.__mapper_args__) # noqa - elif name == '__tablename__': - check_decl = \ - _check_declared_props_nocascade(obj, name, cls) - if not tablename and ( - not class_mapped or - check_decl - ): + mapper_args_fn = lambda: dict( + cls.__mapper_args__ + ) # noqa + elif name == "__tablename__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not tablename and (not class_mapped or check_decl): tablename = cls.__tablename__ - elif name == '__table_args__': - check_decl = \ - _check_declared_props_nocascade(obj, name, cls) - if not table_args and ( - not class_mapped or - check_decl - ): + elif name == "__table_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not table_args and (not class_mapped or check_decl): table_args = cls.__table_args__ if not isinstance( - table_args, (tuple, dict, type(None))): + table_args, (tuple, dict, type(None)) + ): raise exc.ArgumentError( "__table_args__ value must be a tuple, " - "dict, or None") + "dict, or None" + ) if base is not cls: inherited_table_args = True elif class_mapped: if isinstance(obj, declarative_props): - util.warn("Regular (i.e. not __special__) " - "attribute '%s.%s' uses @declared_attr, " - "but owning class %s is mapped - " - "not applying to subclass %s." - % (base.__name__, name, base, cls)) + util.warn( + "Regular (i.e. not __special__) " + "attribute '%s.%s' uses @declared_attr, " + "but owning class %s is mapped - " + "not applying to subclass %s." + % (base.__name__, name, base, cls) + ) continue elif base is not cls: # we're a mixin, abstract base, or something that is @@ -263,7 +273,8 @@ class _MapperConfig(object): "Mapper properties (i.e. deferred," "column_property(), relationship(), etc.) must " "be declared as @declared_attr callables " - "on declarative mixin classes.") + "on declarative mixin classes." + ) elif isinstance(obj, declarative_props): oldclassprop = isinstance(obj, util.classproperty) if not oldclassprop and obj._cascading: @@ -278,15 +289,18 @@ class _MapperConfig(object): "Attribute '%s' on class %s cannot be " "processed due to " "@declared_attr.cascading; " - "skipping" % (name, cls)) - dict_[name] = column_copies[obj] = \ - ret = obj.__get__(obj, cls) + "skipping" % (name, cls) + ) + dict_[name] = column_copies[ + obj + ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) else: if oldclassprop: util.warn_deprecated( "Use of sqlalchemy.util.classproperty on " - "declarative classes is deprecated.") + "declarative classes is deprecated." + ) # access attribute using normal class access ret = getattr(cls, name) @@ -294,14 +308,20 @@ class _MapperConfig(object): # or similar. note there is no known case that # produces nested proxies, so we are only # looking one level deep right now. - if isinstance(ret, InspectionAttr) and \ - ret._is_internal_proxy and not isinstance( - ret.original_property, MapperProperty): + if ( + isinstance(ret, InspectionAttr) + and ret._is_internal_proxy + and not isinstance( + ret.original_property, MapperProperty + ) + ): ret = ret.descriptor dict_[name] = column_copies[obj] = ret - if isinstance(ret, (Column, MapperProperty)) and \ - ret.doc is None: + if ( + isinstance(ret, (Column, MapperProperty)) + and ret.doc is None + ): ret.doc = obj.__doc__ # here, the attribute is some other kind of property that # we assume is not part of the declarative mapping. @@ -321,8 +341,9 @@ class _MapperConfig(object): util.warn( "Attribute '%s' on class %s appears to be a non-schema " "'sqlalchemy.sql.column()' " - "object; this won't be part of the declarative mapping" % - (key, cls)) + "object; this won't be part of the declarative mapping" + % (key, cls) + ) def _produce_column_copies(self, base): cls = self.cls @@ -340,10 +361,11 @@ class _MapperConfig(object): raise exc.InvalidRequestError( "Columns with foreign keys to other columns " "must be declared as @declared_attr callables " - "on declarative mixin classes. ") + "on declarative mixin classes. " + ) elif name not in dict_ and not ( - '__table__' in dict_ and - (obj.name or name) in dict_['__table__'].c + "__table__" in dict_ + and (obj.name or name) in dict_["__table__"].c ): column_copies[obj] = copy_ = obj.copy() copy_._creation_order = obj._creation_order @@ -357,11 +379,12 @@ class _MapperConfig(object): our_stuff = self.properties late_mapped = _get_immediate_cls_attr( - cls, '_sa_decl_prepare_nocascade', strict=True) + cls, "_sa_decl_prepare_nocascade", strict=True + ) for k in list(dict_): - if k in ('__table__', '__tablename__', '__mapper_args__'): + if k in ("__table__", "__tablename__", "__mapper_args__"): continue value = dict_[k] @@ -371,29 +394,37 @@ class _MapperConfig(object): "Use of @declared_attr.cascading only applies to " "Declarative 'mixin' and 'abstract' classes. " "Currently, this flag is ignored on mapped class " - "%s" % self.cls) + "%s" % self.cls + ) value = getattr(cls, k) - elif isinstance(value, QueryableAttribute) and \ - value.class_ is not cls and \ - value.key != k: + elif ( + isinstance(value, QueryableAttribute) + and value.class_ is not cls + and value.key != k + ): # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() value = synonym(value.key) setattr(cls, k, value) - if (isinstance(value, tuple) and len(value) == 1 and - isinstance(value[0], (Column, MapperProperty))): - util.warn("Ignoring declarative-like tuple value of attribute " - "'%s': possibly a copy-and-paste error with a comma " - "accidentally placed at the end of the line?" % k) + if ( + isinstance(value, tuple) + and len(value) == 1 + and isinstance(value[0], (Column, MapperProperty)) + ): + util.warn( + "Ignoring declarative-like tuple value of attribute " + "'%s': possibly a copy-and-paste error with a comma " + "accidentally placed at the end of the line?" % k + ) continue elif not isinstance(value, (Column, MapperProperty)): # using @declared_attr for some object that # isn't Column/MapperProperty; remove from the dict_ # and place the evaluated value onto the class. - if not k.startswith('__'): + if not k.startswith("__"): dict_.pop(k) self._warn_for_decl_attributes(cls, k, value) if not late_mapped: @@ -402,7 +433,7 @@ class _MapperConfig(object): # we expect to see the name 'metadata' in some valid cases; # however at this point we see it's assigned to something trying # to be mapped, so raise for that. - elif k == 'metadata': + elif k == "metadata": raise exc.InvalidRequestError( "Attribute name 'metadata' is reserved " "for the MetaData instance when using a " @@ -423,8 +454,7 @@ class _MapperConfig(object): for key, c in list(our_stuff.items()): if isinstance(c, (ColumnProperty, CompositeProperty)): for col in c.columns: - if isinstance(col, Column) and \ - col.table is None: + if isinstance(col, Column) and col.table is None: _undefer_column_name(key, col) if not isinstance(c, CompositeProperty): name_to_prop_key[col.name].add(key) @@ -447,8 +477,8 @@ class _MapperConfig(object): "On class %r, Column object %r named " "directly multiple times, " "only one will be used: %s. " - "Consider using orm.synonym instead" % - (self.classname, name, (", ".join(sorted(keys)))) + "Consider using orm.synonym instead" + % (self.classname, name, (", ".join(sorted(keys)))) ) def _setup_table(self): @@ -459,15 +489,16 @@ class _MapperConfig(object): declared_columns = self.declared_columns declared_columns = self.declared_columns = sorted( - declared_columns, key=lambda c: c._creation_order) + declared_columns, key=lambda c: c._creation_order + ) table = None - if hasattr(cls, '__table_cls__'): + if hasattr(cls, "__table_cls__"): table_cls = util.unbound_method_to_callable(cls.__table_cls__) else: table_cls = Table - if '__table__' not in dict_: + if "__table__" not in dict_: if tablename is not None: args, table_kw = (), {} @@ -480,14 +511,16 @@ class _MapperConfig(object): else: args = table_args - autoload = dict_.get('__autoload__') + autoload = dict_.get("__autoload__") if autoload: - table_kw['autoload'] = True + table_kw["autoload"] = True cls.__table__ = table = table_cls( - tablename, cls.metadata, + tablename, + cls.metadata, *(tuple(declared_columns) + tuple(args)), - **table_kw) + **table_kw + ) else: table = cls.__table__ if declared_columns: @@ -512,21 +545,27 @@ class _MapperConfig(object): c = _resolve_for_abstract_or_classical(c) if c is None: continue - if _declared_mapping_info(c) is not None and \ - not _get_immediate_cls_attr( - c, '_sa_decl_prepare_nocascade', strict=True): + if _declared_mapping_info( + c + ) is not None and not _get_immediate_cls_attr( + c, "_sa_decl_prepare_nocascade", strict=True + ): inherits.append(c) if inherits: if len(inherits) > 1: raise exc.InvalidRequestError( - "Class %s has multiple mapped bases: %r" % (cls, inherits)) + "Class %s has multiple mapped bases: %r" % (cls, inherits) + ) self.inherits = inherits[0] else: self.inherits = None - if table is None and self.inherits is None and \ - not _get_immediate_cls_attr(cls, '__no_table__'): + if ( + table is None + and self.inherits is None + and not _get_immediate_cls_attr(cls, "__no_table__") + ): raise exc.InvalidRequestError( "Class %r does not have a __table__ or __tablename__ " @@ -553,8 +592,8 @@ class _MapperConfig(object): continue raise exc.ArgumentError( "Column '%s' on class %s conflicts with " - "existing column '%s'" % - (c, cls, inherited_table.c[c.name]) + "existing column '%s'" + % (c, cls, inherited_table.c[c.name]) ) if c.primary_key: raise exc.ArgumentError( @@ -562,8 +601,10 @@ class _MapperConfig(object): "class with no table." ) inherited_table.append_column(c) - if inherited_mapped_table is not None and \ - inherited_mapped_table is not inherited_table: + if ( + inherited_mapped_table is not None + and inherited_mapped_table is not inherited_table + ): inherited_mapped_table._refresh_for_new_column(c) def _prepare_mapper_arguments(self): @@ -575,18 +616,19 @@ class _MapperConfig(object): # make sure that column copies are used rather # than the original columns from any mixins - for k in ('version_id_col', 'polymorphic_on',): + for k in ("version_id_col", "polymorphic_on"): if k in mapper_args: v = mapper_args[k] mapper_args[k] = self.column_copies.get(v, v) - assert 'inherits' not in mapper_args, \ - "Can't specify 'inherits' explicitly with declarative mappings" + assert ( + "inherits" not in mapper_args + ), "Can't specify 'inherits' explicitly with declarative mappings" if self.inherits: - mapper_args['inherits'] = self.inherits + mapper_args["inherits"] = self.inherits - if self.inherits and not mapper_args.get('concrete', False): + if self.inherits and not mapper_args.get("concrete", False): # single or joined inheritance # exclude any cols on the inherited table which are # not mapped on the parent class, to avoid @@ -594,16 +636,17 @@ class _MapperConfig(object): inherited_mapper = _declared_mapping_info(self.inherits) inherited_table = inherited_mapper.local_table - if 'exclude_properties' not in mapper_args: - mapper_args['exclude_properties'] = exclude_properties = \ - set( - [c.key for c in inherited_table.c - if c not in inherited_mapper._columntoproperty] - ).union( - inherited_mapper.exclude_properties or () - ) + if "exclude_properties" not in mapper_args: + mapper_args["exclude_properties"] = exclude_properties = set( + [ + c.key + for c in inherited_table.c + if c not in inherited_mapper._columntoproperty + ] + ).union(inherited_mapper.exclude_properties or ()) exclude_properties.difference_update( - [c.key for c in self.declared_columns]) + [c.key for c in self.declared_columns] + ) # look through columns in the current mapper that # are keyed to a propname different than the colname @@ -621,21 +664,20 @@ class _MapperConfig(object): # first. See [ticket:1892] for background. properties[k] = [col] + p.columns result_mapper_args = mapper_args.copy() - result_mapper_args['properties'] = properties + result_mapper_args["properties"] = properties self.mapper_args = result_mapper_args def map(self): self._prepare_mapper_arguments() - if hasattr(self.cls, '__mapper_cls__'): + if hasattr(self.cls, "__mapper_cls__"): mapper_cls = util.unbound_method_to_callable( - self.cls.__mapper_cls__) + self.cls.__mapper_cls__ + ) else: mapper_cls = mapper self.cls.__mapper__ = mp_ = mapper_cls( - self.cls, - self.local_table, - **self.mapper_args + self.cls, self.local_table, **self.mapper_args ) del self.cls._sa_declared_attr_reg return mp_ @@ -663,8 +705,7 @@ class _DeferredMapperConfig(_MapperConfig): @classmethod def has_cls(cls, class_): # 2.6 fails on weakref if class_ is an old style class - return isinstance(class_, type) and \ - weakref.ref(class_) in cls._configs + return isinstance(class_, type) and weakref.ref(class_) in cls._configs @classmethod def config_for_cls(cls, class_): @@ -673,18 +714,15 @@ class _DeferredMapperConfig(_MapperConfig): @classmethod def classes_for_base(cls, base_cls, sort=True): classes_for_base = [ - m for m, cls_ in - [(m, m.cls) for m in cls._configs.values()] + m + for m, cls_ in [(m, m.cls) for m in cls._configs.values()] if cls_ is not None and issubclass(cls_, base_cls) ] if not sort: return classes_for_base - all_m_by_cls = dict( - (m.cls, m) - for m in classes_for_base - ) + all_m_by_cls = dict((m.cls, m) for m in classes_for_base) tuples = [] for m_cls in all_m_by_cls: @@ -693,12 +731,7 @@ class _DeferredMapperConfig(_MapperConfig): for base_cls in m_cls.__bases__ if base_cls in all_m_by_cls ) - return list( - topological.sort( - tuples, - classes_for_base - ) - ) + return list(topological.sort(tuples, classes_for_base)) def map(self): self._configs.pop(self._cls, None) @@ -713,7 +746,7 @@ def _add_attribute(cls, key, value): """ - if '__mapper__' in cls.__dict__: + if "__mapper__" in cls.__dict__: if isinstance(value, Column): _undefer_column_name(key, value) cls.__table__.append_column(value) @@ -726,16 +759,14 @@ def _add_attribute(cls, key, value): cls.__mapper__.add_property(key, value) elif isinstance(value, MapperProperty): cls.__mapper__.add_property( - key, - clsregistry._deferred_relationship(cls, value) + key, clsregistry._deferred_relationship(cls, value) ) elif isinstance(value, QueryableAttribute) and value.key != key: # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() value = synonym(value.key) cls.__mapper__.add_property( - key, - clsregistry._deferred_relationship(cls, value) + key, clsregistry._deferred_relationship(cls, value) ) else: type.__setattr__(cls, key, value) @@ -746,15 +777,18 @@ def _add_attribute(cls, key, value): def _del_attribute(cls, key): - if '__mapper__' in cls.__dict__ and \ - key in cls.__dict__ and not cls.__mapper__._dispose_called: + if ( + "__mapper__" in cls.__dict__ + and key in cls.__dict__ + and not cls.__mapper__._dispose_called + ): value = cls.__dict__[key] if isinstance( - value, - (Column, ColumnProperty, MapperProperty, QueryableAttribute) + value, (Column, ColumnProperty, MapperProperty, QueryableAttribute) ): raise NotImplementedError( - "Can't un-map individual mapped attributes on a mapped class.") + "Can't un-map individual mapped attributes on a mapped class." + ) else: type.__delattr__(cls, key) cls.__mapper__._expire_memoizations() @@ -776,10 +810,12 @@ def _declarative_constructor(self, **kwargs): for k in kwargs: if not hasattr(cls_, k): raise TypeError( - "%r is an invalid keyword argument for %s" % - (k, cls_.__name__)) + "%r is an invalid keyword argument for %s" % (k, cls_.__name__) + ) setattr(self, k, kwargs[k]) -_declarative_constructor.__name__ = '__init__' + + +_declarative_constructor.__name__ = "__init__" def _undefer_column_name(key, column): diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index e941b9ed3..c52ae4a2f 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -10,8 +10,11 @@ This system allows specification of classes and expressions used in :func:`.relationship` using strings. """ -from ...orm.properties import ColumnProperty, RelationshipProperty, \ - SynonymProperty +from ...orm.properties import ( + ColumnProperty, + RelationshipProperty, + SynonymProperty, +) from ...schema import _get_table_key from ...orm import class_mapper, interfaces from ... import util @@ -35,17 +38,18 @@ def add_class(classname, cls): # class already exists. existing = cls._decl_class_registry[classname] if not isinstance(existing, _MultipleClassMarker): - existing = \ - cls._decl_class_registry[classname] = \ - _MultipleClassMarker([cls, existing]) + existing = cls._decl_class_registry[ + classname + ] = _MultipleClassMarker([cls, existing]) else: cls._decl_class_registry[classname] = cls try: - root_module = cls._decl_class_registry['_sa_module_registry'] + root_module = cls._decl_class_registry["_sa_module_registry"] except KeyError: - cls._decl_class_registry['_sa_module_registry'] = \ - root_module = _ModuleMarker('_sa_module_registry', None) + cls._decl_class_registry[ + "_sa_module_registry" + ] = root_module = _ModuleMarker("_sa_module_registry", None) tokens = cls.__module__.split(".") @@ -71,12 +75,13 @@ class _MultipleClassMarker(object): """ - __slots__ = 'on_remove', 'contents', '__weakref__' + __slots__ = "on_remove", "contents", "__weakref__" def __init__(self, classes, on_remove=None): self.on_remove = on_remove - self.contents = set([ - weakref.ref(item, self._remove_item) for item in classes]) + self.contents = set( + [weakref.ref(item, self._remove_item) for item in classes] + ) _registries.add(self) def __iter__(self): @@ -85,10 +90,10 @@ class _MultipleClassMarker(object): def attempt_get(self, path, key): if len(self.contents) > 1: raise exc.InvalidRequestError( - "Multiple classes found for path \"%s\" " + 'Multiple classes found for path "%s" ' "in the registry of this declarative " - "base. Please use a fully module-qualified path." % - (".".join(path + [key])) + "base. Please use a fully module-qualified path." + % (".".join(path + [key])) ) else: ref = list(self.contents)[0] @@ -108,17 +113,19 @@ class _MultipleClassMarker(object): # protect against class registration race condition against # asynchronous garbage collection calling _remove_item, # [ticket:3208] - modules = set([ - cls.__module__ for cls in - [ref() for ref in self.contents] if cls is not None]) + modules = set( + [ + cls.__module__ + for cls in [ref() for ref in self.contents] + if cls is not None + ] + ) if item.__module__ in modules: util.warn( "This declarative base already contains a class with the " "same class name and module name as %s.%s, and will " - "be replaced in the string-lookup table." % ( - item.__module__, - item.__name__ - ) + "be replaced in the string-lookup table." + % (item.__module__, item.__name__) ) self.contents.add(weakref.ref(item, self._remove_item)) @@ -129,7 +136,7 @@ class _ModuleMarker(object): """ - __slots__ = 'parent', 'name', 'contents', 'mod_ns', 'path', '__weakref__' + __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" def __init__(self, name, parent): self.parent = parent @@ -170,13 +177,13 @@ class _ModuleMarker(object): existing = self.contents[name] existing.add_item(cls) else: - existing = self.contents[name] = \ - _MultipleClassMarker([cls], - on_remove=lambda: self._remove_item(name)) + existing = self.contents[name] = _MultipleClassMarker( + [cls], on_remove=lambda: self._remove_item(name) + ) class _ModNS(object): - __slots__ = '__parent', + __slots__ = ("__parent",) def __init__(self, parent): self.__parent = parent @@ -193,13 +200,14 @@ class _ModNS(object): else: assert isinstance(value, _MultipleClassMarker) return value.attempt_get(self.__parent.path, key) - raise AttributeError("Module %r has no mapped classes " - "registered under the name %r" % ( - self.__parent.name, key)) + raise AttributeError( + "Module %r has no mapped classes " + "registered under the name %r" % (self.__parent.name, key) + ) class _GetColumns(object): - __slots__ = 'cls', + __slots__ = ("cls",) def __init__(self, cls): self.cls = cls @@ -210,7 +218,8 @@ class _GetColumns(object): if key not in mp.all_orm_descriptors: raise exc.InvalidRequestError( "Class %r does not have a mapped column named %r" - % (self.cls, key)) + % (self.cls, key) + ) desc = mp.all_orm_descriptors[key] if desc.extension_type is interfaces.NOT_EXTENSION: @@ -221,24 +230,25 @@ class _GetColumns(object): raise exc.InvalidRequestError( "Property %r is not an instance of" " ColumnProperty (i.e. does not correspond" - " directly to a Column)." % key) + " directly to a Column)." % key + ) return getattr(self.cls, key) + inspection._inspects(_GetColumns)( - lambda target: inspection.inspect(target.cls)) + lambda target: inspection.inspect(target.cls) +) class _GetTable(object): - __slots__ = 'key', 'metadata' + __slots__ = "key", "metadata" def __init__(self, key, metadata): self.key = key self.metadata = metadata def __getattr__(self, key): - return self.metadata.tables[ - _get_table_key(key, self.key) - ] + return self.metadata.tables[_get_table_key(key, self.key)] def _determine_container(key, value): @@ -264,9 +274,11 @@ class _class_resolver(object): return cls.metadata.tables[key] elif key in cls.metadata._schemas: return _GetTable(key, cls.metadata) - elif '_sa_module_registry' in cls._decl_class_registry and \ - key in cls._decl_class_registry['_sa_module_registry']: - registry = cls._decl_class_registry['_sa_module_registry'] + elif ( + "_sa_module_registry" in cls._decl_class_registry + and key in cls._decl_class_registry["_sa_module_registry"] + ): + registry = cls._decl_class_registry["_sa_module_registry"] return registry.resolve_attr(key) elif self._resolvers: for resolv in self._resolvers: @@ -289,8 +301,8 @@ class _class_resolver(object): "When initializing mapper %s, expression %r failed to " "locate a name (%r). If this is a class name, consider " "adding this relationship() to the %r class after " - "both dependent classes have been defined." % - (self.prop.parent, self.arg, n.args[0], self.cls) + "both dependent classes have been defined." + % (self.prop.parent, self.arg, n.args[0], self.cls) ) @@ -299,10 +311,11 @@ def _resolver(cls, prop): from sqlalchemy.orm import foreign, remote fallback = sqlalchemy.__dict__.copy() - fallback.update({'foreign': foreign, 'remote': remote}) + fallback.update({"foreign": foreign, "remote": remote}) def resolve_arg(arg): return _class_resolver(cls, prop, fallback, arg) + return resolve_arg @@ -311,18 +324,32 @@ def _deferred_relationship(cls, prop): if isinstance(prop, RelationshipProperty): resolve_arg = _resolver(cls, prop) - for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin', - 'secondary', '_user_defined_foreign_keys', 'remote_side'): + for attr in ( + "argument", + "order_by", + "primaryjoin", + "secondaryjoin", + "secondary", + "_user_defined_foreign_keys", + "remote_side", + ): v = getattr(prop, attr) if isinstance(v, util.string_types): setattr(prop, attr, resolve_arg(v)) if prop.backref and isinstance(prop.backref, tuple): key, kwargs = prop.backref - for attr in ('primaryjoin', 'secondaryjoin', 'secondary', - 'foreign_keys', 'remote_side', 'order_by'): - if attr in kwargs and isinstance(kwargs[attr], - util.string_types): + for attr in ( + "primaryjoin", + "secondaryjoin", + "secondary", + "foreign_keys", + "remote_side", + "order_by", + ): + if attr in kwargs and isinstance( + kwargs[attr], util.string_types + ): kwargs[attr] = resolve_arg(kwargs[attr]) return prop diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index f86e4fc93..7248e5b4d 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -20,7 +20,7 @@ from .. import util from ..orm.session import Session from ..orm.query import Query -__all__ = ['ShardedSession', 'ShardedQuery'] +__all__ = ["ShardedSession", "ShardedQuery"] class ShardedQuery(Query): @@ -43,12 +43,10 @@ class ShardedQuery(Query): def _execute_and_instances(self, context): def iter_for_shard(shard_id): - context.attributes['shard_id'] = context.identity_token = shard_id + context.attributes["shard_id"] = context.identity_token = shard_id result = self._connection_from_session( - mapper=self._bind_mapper(), - shard_id=shard_id).execute( - context.statement, - self._params) + mapper=self._bind_mapper(), shard_id=shard_id + ).execute(context.statement, self._params) return self.instances(result, context) if context.identity_token is not None: @@ -70,7 +68,8 @@ class ShardedQuery(Query): mapper=mapper, shard_id=shard_id, clause=stmt, - close_with_result=True) + close_with_result=True, + ) result = conn.execute(stmt, self._params) return result @@ -87,8 +86,13 @@ class ShardedQuery(Query): return ShardedResult(results, rowcount) def _identity_lookup( - self, mapper, primary_key_identity, identity_token=None, - lazy_loaded_from=None, **kw): + self, + mapper, + primary_key_identity, + identity_token=None, + lazy_loaded_from=None, + **kw + ): """override the default Query._identity_lookup method so that we search for a given non-token primary key identity across all possible identity tokens (e.g. shard ids). @@ -97,8 +101,10 @@ class ShardedQuery(Query): if identity_token is not None: return super(ShardedQuery, self)._identity_lookup( - mapper, primary_key_identity, - identity_token=identity_token, **kw + mapper, + primary_key_identity, + identity_token=identity_token, + **kw ) else: q = self.session.query(mapper) @@ -113,13 +119,13 @@ class ShardedQuery(Query): return None - def _get_impl( - self, primary_key_identity, db_load_fn, identity_token=None): + def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): """Override the default Query._get_impl() method so that we emit a query to the DB for each possible identity token, if we don't have one already. """ + def _db_load_fn(query, primary_key_identity): # load from the database. The original db_load_fn will # use the given Query object to load from the DB, so our @@ -142,7 +148,8 @@ class ShardedQuery(Query): identity_token = self._shard_id return super(ShardedQuery, self)._get_impl( - primary_key_identity, _db_load_fn, identity_token=identity_token) + primary_key_identity, _db_load_fn, identity_token=identity_token + ) class ShardedResult(object): @@ -158,7 +165,7 @@ class ShardedResult(object): .. versionadded:: 1.3 """ - __slots__ = ('result_proxies', 'aggregate_rowcount',) + __slots__ = ("result_proxies", "aggregate_rowcount") def __init__(self, result_proxies, aggregate_rowcount): self.result_proxies = result_proxies @@ -168,9 +175,17 @@ class ShardedResult(object): def rowcount(self): return self.aggregate_rowcount + class ShardedSession(Session): - def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, - query_cls=ShardedQuery, **kwargs): + def __init__( + self, + shard_chooser, + id_chooser, + query_chooser, + shards=None, + query_cls=ShardedQuery, + **kwargs + ): """Construct a ShardedSession. :param shard_chooser: A callable which, passed a Mapper, a mapped @@ -225,16 +240,16 @@ class ShardedSession(Session): return self.transaction.connection(mapper, shard_id=shard_id) else: return self.get_bind( - mapper, - shard_id=shard_id, - instance=instance + mapper, shard_id=shard_id, instance=instance ).contextual_connect(**kwargs) - def get_bind(self, mapper, shard_id=None, - instance=None, clause=None, **kw): + def get_bind( + self, mapper, shard_id=None, instance=None, clause=None, **kw + ): if shard_id is None: shard_id = self._choose_shard_and_assign( - mapper, instance, clause=clause) + mapper, instance, clause=clause + ) return self.__binds[shard_id] def bind_shard(self, shard_id, bind): diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 95eecb93f..d51a083da 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -778,7 +778,7 @@ there's probably a whole lot of amazing things it can be used for. from .. import util from ..orm import attributes, interfaces -HYBRID_METHOD = util.symbol('HYBRID_METHOD') +HYBRID_METHOD = util.symbol("HYBRID_METHOD") """Symbol indicating an :class:`InspectionAttr` that's of type :class:`.hybrid_method`. @@ -791,7 +791,7 @@ HYBRID_METHOD = util.symbol('HYBRID_METHOD') """ -HYBRID_PROPERTY = util.symbol('HYBRID_PROPERTY') +HYBRID_PROPERTY = util.symbol("HYBRID_PROPERTY") """Symbol indicating an :class:`InspectionAttr` that's of type :class:`.hybrid_method`. @@ -860,8 +860,14 @@ class hybrid_property(interfaces.InspectionAttrInfo): extension_type = HYBRID_PROPERTY def __init__( - self, fget, fset=None, fdel=None, - expr=None, custom_comparator=None, update_expr=None): + self, + fget, + fset=None, + fdel=None, + expr=None, + custom_comparator=None, + update_expr=None, + ): """Create a new :class:`.hybrid_property`. Usage is typically via decorator:: @@ -906,7 +912,8 @@ class hybrid_property(interfaces.InspectionAttrInfo): defaults = { key: value for key, value in self.__dict__.items() - if not key.startswith("_")} + if not key.startswith("_") + } defaults.update(**kw) return type(self)(**defaults) @@ -1078,9 +1085,9 @@ class hybrid_property(interfaces.InspectionAttrInfo): return self._get_expr(self.fget) def _get_expr(self, expr): - def _expr(cls): return ExprComparator(cls, expr(cls), self) + util.update_wrapper(_expr, expr) return self._get_comparator(_expr) @@ -1091,8 +1098,13 @@ class hybrid_property(interfaces.InspectionAttrInfo): def expr_comparator(owner): return proxy_attr( - owner, self.__name__, self, comparator(owner), - doc=comparator.__doc__ or self.__doc__) + owner, + self.__name__, + self, + comparator(owner), + doc=comparator.__doc__ or self.__doc__, + ) + return expr_comparator @@ -1108,7 +1120,7 @@ class Comparator(interfaces.PropComparator): def __clause_element__(self): expr = self.expression - if hasattr(expr, '__clause_element__'): + if hasattr(expr, "__clause_element__"): expr = expr.__clause_element__() return expr diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py index 0bc2b65bb..368e5b00a 100644 --- a/lib/sqlalchemy/ext/indexable.py +++ b/lib/sqlalchemy/ext/indexable.py @@ -232,7 +232,7 @@ from ..orm.attributes import flag_modified from ..ext.hybrid import hybrid_property -__all__ = ['index_property'] +__all__ = ["index_property"] class index_property(hybrid_property): # noqa @@ -251,8 +251,14 @@ class index_property(hybrid_property): # noqa _NO_DEFAULT_ARGUMENT = object() def __init__( - self, attr_name, index, default=_NO_DEFAULT_ARGUMENT, - datatype=None, mutable=True, onebased=True): + self, + attr_name, + index, + default=_NO_DEFAULT_ARGUMENT, + datatype=None, + mutable=True, + onebased=True, + ): """Create a new :class:`.index_property`. :param attr_name: diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index 30a0ab7d7..b2b8dd7c5 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -28,15 +28,18 @@ see the example :ref:`examples_instrumentation`. """ from ..orm import instrumentation as orm_instrumentation from ..orm.instrumentation import ( - ClassManager, InstrumentationFactory, _default_state_getter, - _default_dict_getter, _default_manager_getter + ClassManager, + InstrumentationFactory, + _default_state_getter, + _default_dict_getter, + _default_manager_getter, ) from ..orm import attributes, collections, base as orm_base from .. import util from ..orm import exc as orm_exc import weakref -INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__' +INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__" """Attribute, elects custom instrumentation when present on a mapped class. Allows a class to specify a slightly or wildly different technique for @@ -66,6 +69,7 @@ def find_native_user_instrumentation_hook(cls): """Find user-specified instrumentation management for a class.""" return getattr(cls, INSTRUMENTATION_MANAGER, None) + instrumentation_finders = [find_native_user_instrumentation_hook] """An extensible sequence of callables which return instrumentation implementations @@ -89,6 +93,7 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): class managers. """ + _manager_finders = weakref.WeakKeyDictionary() _state_finders = weakref.WeakKeyDictionary() _dict_finders = weakref.WeakKeyDictionary() @@ -104,13 +109,15 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): return None, None def _check_conflicts(self, class_, factory): - existing_factories = self._collect_management_factories_for(class_).\ - difference([factory]) + existing_factories = self._collect_management_factories_for( + class_ + ).difference([factory]) if existing_factories: raise TypeError( "multiple instrumentation implementations specified " - "in %s inheritance hierarchy: %r" % ( - class_.__name__, list(existing_factories))) + "in %s inheritance hierarchy: %r" + % (class_.__name__, list(existing_factories)) + ) def _extended_class_manager(self, class_, factory): manager = factory(class_) @@ -178,17 +185,20 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): if instance is None: raise AttributeError("None has no persistent state.") return self._state_finders.get( - instance.__class__, _default_state_getter)(instance) + instance.__class__, _default_state_getter + )(instance) def dict_of(self, instance): if instance is None: raise AttributeError("None has no persistent state.") return self._dict_finders.get( - instance.__class__, _default_dict_getter)(instance) + instance.__class__, _default_dict_getter + )(instance) -orm_instrumentation._instrumentation_factory = \ - _instrumentation_factory = ExtendedInstrumentationRegistry() +orm_instrumentation._instrumentation_factory = ( + _instrumentation_factory +) = ExtendedInstrumentationRegistry() orm_instrumentation.instrumentation_finders = instrumentation_finders @@ -222,14 +232,15 @@ class InstrumentationManager(object): pass def manage(self, class_, manager): - setattr(class_, '_default_class_manager', manager) + setattr(class_, "_default_class_manager", manager) def dispose(self, class_, manager): - delattr(class_, '_default_class_manager') + delattr(class_, "_default_class_manager") def manager_getter(self, class_): def get(cls): return cls._default_class_manager + return get def instrument_attribute(self, class_, key, inst): @@ -260,13 +271,13 @@ class InstrumentationManager(object): pass def install_state(self, class_, instance, state): - setattr(instance, '_default_state', state) + setattr(instance, "_default_state", state) def remove_state(self, class_, instance): - delattr(instance, '_default_state') + delattr(instance, "_default_state") def state_getter(self, class_): - return lambda instance: getattr(instance, '_default_state') + return lambda instance: getattr(instance, "_default_state") def dict_getter(self, class_): return lambda inst: self.get_instance_dict(class_, inst) @@ -314,15 +325,17 @@ class _ClassInstrumentationAdapter(ClassManager): def instrument_collection_class(self, key, collection_class): return self._adapted.instrument_collection_class( - self.class_, key, collection_class) + self.class_, key, collection_class + ) def initialize_collection(self, key, state, factory): - delegate = getattr(self._adapted, 'initialize_collection', None) + delegate = getattr(self._adapted, "initialize_collection", None) if delegate: return delegate(key, state, factory) else: - return ClassManager.initialize_collection(self, key, - state, factory) + return ClassManager.initialize_collection( + self, key, state, factory + ) def new_instance(self, state=None): instance = self.class_.__new__(self.class_) @@ -384,7 +397,7 @@ def _install_instrumented_lookups(): dict( instance_state=_instrumentation_factory.state_of, instance_dict=_instrumentation_factory.dict_of, - manager_of_class=_instrumentation_factory.manager_of_class + manager_of_class=_instrumentation_factory.manager_of_class, ) ) @@ -395,7 +408,7 @@ def _reinstall_default_lookups(): dict( instance_state=_default_state_getter, instance_dict=_default_dict_getter, - manager_of_class=_default_manager_getter + manager_of_class=_default_manager_getter, ) ) _instrumentation_factory._extended = False @@ -403,12 +416,15 @@ def _reinstall_default_lookups(): def _install_lookups(lookups): global instance_state, instance_dict, manager_of_class - instance_state = lookups['instance_state'] - instance_dict = lookups['instance_dict'] - manager_of_class = lookups['manager_of_class'] - orm_base.instance_state = attributes.instance_state = \ - orm_instrumentation.instance_state = instance_state - orm_base.instance_dict = attributes.instance_dict = \ - orm_instrumentation.instance_dict = instance_dict - orm_base.manager_of_class = attributes.manager_of_class = \ - orm_instrumentation.manager_of_class = manager_of_class + instance_state = lookups["instance_state"] + instance_dict = lookups["instance_dict"] + manager_of_class = lookups["manager_of_class"] + orm_base.instance_state = ( + attributes.instance_state + ) = orm_instrumentation.instance_state = instance_state + orm_base.instance_dict = ( + attributes.instance_dict + ) = orm_instrumentation.instance_dict = instance_dict + orm_base.manager_of_class = ( + attributes.manager_of_class + ) = orm_instrumentation.manager_of_class = manager_of_class diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 014cef3cc..0f6ccdc33 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -502,27 +502,29 @@ class MutableBase(object): def pickle(state, state_dict): val = state.dict.get(key, None) if val is not None: - if 'ext.mutable.values' not in state_dict: - state_dict['ext.mutable.values'] = [] - state_dict['ext.mutable.values'].append(val) + if "ext.mutable.values" not in state_dict: + state_dict["ext.mutable.values"] = [] + state_dict["ext.mutable.values"].append(val) def unpickle(state, state_dict): - if 'ext.mutable.values' in state_dict: - for val in state_dict['ext.mutable.values']: + if "ext.mutable.values" in state_dict: + for val in state_dict["ext.mutable.values"]: val._parents[state.obj()] = key - event.listen(parent_cls, 'load', load, - raw=True, propagate=True) - event.listen(parent_cls, 'refresh', load_attrs, - raw=True, propagate=True) - event.listen(parent_cls, 'refresh_flush', load_attrs, - raw=True, propagate=True) - event.listen(attribute, 'set', set, - raw=True, retval=True, propagate=True) - event.listen(parent_cls, 'pickle', pickle, - raw=True, propagate=True) - event.listen(parent_cls, 'unpickle', unpickle, - raw=True, propagate=True) + event.listen(parent_cls, "load", load, raw=True, propagate=True) + event.listen( + parent_cls, "refresh", load_attrs, raw=True, propagate=True + ) + event.listen( + parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True + ) + event.listen( + attribute, "set", set, raw=True, retval=True, propagate=True + ) + event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True) + event.listen( + parent_cls, "unpickle", unpickle, raw=True, propagate=True + ) class Mutable(MutableBase): @@ -572,7 +574,7 @@ class Mutable(MutableBase): if isinstance(prop.columns[0].type, sqltype): cls.associate_with_attribute(getattr(class_, prop.key)) - event.listen(mapper, 'mapper_configured', listen_for_type) + event.listen(mapper, "mapper_configured", listen_for_type) @classmethod def as_mutable(cls, sqltype): @@ -613,9 +615,11 @@ class Mutable(MutableBase): # and we'll lose our ability to link that type back to the original. # so track our original type w/ columns if isinstance(sqltype, SchemaEventTarget): + @event.listens_for(sqltype, "before_parent_attach") def _add_column_memo(sqltyp, parent): - parent.info['_ext_mutable_orig_type'] = sqltyp + parent.info["_ext_mutable_orig_type"] = sqltyp + schema_event_check = True else: schema_event_check = False @@ -625,16 +629,14 @@ class Mutable(MutableBase): return for prop in mapper.column_attrs: if ( - schema_event_check and - hasattr(prop.expression, 'info') and - prop.expression.info.get('_ext_mutable_orig_type') - is sqltype - ) or ( - prop.columns[0].type is sqltype - ): + schema_event_check + and hasattr(prop.expression, "info") + and prop.expression.info.get("_ext_mutable_orig_type") + is sqltype + ) or (prop.columns[0].type is sqltype): cls.associate_with_attribute(getattr(class_, prop.key)) - event.listen(mapper, 'mapper_configured', listen_for_type) + event.listen(mapper, "mapper_configured", listen_for_type) return sqltype @@ -659,21 +661,27 @@ class MutableComposite(MutableBase): prop = object_mapper(parent).get_property(key) for value, attr_name in zip( - self.__composite_values__(), - prop._attribute_keys): + self.__composite_values__(), prop._attribute_keys + ): setattr(parent, attr_name, value) def _setup_composite_listener(): def _listen_for_type(mapper, class_): for prop in mapper.iterate_properties: - if (hasattr(prop, 'composite_class') and - isinstance(prop.composite_class, type) and - issubclass(prop.composite_class, MutableComposite)): + if ( + hasattr(prop, "composite_class") + and isinstance(prop.composite_class, type) + and issubclass(prop.composite_class, MutableComposite) + ): prop.composite_class._listen_on_attribute( - getattr(class_, prop.key), False, class_) + getattr(class_, prop.key), False, class_ + ) + if not event.contains(Mapper, "mapper_configured", _listen_for_type): - event.listen(Mapper, 'mapper_configured', _listen_for_type) + event.listen(Mapper, "mapper_configured", _listen_for_type) + + _setup_composite_listener() @@ -947,4 +955,4 @@ class MutableSet(Mutable, set): self.update(state) def __reduce_ex__(self, proto): - return (self.__class__, (list(self), )) + return (self.__class__, (list(self),)) diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 316742a67..2a8522120 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -122,7 +122,7 @@ start numbering at 1 or some other integer, provide ``count_from=1``. from ..orm.collections import collection, collection_adapter from .. import util -__all__ = ['ordering_list'] +__all__ = ["ordering_list"] def ordering_list(attr, count_from=None, **kw): @@ -180,8 +180,9 @@ def count_from_n_factory(start): def f(index, collection): return index + start + try: - f.__name__ = 'count_from_%i' % start + f.__name__ = "count_from_%i" % start except TypeError: pass return f @@ -194,14 +195,14 @@ def _unsugar_count_from(**kw): ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. """ - count_from = kw.pop('count_from', None) - if kw.get('ordering_func', None) is None and count_from is not None: + count_from = kw.pop("count_from", None) + if kw.get("ordering_func", None) is None and count_from is not None: if count_from == 0: - kw['ordering_func'] = count_from_0 + kw["ordering_func"] = count_from_0 elif count_from == 1: - kw['ordering_func'] = count_from_1 + kw["ordering_func"] = count_from_1 else: - kw['ordering_func'] = count_from_n_factory(count_from) + kw["ordering_func"] = count_from_n_factory(count_from) return kw @@ -214,8 +215,9 @@ class OrderingList(list): """ - def __init__(self, ordering_attr=None, ordering_func=None, - reorder_on_append=False): + def __init__( + self, ordering_attr=None, ordering_func=None, reorder_on_append=False + ): """A custom list that manages position information for its children. ``OrderingList`` is a ``collection_class`` list implementation that @@ -311,6 +313,7 @@ class OrderingList(list): """Append without any ordering behavior.""" super(OrderingList, self).append(entity) + _raw_append = collection.adds(1)(_raw_append) def insert(self, index, entity): @@ -361,8 +364,12 @@ class OrderingList(list): return _reconstitute, (self.__class__, self.__dict__, list(self)) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(list, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py index 2fded51d1..3adcec34f 100644 --- a/lib/sqlalchemy/ext/serializer.py +++ b/lib/sqlalchemy/ext/serializer.py @@ -64,7 +64,7 @@ from ..util import pickle, byte_buffer, b64encode, b64decode, text_type import re -__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads'] +__all__ = ["Serializer", "Deserializer", "dumps", "loads"] def Serializer(*args, **kw): @@ -79,13 +79,18 @@ def Serializer(*args, **kw): elif isinstance(obj, Mapper) and not obj.non_primary: id = "mapper:" + b64encode(pickle.dumps(obj.class_)) elif isinstance(obj, MapperProperty) and not obj.parent.non_primary: - id = "mapperprop:" + b64encode(pickle.dumps(obj.parent.class_)) + \ - ":" + obj.key + id = ( + "mapperprop:" + + b64encode(pickle.dumps(obj.parent.class_)) + + ":" + + obj.key + ) elif isinstance(obj, Table): id = "table:" + text_type(obj.key) elif isinstance(obj, Column) and isinstance(obj.table, Table): - id = "column:" + \ - text_type(obj.table.key) + ":" + text_type(obj.key) + id = ( + "column:" + text_type(obj.table.key) + ":" + text_type(obj.key) + ) elif isinstance(obj, Session): id = "session:" elif isinstance(obj, Engine): @@ -97,8 +102,10 @@ def Serializer(*args, **kw): pickler.persistent_id = persistent_id return pickler + our_ids = re.compile( - r'(mapperprop|mapper|table|column|session|attribute|engine):(.*)') + r"(mapperprop|mapper|table|column|session|attribute|engine):(.*)" +) def Deserializer(file, metadata=None, scoped_session=None, engine=None): @@ -120,7 +127,7 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): return None else: type_, args = m.group(1, 2) - if type_ == 'attribute': + if type_ == "attribute": key, clsarg = args.split(":") cls = pickle.loads(b64decode(clsarg)) return getattr(cls, key) @@ -128,13 +135,13 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): cls = pickle.loads(b64decode(args)) return class_mapper(cls) elif type_ == "mapperprop": - mapper, keyname = args.split(':') + mapper, keyname = args.split(":") cls = pickle.loads(b64decode(mapper)) return class_mapper(cls).attrs[keyname] elif type_ == "table": return metadata.tables[args] elif type_ == "column": - table, colname = args.split(':') + table, colname = args.split(":") return metadata.tables[table].c[colname] elif type_ == "session": return scoped_session() @@ -142,6 +149,7 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): return get_engine() else: raise Exception("Unknown token: %s" % type_) + unpickler.persistent_load = persistent_load return unpickler |
