summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext')
-rw-r--r--lib/sqlalchemy/ext/__init__.py1
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py288
-rw-r--r--lib/sqlalchemy/ext/automap.py309
-rw-r--r--lib/sqlalchemy/ext/baked.py135
-rw-r--r--lib/sqlalchemy/ext/compiler.py30
-rw-r--r--lib/sqlalchemy/ext/declarative/__init__.py35
-rw-r--r--lib/sqlalchemy/ext/declarative/api.py117
-rw-r--r--lib/sqlalchemy/ext/declarative/base.py326
-rw-r--r--lib/sqlalchemy/ext/declarative/clsregistry.py125
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py61
-rw-r--r--lib/sqlalchemy/ext/hybrid.py30
-rw-r--r--lib/sqlalchemy/ext/indexable.py12
-rw-r--r--lib/sqlalchemy/ext/instrumentation.py78
-rw-r--r--lib/sqlalchemy/ext/mutable.py78
-rw-r--r--lib/sqlalchemy/ext/orderinglist.py29
-rw-r--r--lib/sqlalchemy/ext/serializer.py26
16 files changed, 1021 insertions, 659 deletions
diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py
index 9558b2a1f..9fed09e2b 100644
--- a/lib/sqlalchemy/ext/__init__.py
+++ b/lib/sqlalchemy/ext/__init__.py
@@ -8,4 +8,3 @@
from .. import util as _sa_util
_sa_util.dependencies.resolve_all("sqlalchemy.ext")
-
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index ff9433d4d..56b91ce0b 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -76,7 +76,7 @@ def association_proxy(target_collection, attr, **kw):
return AssociationProxy(target_collection, attr, **kw)
-ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY')
+ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY")
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.AssociationProxy`.
@@ -92,10 +92,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
is_attribute = False
extension_type = ASSOCIATION_PROXY
- def __init__(self, target_collection, attr, creator=None,
- getset_factory=None, proxy_factory=None,
- proxy_bulk_set=None, info=None,
- cascade_scalar_deletes=False):
+ def __init__(
+ self,
+ target_collection,
+ attr,
+ creator=None,
+ getset_factory=None,
+ proxy_factory=None,
+ proxy_bulk_set=None,
+ info=None,
+ cascade_scalar_deletes=False,
+ ):
"""Construct a new :class:`.AssociationProxy`.
The :func:`.association_proxy` function is provided as the usual
@@ -162,8 +169,11 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
self.proxy_bulk_set = proxy_bulk_set
self.cascade_scalar_deletes = cascade_scalar_deletes
- self.key = '_%s_%s_%s' % (
- type(self).__name__, target_collection, id(self))
+ self.key = "_%s_%s_%s" % (
+ type(self).__name__,
+ target_collection,
+ id(self),
+ )
if info:
self.info = info
@@ -264,12 +274,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
def getter(target):
return _getter(target) if target is not None else None
+
if collection_class is dict:
+
def setter(o, k, v):
setattr(o, attr, v)
+
else:
+
def setter(o, v):
setattr(o, attr, v)
+
return getter, setter
@@ -325,20 +340,21 @@ class AssociationProxyInstance(object):
def for_proxy(cls, parent, owning_class, parent_instance):
target_collection = parent.target_collection
value_attr = parent.value_attr
- prop = orm.class_mapper(owning_class).\
- get_property(target_collection)
+ prop = orm.class_mapper(owning_class).get_property(target_collection)
# this was never asserted before but this should be made clear.
if not isinstance(prop, orm.RelationshipProperty):
raise NotImplementedError(
"association proxy to a non-relationship "
- "intermediary is not supported")
+ "intermediary is not supported"
+ )
target_class = prop.mapper.class_
try:
target_assoc = cls._cls_unwrap_target_assoc_proxy(
- target_class, value_attr)
+ target_class, value_attr
+ )
except AttributeError:
# the proxied attribute doesn't exist on the target class;
# return an "ambiguous" instance that will work on a per-object
@@ -353,8 +369,8 @@ class AssociationProxyInstance(object):
@classmethod
def _construct_for_assoc(
- cls, target_assoc, parent, owning_class,
- target_class, value_attr):
+ cls, target_assoc, parent, owning_class, target_class, value_attr
+ ):
if target_assoc is not None:
return ObjectAssociationProxyInstance(
parent, owning_class, target_class, value_attr
@@ -371,8 +387,9 @@ class AssociationProxyInstance(object):
)
def _get_property(self):
- return orm.class_mapper(self.owning_class).\
- get_property(self.target_collection)
+ return orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
@property
def _comparator(self):
@@ -388,7 +405,8 @@ class AssociationProxyInstance(object):
@util.memoized_property
def _unwrap_target_assoc_proxy(self):
return self._cls_unwrap_target_assoc_proxy(
- self.target_class, self.value_attr)
+ self.target_class, self.value_attr
+ )
@property
def remote_attr(self):
@@ -448,8 +466,11 @@ class AssociationProxyInstance(object):
@util.memoized_property
def _value_is_scalar(self):
- return not self._get_property().\
- mapper.get_property(self.value_attr).uselist
+ return (
+ not self._get_property()
+ .mapper.get_property(self.value_attr)
+ .uselist
+ )
@property
def _target_is_object(self):
@@ -468,12 +489,17 @@ class AssociationProxyInstance(object):
def getter(target):
return _getter(target) if target is not None else None
+
if collection_class is dict:
+
def setter(o, k, v):
return setattr(o, attr, v)
+
else:
+
def setter(o, v):
return setattr(o, attr, v)
+
return getter, setter
@property
@@ -500,14 +526,18 @@ class AssociationProxyInstance(object):
return proxy
self.collection_class, proxy = self._new(
- _lazy_collection(obj, self.target_collection))
+ _lazy_collection(obj, self.target_collection)
+ )
setattr(obj, self.key, (id(obj), id(self), proxy))
return proxy
def set(self, obj, values):
if self.scalar:
- creator = self.parent.creator \
- if self.parent.creator else self.target_class
+ creator = (
+ self.parent.creator
+ if self.parent.creator
+ else self.target_class
+ )
target = getattr(obj, self.target_collection)
if target is None:
if values is None:
@@ -535,35 +565,52 @@ class AssociationProxyInstance(object):
delattr(obj, self.target_collection)
def _new(self, lazy_collection):
- creator = self.parent.creator if self.parent.creator else \
- self.target_class
+ creator = (
+ self.parent.creator if self.parent.creator else self.target_class
+ )
collection_class = util.duck_type_collection(lazy_collection())
if self.parent.proxy_factory:
- return collection_class, self.parent.proxy_factory(
- lazy_collection, creator, self.value_attr, self)
+ return (
+ collection_class,
+ self.parent.proxy_factory(
+ lazy_collection, creator, self.value_attr, self
+ ),
+ )
if self.parent.getset_factory:
- getter, setter = self.parent.getset_factory(
- collection_class, self)
+ getter, setter = self.parent.getset_factory(collection_class, self)
else:
getter, setter = self.parent._default_getset(collection_class)
if collection_class is list:
- return collection_class, _AssociationList(
- lazy_collection, creator, getter, setter, self)
+ return (
+ collection_class,
+ _AssociationList(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
elif collection_class is dict:
- return collection_class, _AssociationDict(
- lazy_collection, creator, getter, setter, self)
+ return (
+ collection_class,
+ _AssociationDict(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
elif collection_class is set:
- return collection_class, _AssociationSet(
- lazy_collection, creator, getter, setter, self)
+ return (
+ collection_class,
+ _AssociationSet(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
else:
raise exc.ArgumentError(
- 'could not guess which interface to use for '
+ "could not guess which interface to use for "
'collection_class "%s" backing "%s"; specify a '
- 'proxy_factory and proxy_bulk_set manually' %
- (self.collection_class.__name__, self.target_collection))
+ "proxy_factory and proxy_bulk_set manually"
+ % (self.collection_class.__name__, self.target_collection)
+ )
def _set(self, proxy, values):
if self.parent.proxy_bulk_set:
@@ -576,16 +623,19 @@ class AssociationProxyInstance(object):
proxy.update(values)
else:
raise exc.ArgumentError(
- 'no proxy_bulk_set supplied for custom '
- 'collection_class implementation')
+ "no proxy_bulk_set supplied for custom "
+ "collection_class implementation"
+ )
def _inflate(self, proxy):
- creator = self.parent.creator and \
- self.parent.creator or self.target_class
+ creator = (
+ self.parent.creator and self.parent.creator or self.target_class
+ )
if self.parent.getset_factory:
getter, setter = self.parent.getset_factory(
- self.collection_class, self)
+ self.collection_class, self
+ )
else:
getter, setter = self.parent._default_getset(self.collection_class)
@@ -594,12 +644,13 @@ class AssociationProxyInstance(object):
proxy.setter = setter
def _criterion_exists(self, criterion=None, **kwargs):
- is_has = kwargs.pop('is_has', None)
+ is_has = kwargs.pop("is_has", None)
target_assoc = self._unwrap_target_assoc_proxy
if target_assoc is not None:
inner = target_assoc._criterion_exists(
- criterion=criterion, **kwargs)
+ criterion=criterion, **kwargs
+ )
return self._comparator._criterion_exists(inner)
if self._target_is_object:
@@ -631,15 +682,15 @@ class AssociationProxyInstance(object):
"""
if self._unwrap_target_assoc_proxy is None and (
- self.scalar and (
- not self._target_is_object or self._value_is_scalar)
+ self.scalar
+ and (not self._target_is_object or self._value_is_scalar)
):
raise exc.InvalidRequestError(
- "'any()' not implemented for scalar "
- "attributes. Use has()."
+ "'any()' not implemented for scalar " "attributes. Use has()."
)
return self._criterion_exists(
- criterion=criterion, is_has=False, **kwargs)
+ criterion=criterion, is_has=False, **kwargs
+ )
def has(self, criterion=None, **kwargs):
"""Produce a proxied 'has' expression using EXISTS.
@@ -651,14 +702,15 @@ class AssociationProxyInstance(object):
"""
if self._unwrap_target_assoc_proxy is None and (
- not self.scalar or (
- self._target_is_object and not self._value_is_scalar)
+ not self.scalar
+ or (self._target_is_object and not self._value_is_scalar)
):
raise exc.InvalidRequestError(
- "'has()' not implemented for collections. "
- "Use any().")
+ "'has()' not implemented for collections. " "Use any()."
+ )
return self._criterion_exists(
- criterion=criterion, is_has=True, **kwargs)
+ criterion=criterion, is_has=True, **kwargs
+ )
class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
@@ -673,10 +725,14 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
"Association proxy %s.%s refers to an attribute '%s' that is not "
"directly mapped on class %s; therefore this operation cannot "
"proceed since we don't know what type of object is referred "
- "towards" % (
- self.owning_class.__name__, self.target_collection,
- self.value_attr, self.target_class
- ))
+ "towards"
+ % (
+ self.owning_class.__name__,
+ self.target_collection,
+ self.value_attr,
+ self.target_class,
+ )
+ )
def get(self, obj):
self._ambiguous()
@@ -718,27 +774,32 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
return self
def _populate_cache(self, instance_class):
- prop = orm.class_mapper(self.owning_class).\
- get_property(self.target_collection)
+ prop = orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
if inspect(instance_class).mapper.isa(prop.mapper):
target_class = instance_class
try:
target_assoc = self._cls_unwrap_target_assoc_proxy(
- target_class, self.value_attr)
+ target_class, self.value_attr
+ )
except AttributeError:
pass
else:
- self._lookup_cache[instance_class] = \
- self._construct_for_assoc(
- target_assoc, self.parent, self.owning_class,
- target_class, self.value_attr
+ self._lookup_cache[instance_class] = self._construct_for_assoc(
+ target_assoc,
+ self.parent,
+ self.owning_class,
+ target_class,
+ self.value_attr,
)
class ObjectAssociationProxyInstance(AssociationProxyInstance):
"""an :class:`.AssociationProxyInstance` that has an object as a target.
"""
+
_target_is_object = True
_is_canonical = True
@@ -756,17 +817,21 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
if target_assoc is not None:
return self._comparator._criterion_exists(
target_assoc.contains(obj)
- if not target_assoc.scalar else target_assoc == obj
+ if not target_assoc.scalar
+ else target_assoc == obj
)
- elif self._target_is_object and self.scalar and \
- not self._value_is_scalar:
+ elif (
+ self._target_is_object
+ and self.scalar
+ and not self._value_is_scalar
+ ):
return self._comparator.has(
getattr(self.target_class, self.value_attr).contains(obj)
)
- elif self._target_is_object and self.scalar and \
- self._value_is_scalar:
+ elif self._target_is_object and self.scalar and self._value_is_scalar:
raise exc.InvalidRequestError(
- "contains() doesn't apply to a scalar object endpoint; use ==")
+ "contains() doesn't apply to a scalar object endpoint; use =="
+ )
else:
return self._comparator._criterion_exists(**{self.value_attr: obj})
@@ -777,7 +842,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
if obj is None:
return or_(
self._comparator.has(**{self.value_attr: obj}),
- self._comparator == None
+ self._comparator == None,
)
else:
return self._comparator.has(**{self.value_attr: obj})
@@ -786,14 +851,17 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
# note the has() here will fail for collections; eq_()
# is only allowed with a scalar.
return self._comparator.has(
- getattr(self.target_class, self.value_attr) != obj)
+ getattr(self.target_class, self.value_attr) != obj
+ )
class ColumnAssociationProxyInstance(
- ColumnOperators, AssociationProxyInstance):
+ ColumnOperators, AssociationProxyInstance
+):
"""an :class:`.AssociationProxyInstance` that has a database column as a
target.
"""
+
_target_is_object = False
_is_canonical = True
@@ -803,9 +871,7 @@ class ColumnAssociationProxyInstance(
self.remote_attr.operate(operator.eq, other)
)
if other is None:
- return or_(
- expr, self._comparator == None
- )
+ return or_(expr, self._comparator == None)
else:
return expr
@@ -824,11 +890,11 @@ class _lazy_collection(object):
return getattr(self.parent, self.target)
def __getstate__(self):
- return {'obj': self.parent, 'target': self.target}
+ return {"obj": self.parent, "target": self.target}
def __setstate__(self, state):
- self.parent = state['obj']
- self.target = state['target']
+ self.parent = state["obj"]
+ self.target = state["target"]
class _AssociationCollection(object):
@@ -874,11 +940,11 @@ class _AssociationCollection(object):
__nonzero__ = __bool__
def __getstate__(self):
- return {'parent': self.parent, 'lazy_collection': self.lazy_collection}
+ return {"parent": self.parent, "lazy_collection": self.lazy_collection}
def __setstate__(self, state):
- self.parent = state['parent']
- self.lazy_collection = state['lazy_collection']
+ self.parent = state["parent"]
+ self.lazy_collection = state["lazy_collection"]
self.parent._inflate(self)
@@ -925,8 +991,8 @@ class _AssociationList(_AssociationCollection):
if len(value) != len(rng):
raise ValueError(
"attempt to assign sequence of size %s to "
- "extended slice of size %s" % (len(value),
- len(rng)))
+ "extended slice of size %s" % (len(value), len(rng))
+ )
for i, item in zip(rng, value):
self._set(self.col[i], item)
@@ -968,8 +1034,14 @@ class _AssociationList(_AssociationCollection):
col.append(item)
def count(self, value):
- return sum([1 for _ in
- util.itertools_filter(lambda v: v == value, iter(self))])
+ return sum(
+ [
+ 1
+ for _ in util.itertools_filter(
+ lambda v: v == value, iter(self)
+ )
+ ]
+ )
def extend(self, values):
for v in values:
@@ -999,7 +1071,7 @@ class _AssociationList(_AssociationCollection):
raise NotImplementedError
def clear(self):
- del self.col[0:len(self.col)]
+ del self.col[0 : len(self.col)]
def __eq__(self, other):
return list(self) == other
@@ -1040,6 +1112,7 @@ class _AssociationList(_AssociationCollection):
if not isinstance(n, int):
return NotImplemented
return list(self) * n
+
__rmul__ = __mul__
def __iadd__(self, iterable):
@@ -1072,13 +1145,17 @@ class _AssociationList(_AssociationCollection):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in list(locals().items()):
- if (util.callable(func) and func.__name__ == func_name and
- not func.__doc__ and hasattr(list, func_name)):
+ if (
+ util.callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
-_NotProvided = util.symbol('_NotProvided')
+_NotProvided = util.symbol("_NotProvided")
class _AssociationDict(_AssociationCollection):
@@ -1160,6 +1237,7 @@ class _AssociationDict(_AssociationCollection):
return self.col.keys()
if util.py2k:
+
def iteritems(self):
return ((key, self._get(self.col[key])) for key in self.col)
@@ -1174,7 +1252,9 @@ class _AssociationDict(_AssociationCollection):
def items(self):
return [(k, self._get(self.col[k])) for k in self]
+
else:
+
def items(self):
return ((key, self._get(self.col[key])) for key in self.col)
@@ -1194,14 +1274,15 @@ class _AssociationDict(_AssociationCollection):
def update(self, *a, **kw):
if len(a) > 1:
- raise TypeError('update expected at most 1 arguments, got %i' %
- len(a))
+ raise TypeError(
+ "update expected at most 1 arguments, got %i" % len(a)
+ )
elif len(a) == 1:
seq_or_map = a[0]
# discern dict from sequence - took the advice from
# http://www.voidspace.org.uk/python/articles/duck_typing.shtml
# still not perfect :(
- if hasattr(seq_or_map, 'keys'):
+ if hasattr(seq_or_map, "keys"):
for item in seq_or_map:
self[item] = seq_or_map[item]
else:
@@ -1211,7 +1292,8 @@ class _AssociationDict(_AssociationCollection):
except ValueError:
raise ValueError(
"dictionary update sequence "
- "requires 2-element tuples")
+ "requires 2-element tuples"
+ )
for key, value in kw:
self[key] = value
@@ -1223,8 +1305,12 @@ class _AssociationDict(_AssociationCollection):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in list(locals().items()):
- if (util.callable(func) and func.__name__ == func_name and
- not func.__doc__ and hasattr(dict, func_name)):
+ if (
+ util.callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(dict, func_name)
+ ):
func.__doc__ = getattr(dict, func_name).__doc__
del func_name, func
@@ -1288,7 +1374,7 @@ class _AssociationSet(_AssociationCollection):
def pop(self):
if not self.col:
- raise KeyError('pop from an empty set')
+ raise KeyError("pop from an empty set")
member = self.col.pop()
return self._get(member)
@@ -1420,7 +1506,11 @@ class _AssociationSet(_AssociationCollection):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in list(locals().items()):
- if (util.callable(func) and func.__name__ == func_name and
- not func.__doc__ and hasattr(set, func_name)):
+ if (
+ util.callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(set, func_name)
+ ):
func.__doc__ = getattr(set, func_name).__doc__
del func_name, func
diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py
index cafb3d61c..747373a2a 100644
--- a/lib/sqlalchemy/ext/automap.py
+++ b/lib/sqlalchemy/ext/automap.py
@@ -580,7 +580,8 @@ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
def name_for_collection_relationship(
- base, local_cls, referred_cls, constraint):
+ base, local_cls, referred_cls, constraint
+):
"""Return the attribute name that should be used to refer from one
class to another, for a collection reference.
@@ -607,7 +608,8 @@ def name_for_collection_relationship(
def generate_relationship(
- base, direction, return_fn, attrname, local_cls, referred_cls, **kw):
+ base, direction, return_fn, attrname, local_cls, referred_cls, **kw
+):
r"""Generate a :func:`.relationship` or :func:`.backref` on behalf of two
mapped classes.
@@ -677,6 +679,7 @@ class AutomapBase(object):
:ref:`automap_toplevel`
"""
+
__abstract__ = True
classes = None
@@ -694,15 +697,16 @@ class AutomapBase(object):
@classmethod
def prepare(
- cls,
- engine=None,
- reflect=False,
- schema=None,
- classname_for_table=classname_for_table,
- collection_class=list,
- name_for_scalar_relationship=name_for_scalar_relationship,
- name_for_collection_relationship=name_for_collection_relationship,
- generate_relationship=generate_relationship):
+ cls,
+ engine=None,
+ reflect=False,
+ schema=None,
+ classname_for_table=classname_for_table,
+ collection_class=list,
+ name_for_scalar_relationship=name_for_scalar_relationship,
+ name_for_collection_relationship=name_for_collection_relationship,
+ generate_relationship=generate_relationship,
+ ):
"""Extract mapped classes and relationships from the :class:`.MetaData` and
perform mappings.
@@ -752,15 +756,16 @@ class AutomapBase(object):
engine,
schema=schema,
extend_existing=True,
- autoload_replace=False
+ autoload_replace=False,
)
_CONFIGURE_MUTEX.acquire()
try:
table_to_map_config = dict(
(m.local_table, m)
- for m in _DeferredMapperConfig.
- classes_for_base(cls, sort=False)
+ for m in _DeferredMapperConfig.classes_for_base(
+ cls, sort=False
+ )
)
many_to_many = []
@@ -774,30 +779,39 @@ class AutomapBase(object):
elif table not in table_to_map_config:
mapped_cls = type(
classname_for_table(cls, table.name, table),
- (cls, ),
- {"__table__": table}
+ (cls,),
+ {"__table__": table},
)
map_config = _DeferredMapperConfig.config_for_cls(
- mapped_cls)
+ mapped_cls
+ )
cls.classes[map_config.cls.__name__] = mapped_cls
table_to_map_config[table] = map_config
for map_config in table_to_map_config.values():
- _relationships_for_fks(cls,
- map_config,
- table_to_map_config,
- collection_class,
- name_for_scalar_relationship,
- name_for_collection_relationship,
- generate_relationship)
+ _relationships_for_fks(
+ cls,
+ map_config,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+ )
for lcl_m2m, rem_m2m, m2m_const, table in many_to_many:
- _m2m_relationship(cls, lcl_m2m, rem_m2m, m2m_const, table,
- table_to_map_config,
- collection_class,
- name_for_scalar_relationship,
- name_for_collection_relationship,
- generate_relationship)
+ _m2m_relationship(
+ cls,
+ lcl_m2m,
+ rem_m2m,
+ m2m_const,
+ table,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+ )
for map_config in _DeferredMapperConfig.classes_for_base(cls):
map_config.map()
@@ -853,20 +867,27 @@ def automap_base(declarative_base=None, **kw):
return type(
Base.__name__,
- (AutomapBase, Base,),
- {"__abstract__": True, "classes": util.Properties({})}
+ (AutomapBase, Base),
+ {"__abstract__": True, "classes": util.Properties({})},
)
def _is_many_to_many(automap_base, table):
- fk_constraints = [const for const in table.constraints
- if isinstance(const, ForeignKeyConstraint)]
+ fk_constraints = [
+ const
+ for const in table.constraints
+ if isinstance(const, ForeignKeyConstraint)
+ ]
if len(fk_constraints) != 2:
return None, None, None
cols = sum(
- [[fk.parent for fk in fk_constraint.elements]
- for fk_constraint in fk_constraints], [])
+ [
+ [fk.parent for fk in fk_constraint.elements]
+ for fk_constraint in fk_constraints
+ ],
+ [],
+ )
if set(cols) != set(table.c):
return None, None, None
@@ -874,15 +895,19 @@ def _is_many_to_many(automap_base, table):
return (
fk_constraints[0].elements[0].column.table,
fk_constraints[1].elements[0].column.table,
- fk_constraints
+ fk_constraints,
)
-def _relationships_for_fks(automap_base, map_config, table_to_map_config,
- collection_class,
- name_for_scalar_relationship,
- name_for_collection_relationship,
- generate_relationship):
+def _relationships_for_fks(
+ automap_base,
+ map_config,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+):
local_table = map_config.local_table
local_cls = map_config.cls # derived from a weakref, may be None
@@ -898,32 +923,33 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config,
referred_cls = referred_cfg.cls
if local_cls is not referred_cls and issubclass(
- local_cls, referred_cls):
+ local_cls, referred_cls
+ ):
continue
relationship_name = name_for_scalar_relationship(
- automap_base,
- local_cls,
- referred_cls, constraint)
+ automap_base, local_cls, referred_cls, constraint
+ )
backref_name = name_for_collection_relationship(
- automap_base,
- referred_cls,
- local_cls,
- constraint
+ automap_base, referred_cls, local_cls, constraint
)
o2m_kws = {}
nullable = False not in {fk.parent.nullable for fk in fks}
if not nullable:
- o2m_kws['cascade'] = "all, delete-orphan"
+ o2m_kws["cascade"] = "all, delete-orphan"
- if constraint.ondelete and \
- constraint.ondelete.lower() == "cascade":
- o2m_kws['passive_deletes'] = True
+ if (
+ constraint.ondelete
+ and constraint.ondelete.lower() == "cascade"
+ ):
+ o2m_kws["passive_deletes"] = True
else:
- if constraint.ondelete and \
- constraint.ondelete.lower() == "set null":
- o2m_kws['passive_deletes'] = True
+ if (
+ constraint.ondelete
+ and constraint.ondelete.lower() == "set null"
+ ):
+ o2m_kws["passive_deletes"] = True
create_backref = backref_name not in referred_cfg.properties
@@ -931,54 +957,65 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config,
if create_backref:
backref_obj = generate_relationship(
automap_base,
- interfaces.ONETOMANY, backref,
- backref_name, referred_cls, local_cls,
+ interfaces.ONETOMANY,
+ backref,
+ backref_name,
+ referred_cls,
+ local_cls,
collection_class=collection_class,
- **o2m_kws)
+ **o2m_kws
+ )
else:
backref_obj = None
- rel = generate_relationship(automap_base,
- interfaces.MANYTOONE,
- relationship,
- relationship_name,
- local_cls, referred_cls,
- foreign_keys=[
- fk.parent
- for fk in constraint.elements],
- backref=backref_obj,
- remote_side=[
- fk.column
- for fk in constraint.elements]
- )
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOONE,
+ relationship,
+ relationship_name,
+ local_cls,
+ referred_cls,
+ foreign_keys=[fk.parent for fk in constraint.elements],
+ backref=backref_obj,
+ remote_side=[fk.column for fk in constraint.elements],
+ )
if rel is not None:
map_config.properties[relationship_name] = rel
if not create_backref:
referred_cfg.properties[
- backref_name].back_populates = relationship_name
+ backref_name
+ ].back_populates = relationship_name
elif create_backref:
- rel = generate_relationship(automap_base,
- interfaces.ONETOMANY,
- relationship,
- backref_name,
- referred_cls, local_cls,
- foreign_keys=[
- fk.parent
- for fk in constraint.elements],
- back_populates=relationship_name,
- collection_class=collection_class,
- **o2m_kws)
+ rel = generate_relationship(
+ automap_base,
+ interfaces.ONETOMANY,
+ relationship,
+ backref_name,
+ referred_cls,
+ local_cls,
+ foreign_keys=[fk.parent for fk in constraint.elements],
+ back_populates=relationship_name,
+ collection_class=collection_class,
+ **o2m_kws
+ )
if rel is not None:
referred_cfg.properties[backref_name] = rel
map_config.properties[
- relationship_name].back_populates = backref_name
-
-
-def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table,
- table_to_map_config,
- collection_class,
- name_for_scalar_relationship,
- name_for_collection_relationship,
- generate_relationship):
+ relationship_name
+ ].back_populates = backref_name
+
+
+def _m2m_relationship(
+ automap_base,
+ lcl_m2m,
+ rem_m2m,
+ m2m_const,
+ table,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+):
map_config = table_to_map_config.get(lcl_m2m, None)
referred_cfg = table_to_map_config.get(rem_m2m, None)
@@ -989,14 +1026,10 @@ def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table,
referred_cls = referred_cfg.cls
relationship_name = name_for_collection_relationship(
- automap_base,
- local_cls,
- referred_cls, m2m_const[0])
+ automap_base, local_cls, referred_cls, m2m_const[0]
+ )
backref_name = name_for_collection_relationship(
- automap_base,
- referred_cls,
- local_cls,
- m2m_const[1]
+ automap_base, referred_cls, local_cls, m2m_const[1]
)
create_backref = backref_name not in referred_cfg.properties
@@ -1008,48 +1041,56 @@ def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table,
interfaces.MANYTOMANY,
backref,
backref_name,
- referred_cls, local_cls,
- collection_class=collection_class
+ referred_cls,
+ local_cls,
+ collection_class=collection_class,
)
else:
backref_obj = None
- rel = generate_relationship(automap_base,
- interfaces.MANYTOMANY,
- relationship,
- relationship_name,
- local_cls, referred_cls,
- secondary=table,
- primaryjoin=and_(
- fk.column == fk.parent
- for fk in m2m_const[0].elements),
- secondaryjoin=and_(
- fk.column == fk.parent
- for fk in m2m_const[1].elements),
- backref=backref_obj,
- collection_class=collection_class
- )
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ relationship,
+ relationship_name,
+ local_cls,
+ referred_cls,
+ secondary=table,
+ primaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[0].elements
+ ),
+ secondaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[1].elements
+ ),
+ backref=backref_obj,
+ collection_class=collection_class,
+ )
if rel is not None:
map_config.properties[relationship_name] = rel
if not create_backref:
referred_cfg.properties[
- backref_name].back_populates = relationship_name
+ backref_name
+ ].back_populates = relationship_name
elif create_backref:
- rel = generate_relationship(automap_base,
- interfaces.MANYTOMANY,
- relationship,
- backref_name,
- referred_cls, local_cls,
- secondary=table,
- primaryjoin=and_(
- fk.column == fk.parent
- for fk in m2m_const[1].elements),
- secondaryjoin=and_(
- fk.column == fk.parent
- for fk in m2m_const[0].elements),
- back_populates=relationship_name,
- collection_class=collection_class)
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ relationship,
+ backref_name,
+ referred_cls,
+ local_cls,
+ secondary=table,
+ primaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[1].elements
+ ),
+ secondaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[0].elements
+ ),
+ back_populates=relationship_name,
+ collection_class=collection_class,
+ )
if rel is not None:
referred_cfg.properties[backref_name] = rel
map_config.properties[
- relationship_name].back_populates = backref_name
+ relationship_name
+ ].back_populates = backref_name
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py
index 516879142..f55231a09 100644
--- a/lib/sqlalchemy/ext/baked.py
+++ b/lib/sqlalchemy/ext/baked.py
@@ -38,7 +38,8 @@ class Bakery(object):
"""
- __slots__ = 'cls', 'cache'
+
+ __slots__ = "cls", "cache"
def __init__(self, cls_, cache):
self.cls = cls_
@@ -51,7 +52,7 @@ class Bakery(object):
class BakedQuery(object):
"""A builder object for :class:`.query.Query` objects."""
- __slots__ = 'steps', '_bakery', '_cache_key', '_spoiled'
+ __slots__ = "steps", "_bakery", "_cache_key", "_spoiled"
def __init__(self, bakery, initial_fn, args=()):
self._cache_key = ()
@@ -148,7 +149,7 @@ class BakedQuery(object):
"""
if not full and not self._spoiled:
_spoil_point = self._clone()
- _spoil_point._cache_key += ('_query_only', )
+ _spoil_point._cache_key += ("_query_only",)
self.steps = [_spoil_point._retrieve_baked_query]
self._spoiled = True
return self
@@ -164,7 +165,7 @@ class BakedQuery(object):
session will want to use.
"""
- return self._cache_key + (session._query_cls, )
+ return self._cache_key + (session._query_cls,)
def _with_lazyload_options(self, options, effective_path, cache_path=None):
"""Cloning version of _add_lazyload_options.
@@ -201,16 +202,20 @@ class BakedQuery(object):
key += cache_key
self.add_criteria(
- lambda q: q._with_current_path(effective_path).
- _conditional_options(*options),
- cache_path.path, key
+ lambda q: q._with_current_path(
+ effective_path
+ )._conditional_options(*options),
+ cache_path.path,
+ key,
)
def _retrieve_baked_query(self, session):
query = self._bakery.get(self._effective_key(session), None)
if query is None:
query = self._as_query(session)
- self._bakery[self._effective_key(session)] = query.with_session(None)
+ self._bakery[self._effective_key(session)] = query.with_session(
+ None
+ )
return query.with_session(session)
def _bake(self, session):
@@ -227,8 +232,12 @@ class BakedQuery(object):
# so delete some compilation-use-only attributes that can take up
# space
for attr in (
- '_correlate', '_from_obj', '_mapper_adapter_map',
- '_joinpath', '_joinpoint'):
+ "_correlate",
+ "_from_obj",
+ "_mapper_adapter_map",
+ "_joinpath",
+ "_joinpoint",
+ ):
query.__dict__.pop(attr, None)
self._bakery[self._effective_key(session)] = context
return context
@@ -276,11 +285,13 @@ class BakedQuery(object):
session = query_or_session.session
if session is None:
raise sa_exc.ArgumentError(
- "Given Query needs to be associated with a Session")
+ "Given Query needs to be associated with a Session"
+ )
else:
raise TypeError(
- "Query or Session object expected, got %r." %
- type(query_or_session))
+ "Query or Session object expected, got %r."
+ % type(query_or_session)
+ )
return self._as_query(session)
def _as_query(self, session):
@@ -299,10 +310,10 @@ class BakedQuery(object):
a "baked" query so that we save on performance too.
"""
- context.attributes['baked_queries'] = baked_queries = []
+ context.attributes["baked_queries"] = baked_queries = []
for k, v in list(context.attributes.items()):
if isinstance(v, Query):
- if 'subquery' in k:
+ if "subquery" in k:
bk = BakedQuery(self._bakery, lambda *args: v)
bk._cache_key = self._cache_key + k
bk._bake(session)
@@ -310,15 +321,17 @@ class BakedQuery(object):
del context.attributes[k]
def _unbake_subquery_loaders(
- self, session, context, params, post_criteria):
+ self, session, context, params, post_criteria
+ ):
"""Retrieve subquery eager loaders stored by _bake_subquery_loaders
and turn them back into Result objects that will iterate just
like a Query object.
"""
for k, cache_key, query in context.attributes["baked_queries"]:
- bk = BakedQuery(self._bakery,
- lambda sess, q=query: q.with_session(sess))
+ bk = BakedQuery(
+ self._bakery, lambda sess, q=query: q.with_session(sess)
+ )
bk._cache_key = cache_key
q = bk.for_session(session)
for fn in post_criteria:
@@ -334,7 +347,8 @@ class Result(object):
against a target :class:`.Session`, and is then invoked for results.
"""
- __slots__ = 'bq', 'session', '_params', '_post_criteria'
+
+ __slots__ = "bq", "session", "_params", "_post_criteria"
def __init__(self, bq, session):
self.bq = bq
@@ -350,7 +364,8 @@ class Result(object):
elif len(args) > 0:
raise sa_exc.ArgumentError(
"params() takes zero or one positional argument, "
- "which is a dictionary.")
+ "which is a dictionary."
+ )
self._params.update(kw)
return self
@@ -403,7 +418,8 @@ class Result(object):
context.attributes = context.attributes.copy()
bq._unbake_subquery_loaders(
- self.session, context, self._params, self._post_criteria)
+ self.session, context, self._params, self._post_criteria
+ )
context.statement.use_labels = True
if context.autoflush and not context.populate_existing:
@@ -426,7 +442,7 @@ class Result(object):
"""
- col = func.count(literal_column('*'))
+ col = func.count(literal_column("*"))
bq = self.bq.with_criteria(lambda q: q.from_self(col))
return bq.for_session(self.session).params(self._params).scalar()
@@ -456,8 +472,10 @@ class Result(object):
"""
bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
ret = list(
- bq.for_session(self.session).params(self._params).
- _using_post_criteria(self._post_criteria))
+ bq.for_session(self.session)
+ .params(self._params)
+ ._using_post_criteria(self._post_criteria)
+ )
if len(ret) > 0:
return ret[0]
else:
@@ -473,7 +491,8 @@ class Result(object):
ret = self.one_or_none()
except orm_exc.MultipleResultsFound:
raise orm_exc.MultipleResultsFound(
- "Multiple rows were found for one()")
+ "Multiple rows were found for one()"
+ )
else:
if ret is None:
raise orm_exc.NoResultFound("No row was found for one()")
@@ -497,7 +516,8 @@ class Result(object):
return None
else:
raise orm_exc.MultipleResultsFound(
- "Multiple rows were found for one_or_none()")
+ "Multiple rows were found for one_or_none()"
+ )
def all(self):
"""Return all rows.
@@ -533,13 +553,18 @@ class Result(object):
# None present in ident - turn those comparisons
# into "IS NULL"
if None in primary_key_identity:
- nones = set([
- _get_params[col].key for col, value in
- zip(mapper.primary_key, primary_key_identity)
- if value is None
- ])
+ nones = set(
+ [
+ _get_params[col].key
+ for col, value in zip(
+ mapper.primary_key, primary_key_identity
+ )
+ if value is None
+ ]
+ )
_lcl_get_clause = sql_util.adapt_criterion_to_null(
- _lcl_get_clause, nones)
+ _lcl_get_clause, nones
+ )
_lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False)
q._criterion = _lcl_get_clause
@@ -556,16 +581,20 @@ class Result(object):
# key so that if a race causes multiple calls to _get_clause,
# we've cached on ours
bq = bq._clone()
- bq._cache_key += (_get_clause, )
+ bq._cache_key += (_get_clause,)
bq = bq.with_criteria(
- setup, tuple(elem is None for elem in primary_key_identity))
+ setup, tuple(elem is None for elem in primary_key_identity)
+ )
- params = dict([
- (_get_params[primary_key].key, id_val)
- for id_val, primary_key
- in zip(primary_key_identity, mapper.primary_key)
- ])
+ params = dict(
+ [
+ (_get_params[primary_key].key, id_val)
+ for id_val, primary_key in zip(
+ primary_key_identity, mapper.primary_key
+ )
+ ]
+ )
result = list(bq.for_session(self.session).params(**params))
l = len(result)
@@ -578,7 +607,8 @@ class Result(object):
@util.deprecated(
- "1.2", "Baked lazy loading is now the default implementation.")
+ "1.2", "Baked lazy loading is now the default implementation."
+)
def bake_lazy_loaders():
"""Enable the use of baked queries for all lazyloaders systemwide.
@@ -590,7 +620,8 @@ def bake_lazy_loaders():
@util.deprecated(
- "1.2", "Baked lazy loading is now the default implementation.")
+ "1.2", "Baked lazy loading is now the default implementation."
+)
def unbake_lazy_loaders():
"""Disable the use of baked queries for all lazyloaders systemwide.
@@ -601,7 +632,8 @@ def unbake_lazy_loaders():
"""
raise NotImplementedError(
- "Baked lazy loading is now the default implementation")
+ "Baked lazy loading is now the default implementation"
+ )
@strategy_options.loader_option()
@@ -615,20 +647,27 @@ def baked_lazyload(loadopt, attr):
@baked_lazyload._add_unbound_fn
@util.deprecated(
- "1.2", "Baked lazy loading is now the default "
- "implementation for lazy loading.")
+ "1.2",
+ "Baked lazy loading is now the default "
+ "implementation for lazy loading.",
+)
def baked_lazyload(*keys):
return strategy_options._UnboundLoad._from_keys(
- strategy_options._UnboundLoad.baked_lazyload, keys, False, {})
+ strategy_options._UnboundLoad.baked_lazyload, keys, False, {}
+ )
@baked_lazyload._add_unbound_all_fn
@util.deprecated(
- "1.2", "Baked lazy loading is now the default "
- "implementation for lazy loading.")
+ "1.2",
+ "Baked lazy loading is now the default "
+ "implementation for lazy loading.",
+)
def baked_lazyload_all(*keys):
return strategy_options._UnboundLoad._from_keys(
- strategy_options._UnboundLoad.baked_lazyload, keys, True, {})
+ strategy_options._UnboundLoad.baked_lazyload, keys, True, {}
+ )
+
baked_lazyload = baked_lazyload._unbound_fn
baked_lazyload_all = baked_lazyload_all._unbound_all_fn
diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py
index 6a0909d36..220b2c057 100644
--- a/lib/sqlalchemy/ext/compiler.py
+++ b/lib/sqlalchemy/ext/compiler.py
@@ -407,37 +407,44 @@ def compiles(class_, *specs):
def decorate(fn):
# get an existing @compiles handler
- existing = class_.__dict__.get('_compiler_dispatcher', None)
+ existing = class_.__dict__.get("_compiler_dispatcher", None)
# get the original handler. All ClauseElement classes have one
# of these, but some TypeEngine classes will not.
- existing_dispatch = getattr(class_, '_compiler_dispatch', None)
+ existing_dispatch = getattr(class_, "_compiler_dispatch", None)
if not existing:
existing = _dispatcher()
if existing_dispatch:
+
def _wrap_existing_dispatch(element, compiler, **kw):
try:
return existing_dispatch(element, compiler, **kw)
except exc.UnsupportedCompilationError:
raise exc.CompileError(
"%s construct has no default "
- "compilation handler." % type(element))
- existing.specs['default'] = _wrap_existing_dispatch
+ "compilation handler." % type(element)
+ )
+
+ existing.specs["default"] = _wrap_existing_dispatch
# TODO: why is the lambda needed ?
- setattr(class_, '_compiler_dispatch',
- lambda *arg, **kw: existing(*arg, **kw))
- setattr(class_, '_compiler_dispatcher', existing)
+ setattr(
+ class_,
+ "_compiler_dispatch",
+ lambda *arg, **kw: existing(*arg, **kw),
+ )
+ setattr(class_, "_compiler_dispatcher", existing)
if specs:
for s in specs:
existing.specs[s] = fn
else:
- existing.specs['default'] = fn
+ existing.specs["default"] = fn
return fn
+
return decorate
@@ -445,7 +452,7 @@ def deregister(class_):
"""Remove all custom compilers associated with a given
:class:`.ClauseElement` type."""
- if hasattr(class_, '_compiler_dispatcher'):
+ if hasattr(class_, "_compiler_dispatcher"):
# regenerate default _compiler_dispatch
visitors._generate_dispatch(class_)
# remove custom directive
@@ -461,10 +468,11 @@ class _dispatcher(object):
fn = self.specs.get(compiler.dialect.name, None)
if not fn:
try:
- fn = self.specs['default']
+ fn = self.specs["default"]
except KeyError:
raise exc.CompileError(
"%s construct has no default "
- "compilation handler." % type(element))
+ "compilation handler." % type(element)
+ )
return fn(element, compiler, **kw)
diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py
index cb81f51e5..2b0a37884 100644
--- a/lib/sqlalchemy/ext/declarative/__init__.py
+++ b/lib/sqlalchemy/ext/declarative/__init__.py
@@ -5,14 +5,31 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from .api import declarative_base, synonym_for, comparable_using, \
- instrument_declarative, ConcreteBase, AbstractConcreteBase, \
- DeclarativeMeta, DeferredReflection, has_inherited_table,\
- declared_attr, as_declarative
+from .api import (
+ declarative_base,
+ synonym_for,
+ comparable_using,
+ instrument_declarative,
+ ConcreteBase,
+ AbstractConcreteBase,
+ DeclarativeMeta,
+ DeferredReflection,
+ has_inherited_table,
+ declared_attr,
+ as_declarative,
+)
-__all__ = ['declarative_base', 'synonym_for', 'has_inherited_table',
- 'comparable_using', 'instrument_declarative', 'declared_attr',
- 'as_declarative',
- 'ConcreteBase', 'AbstractConcreteBase', 'DeclarativeMeta',
- 'DeferredReflection']
+__all__ = [
+ "declarative_base",
+ "synonym_for",
+ "has_inherited_table",
+ "comparable_using",
+ "instrument_declarative",
+ "declared_attr",
+ "as_declarative",
+ "ConcreteBase",
+ "AbstractConcreteBase",
+ "DeclarativeMeta",
+ "DeferredReflection",
+]
diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py
index 865cd16f0..987e92119 100644
--- a/lib/sqlalchemy/ext/declarative/api.py
+++ b/lib/sqlalchemy/ext/declarative/api.py
@@ -8,9 +8,13 @@
from ...schema import Table, MetaData, Column
-from ...orm import synonym as _orm_synonym, \
- comparable_property,\
- interfaces, properties, attributes
+from ...orm import (
+ synonym as _orm_synonym,
+ comparable_property,
+ interfaces,
+ properties,
+ attributes,
+)
from ...orm.util import polymorphic_union
from ...orm.base import _mapper_or_none
from ...util import OrderedDict, hybridmethod, hybridproperty
@@ -19,9 +23,13 @@ from ... import exc
import weakref
import re
-from .base import _as_declarative, \
- _declarative_constructor,\
- _DeferredMapperConfig, _add_attribute, _del_attribute
+from .base import (
+ _as_declarative,
+ _declarative_constructor,
+ _DeferredMapperConfig,
+ _add_attribute,
+ _del_attribute,
+)
from .clsregistry import _class_resolver
@@ -31,10 +39,10 @@ def instrument_declarative(cls, registry, metadata):
MetaData object.
"""
- if '_decl_class_registry' in cls.__dict__:
+ if "_decl_class_registry" in cls.__dict__:
raise exc.InvalidRequestError(
- "Class %r already has been "
- "instrumented declaratively" % cls)
+ "Class %r already has been " "instrumented declaratively" % cls
+ )
cls._decl_class_registry = registry
cls.metadata = metadata
_as_declarative(cls, cls.__name__, cls.__dict__)
@@ -54,14 +62,14 @@ def has_inherited_table(cls):
"""
for class_ in cls.__mro__[1:]:
- if getattr(class_, '__table__', None) is not None:
+ if getattr(class_, "__table__", None) is not None:
return True
return False
class DeclarativeMeta(type):
def __init__(cls, classname, bases, dict_):
- if '_decl_class_registry' not in cls.__dict__:
+ if "_decl_class_registry" not in cls.__dict__:
_as_declarative(cls, classname, cls.__dict__)
type.__init__(cls, classname, bases, dict_)
@@ -71,6 +79,7 @@ class DeclarativeMeta(type):
def __delattr__(cls, key):
_del_attribute(cls, key)
+
def synonym_for(name, map_column=False):
"""Decorator that produces an :func:`.orm.synonym` attribute in conjunction
with a Python descriptor.
@@ -104,8 +113,10 @@ def synonym_for(name, map_column=False):
can be achieved with synonyms.
"""
+
def decorate(fn):
return _orm_synonym(name, map_column=map_column, descriptor=fn)
+
return decorate
@@ -127,8 +138,10 @@ def comparable_using(comparator_factory):
prop = comparable_property(MyComparatorType)
"""
+
def decorate(fn):
return comparable_property(comparator_factory, fn)
+
return decorate
@@ -190,14 +203,16 @@ class declared_attr(interfaces._MappedAttribute, property):
self._cascading = cascading
def __get__(desc, self, cls):
- reg = cls.__dict__.get('_sa_declared_attr_reg', None)
+ reg = cls.__dict__.get("_sa_declared_attr_reg", None)
if reg is None:
- if not re.match(r'^__.+__$', desc.fget.__name__) and \
- attributes.manager_of_class(cls) is None:
+ if (
+ not re.match(r"^__.+__$", desc.fget.__name__)
+ and attributes.manager_of_class(cls) is None
+ ):
util.warn(
"Unmanaged access of declarative attribute %s from "
- "non-mapped class %s" %
- (desc.fget.__name__, cls.__name__))
+ "non-mapped class %s" % (desc.fget.__name__, cls.__name__)
+ )
return desc.fget(cls)
elif desc in reg:
return reg[desc]
@@ -283,10 +298,16 @@ class _stateful_declared_attr(declared_attr):
return declared_attr(fn, **self.kw)
-def declarative_base(bind=None, metadata=None, mapper=None, cls=object,
- name='Base', constructor=_declarative_constructor,
- class_registry=None,
- metaclass=DeclarativeMeta):
+def declarative_base(
+ bind=None,
+ metadata=None,
+ mapper=None,
+ cls=object,
+ name="Base",
+ constructor=_declarative_constructor,
+ class_registry=None,
+ metaclass=DeclarativeMeta,
+):
r"""Construct a base class for declarative class definitions.
The new base class will be given a metaclass that produces
@@ -357,16 +378,17 @@ def declarative_base(bind=None, metadata=None, mapper=None, cls=object,
class_registry = weakref.WeakValueDictionary()
bases = not isinstance(cls, tuple) and (cls,) or cls
- class_dict = dict(_decl_class_registry=class_registry,
- metadata=lcl_metadata)
+ class_dict = dict(
+ _decl_class_registry=class_registry, metadata=lcl_metadata
+ )
if isinstance(cls, type):
- class_dict['__doc__'] = cls.__doc__
+ class_dict["__doc__"] = cls.__doc__
if constructor:
- class_dict['__init__'] = constructor
+ class_dict["__init__"] = constructor
if mapper:
- class_dict['__mapper_cls__'] = mapper
+ class_dict["__mapper_cls__"] = mapper
return metaclass(name, bases, class_dict)
@@ -401,9 +423,10 @@ def as_declarative(**kw):
:func:`.declarative_base`
"""
+
def decorate(cls):
- kw['cls'] = cls
- kw['name'] = cls.__name__
+ kw["cls"] = cls
+ kw["name"] = cls.__name__
return declarative_base(**kw)
return decorate
@@ -456,10 +479,13 @@ class ConcreteBase(object):
@classmethod
def _create_polymorphic_union(cls, mappers):
- return polymorphic_union(OrderedDict(
- (mp.polymorphic_identity, mp.local_table)
- for mp in mappers
- ), 'type', 'pjoin')
+ return polymorphic_union(
+ OrderedDict(
+ (mp.polymorphic_identity, mp.local_table) for mp in mappers
+ ),
+ "type",
+ "pjoin",
+ )
@classmethod
def __declare_first__(cls):
@@ -568,7 +594,7 @@ class AbstractConcreteBase(ConcreteBase):
@classmethod
def _sa_decl_prepare_nocascade(cls):
- if getattr(cls, '__mapper__', None):
+ if getattr(cls, "__mapper__", None):
return
to_map = _DeferredMapperConfig.config_for_cls(cls)
@@ -604,8 +630,9 @@ class AbstractConcreteBase(ConcreteBase):
def mapper_args():
args = m_args()
- args['polymorphic_on'] = pjoin.c.type
+ args["polymorphic_on"] = pjoin.c.type
return args
+
to_map.mapper_args_fn = mapper_args
m = to_map.map()
@@ -684,6 +711,7 @@ class DeferredReflection(object):
.. versionadded:: 0.8
"""
+
@classmethod
def prepare(cls, engine):
"""Reflect all :class:`.Table` objects for all current
@@ -696,8 +724,10 @@ class DeferredReflection(object):
mapper = thingy.cls.__mapper__
metadata = mapper.class_.metadata
for rel in mapper._props.values():
- if isinstance(rel, properties.RelationshipProperty) and \
- rel.secondary is not None:
+ if (
+ isinstance(rel, properties.RelationshipProperty)
+ and rel.secondary is not None
+ ):
if isinstance(rel.secondary, Table):
cls._reflect_table(rel.secondary, engine)
elif isinstance(rel.secondary, _class_resolver):
@@ -711,6 +741,7 @@ class DeferredReflection(object):
t1 = Table(key, metadata)
cls._reflect_table(t1, engine)
return t1
+
return _resolve
@classmethod
@@ -724,10 +755,12 @@ class DeferredReflection(object):
@classmethod
def _reflect_table(cls, table, engine):
- Table(table.name,
- table.metadata,
- extend_existing=True,
- autoload_replace=False,
- autoload=True,
- autoload_with=engine,
- schema=table.schema)
+ Table(
+ table.name,
+ table.metadata,
+ extend_existing=True,
+ autoload_replace=False,
+ autoload=True,
+ autoload_with=engine,
+ schema=table.schema,
+ )
diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py
index f27314b5e..07778f733 100644
--- a/lib/sqlalchemy/ext/declarative/base.py
+++ b/lib/sqlalchemy/ext/declarative/base.py
@@ -39,7 +39,7 @@ def _resolve_for_abstract_or_classical(cls):
if cls is object:
return None
- if _get_immediate_cls_attr(cls, '__abstract__', strict=True):
+ if _get_immediate_cls_attr(cls, "__abstract__", strict=True):
for sup in cls.__bases__:
sup = _resolve_for_abstract_or_classical(sup)
if sup is not None:
@@ -59,7 +59,7 @@ def _dive_for_classically_mapped_class(cls):
# if we are within a base hierarchy, don't
# search at all for classical mappings
- if hasattr(cls, '_decl_class_registry'):
+ if hasattr(cls, "_decl_class_registry"):
return None
manager = instrumentation.manager_of_class(cls)
@@ -89,15 +89,19 @@ def _get_immediate_cls_attr(cls, attrname, strict=False):
return None
for base in cls.__mro__:
- _is_declarative_inherits = hasattr(base, '_decl_class_registry')
- _is_classicial_inherits = not _is_declarative_inherits and \
- _dive_for_classically_mapped_class(base) is not None
+ _is_declarative_inherits = hasattr(base, "_decl_class_registry")
+ _is_classicial_inherits = (
+ not _is_declarative_inherits
+ and _dive_for_classically_mapped_class(base) is not None
+ )
if attrname in base.__dict__ and (
- base is cls or
- ((base in cls.__bases__ if strict else True)
+ base is cls
+ or (
+ (base in cls.__bases__ if strict else True)
and not _is_declarative_inherits
- and not _is_classicial_inherits)
+ and not _is_classicial_inherits
+ )
):
return getattr(base, attrname)
else:
@@ -108,9 +112,10 @@ def _as_declarative(cls, classname, dict_):
global declared_attr, declarative_props
if declared_attr is None:
from .api import declared_attr
+
declarative_props = (declared_attr, util.classproperty)
- if _get_immediate_cls_attr(cls, '__abstract__', strict=True):
+ if _get_immediate_cls_attr(cls, "__abstract__", strict=True):
return
_MapperConfig.setup_mapping(cls, classname, dict_)
@@ -119,23 +124,23 @@ def _as_declarative(cls, classname, dict_):
def _check_declared_props_nocascade(obj, name, cls):
if isinstance(obj, declarative_props):
- if getattr(obj, '_cascading', False):
+ if getattr(obj, "_cascading", False):
util.warn(
"@declared_attr.cascading is not supported on the %s "
"attribute on class %s. This attribute invokes for "
- "subclasses in any case." % (name, cls))
+ "subclasses in any case." % (name, cls)
+ )
return True
else:
return False
class _MapperConfig(object):
-
@classmethod
def setup_mapping(cls, cls_, classname, dict_):
defer_map = _get_immediate_cls_attr(
- cls_, '_sa_decl_prepare_nocascade', strict=True) or \
- hasattr(cls_, '_sa_decl_prepare')
+ cls_, "_sa_decl_prepare_nocascade", strict=True
+ ) or hasattr(cls_, "_sa_decl_prepare")
if defer_map:
cfg_cls = _DeferredMapperConfig
@@ -179,12 +184,14 @@ class _MapperConfig(object):
self.map()
def _setup_declared_events(self):
- if _get_immediate_cls_attr(self.cls, '__declare_last__'):
+ if _get_immediate_cls_attr(self.cls, "__declare_last__"):
+
@event.listens_for(mapper, "after_configured")
def after_configured():
self.cls.__declare_last__()
- if _get_immediate_cls_attr(self.cls, '__declare_first__'):
+ if _get_immediate_cls_attr(self.cls, "__declare_first__"):
+
@event.listens_for(mapper, "before_configured")
def before_configured():
self.cls.__declare_first__()
@@ -198,59 +205,62 @@ class _MapperConfig(object):
tablename = None
for base in cls.__mro__:
- class_mapped = base is not cls and \
- _declared_mapping_info(base) is not None and \
- not _get_immediate_cls_attr(
- base, '_sa_decl_prepare_nocascade', strict=True)
+ class_mapped = (
+ base is not cls
+ and _declared_mapping_info(base) is not None
+ and not _get_immediate_cls_attr(
+ base, "_sa_decl_prepare_nocascade", strict=True
+ )
+ )
if not class_mapped and base is not cls:
self._produce_column_copies(base)
for name, obj in vars(base).items():
- if name == '__mapper_args__':
- check_decl = \
- _check_declared_props_nocascade(obj, name, cls)
- if not mapper_args_fn and (
- not class_mapped or
- check_decl
- ):
+ if name == "__mapper_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not mapper_args_fn and (not class_mapped or check_decl):
# don't even invoke __mapper_args__ until
# after we've determined everything about the
# mapped table.
# make a copy of it so a class-level dictionary
# is not overwritten when we update column-based
# arguments.
- mapper_args_fn = lambda: dict(cls.__mapper_args__) # noqa
- elif name == '__tablename__':
- check_decl = \
- _check_declared_props_nocascade(obj, name, cls)
- if not tablename and (
- not class_mapped or
- check_decl
- ):
+ mapper_args_fn = lambda: dict(
+ cls.__mapper_args__
+ ) # noqa
+ elif name == "__tablename__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not tablename and (not class_mapped or check_decl):
tablename = cls.__tablename__
- elif name == '__table_args__':
- check_decl = \
- _check_declared_props_nocascade(obj, name, cls)
- if not table_args and (
- not class_mapped or
- check_decl
- ):
+ elif name == "__table_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not table_args and (not class_mapped or check_decl):
table_args = cls.__table_args__
if not isinstance(
- table_args, (tuple, dict, type(None))):
+ table_args, (tuple, dict, type(None))
+ ):
raise exc.ArgumentError(
"__table_args__ value must be a tuple, "
- "dict, or None")
+ "dict, or None"
+ )
if base is not cls:
inherited_table_args = True
elif class_mapped:
if isinstance(obj, declarative_props):
- util.warn("Regular (i.e. not __special__) "
- "attribute '%s.%s' uses @declared_attr, "
- "but owning class %s is mapped - "
- "not applying to subclass %s."
- % (base.__name__, name, base, cls))
+ util.warn(
+ "Regular (i.e. not __special__) "
+ "attribute '%s.%s' uses @declared_attr, "
+ "but owning class %s is mapped - "
+ "not applying to subclass %s."
+ % (base.__name__, name, base, cls)
+ )
continue
elif base is not cls:
# we're a mixin, abstract base, or something that is
@@ -263,7 +273,8 @@ class _MapperConfig(object):
"Mapper properties (i.e. deferred,"
"column_property(), relationship(), etc.) must "
"be declared as @declared_attr callables "
- "on declarative mixin classes.")
+ "on declarative mixin classes."
+ )
elif isinstance(obj, declarative_props):
oldclassprop = isinstance(obj, util.classproperty)
if not oldclassprop and obj._cascading:
@@ -278,15 +289,18 @@ class _MapperConfig(object):
"Attribute '%s' on class %s cannot be "
"processed due to "
"@declared_attr.cascading; "
- "skipping" % (name, cls))
- dict_[name] = column_copies[obj] = \
- ret = obj.__get__(obj, cls)
+ "skipping" % (name, cls)
+ )
+ dict_[name] = column_copies[
+ obj
+ ] = ret = obj.__get__(obj, cls)
setattr(cls, name, ret)
else:
if oldclassprop:
util.warn_deprecated(
"Use of sqlalchemy.util.classproperty on "
- "declarative classes is deprecated.")
+ "declarative classes is deprecated."
+ )
# access attribute using normal class access
ret = getattr(cls, name)
@@ -294,14 +308,20 @@ class _MapperConfig(object):
# or similar. note there is no known case that
# produces nested proxies, so we are only
# looking one level deep right now.
- if isinstance(ret, InspectionAttr) and \
- ret._is_internal_proxy and not isinstance(
- ret.original_property, MapperProperty):
+ if (
+ isinstance(ret, InspectionAttr)
+ and ret._is_internal_proxy
+ and not isinstance(
+ ret.original_property, MapperProperty
+ )
+ ):
ret = ret.descriptor
dict_[name] = column_copies[obj] = ret
- if isinstance(ret, (Column, MapperProperty)) and \
- ret.doc is None:
+ if (
+ isinstance(ret, (Column, MapperProperty))
+ and ret.doc is None
+ ):
ret.doc = obj.__doc__
# here, the attribute is some other kind of property that
# we assume is not part of the declarative mapping.
@@ -321,8 +341,9 @@ class _MapperConfig(object):
util.warn(
"Attribute '%s' on class %s appears to be a non-schema "
"'sqlalchemy.sql.column()' "
- "object; this won't be part of the declarative mapping" %
- (key, cls))
+ "object; this won't be part of the declarative mapping"
+ % (key, cls)
+ )
def _produce_column_copies(self, base):
cls = self.cls
@@ -340,10 +361,11 @@ class _MapperConfig(object):
raise exc.InvalidRequestError(
"Columns with foreign keys to other columns "
"must be declared as @declared_attr callables "
- "on declarative mixin classes. ")
+ "on declarative mixin classes. "
+ )
elif name not in dict_ and not (
- '__table__' in dict_ and
- (obj.name or name) in dict_['__table__'].c
+ "__table__" in dict_
+ and (obj.name or name) in dict_["__table__"].c
):
column_copies[obj] = copy_ = obj.copy()
copy_._creation_order = obj._creation_order
@@ -357,11 +379,12 @@ class _MapperConfig(object):
our_stuff = self.properties
late_mapped = _get_immediate_cls_attr(
- cls, '_sa_decl_prepare_nocascade', strict=True)
+ cls, "_sa_decl_prepare_nocascade", strict=True
+ )
for k in list(dict_):
- if k in ('__table__', '__tablename__', '__mapper_args__'):
+ if k in ("__table__", "__tablename__", "__mapper_args__"):
continue
value = dict_[k]
@@ -371,29 +394,37 @@ class _MapperConfig(object):
"Use of @declared_attr.cascading only applies to "
"Declarative 'mixin' and 'abstract' classes. "
"Currently, this flag is ignored on mapped class "
- "%s" % self.cls)
+ "%s" % self.cls
+ )
value = getattr(cls, k)
- elif isinstance(value, QueryableAttribute) and \
- value.class_ is not cls and \
- value.key != k:
+ elif (
+ isinstance(value, QueryableAttribute)
+ and value.class_ is not cls
+ and value.key != k
+ ):
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
value = synonym(value.key)
setattr(cls, k, value)
- if (isinstance(value, tuple) and len(value) == 1 and
- isinstance(value[0], (Column, MapperProperty))):
- util.warn("Ignoring declarative-like tuple value of attribute "
- "'%s': possibly a copy-and-paste error with a comma "
- "accidentally placed at the end of the line?" % k)
+ if (
+ isinstance(value, tuple)
+ and len(value) == 1
+ and isinstance(value[0], (Column, MapperProperty))
+ ):
+ util.warn(
+ "Ignoring declarative-like tuple value of attribute "
+ "'%s': possibly a copy-and-paste error with a comma "
+ "accidentally placed at the end of the line?" % k
+ )
continue
elif not isinstance(value, (Column, MapperProperty)):
# using @declared_attr for some object that
# isn't Column/MapperProperty; remove from the dict_
# and place the evaluated value onto the class.
- if not k.startswith('__'):
+ if not k.startswith("__"):
dict_.pop(k)
self._warn_for_decl_attributes(cls, k, value)
if not late_mapped:
@@ -402,7 +433,7 @@ class _MapperConfig(object):
# we expect to see the name 'metadata' in some valid cases;
# however at this point we see it's assigned to something trying
# to be mapped, so raise for that.
- elif k == 'metadata':
+ elif k == "metadata":
raise exc.InvalidRequestError(
"Attribute name 'metadata' is reserved "
"for the MetaData instance when using a "
@@ -423,8 +454,7 @@ class _MapperConfig(object):
for key, c in list(our_stuff.items()):
if isinstance(c, (ColumnProperty, CompositeProperty)):
for col in c.columns:
- if isinstance(col, Column) and \
- col.table is None:
+ if isinstance(col, Column) and col.table is None:
_undefer_column_name(key, col)
if not isinstance(c, CompositeProperty):
name_to_prop_key[col.name].add(key)
@@ -447,8 +477,8 @@ class _MapperConfig(object):
"On class %r, Column object %r named "
"directly multiple times, "
"only one will be used: %s. "
- "Consider using orm.synonym instead" %
- (self.classname, name, (", ".join(sorted(keys))))
+ "Consider using orm.synonym instead"
+ % (self.classname, name, (", ".join(sorted(keys))))
)
def _setup_table(self):
@@ -459,15 +489,16 @@ class _MapperConfig(object):
declared_columns = self.declared_columns
declared_columns = self.declared_columns = sorted(
- declared_columns, key=lambda c: c._creation_order)
+ declared_columns, key=lambda c: c._creation_order
+ )
table = None
- if hasattr(cls, '__table_cls__'):
+ if hasattr(cls, "__table_cls__"):
table_cls = util.unbound_method_to_callable(cls.__table_cls__)
else:
table_cls = Table
- if '__table__' not in dict_:
+ if "__table__" not in dict_:
if tablename is not None:
args, table_kw = (), {}
@@ -480,14 +511,16 @@ class _MapperConfig(object):
else:
args = table_args
- autoload = dict_.get('__autoload__')
+ autoload = dict_.get("__autoload__")
if autoload:
- table_kw['autoload'] = True
+ table_kw["autoload"] = True
cls.__table__ = table = table_cls(
- tablename, cls.metadata,
+ tablename,
+ cls.metadata,
*(tuple(declared_columns) + tuple(args)),
- **table_kw)
+ **table_kw
+ )
else:
table = cls.__table__
if declared_columns:
@@ -512,21 +545,27 @@ class _MapperConfig(object):
c = _resolve_for_abstract_or_classical(c)
if c is None:
continue
- if _declared_mapping_info(c) is not None and \
- not _get_immediate_cls_attr(
- c, '_sa_decl_prepare_nocascade', strict=True):
+ if _declared_mapping_info(
+ c
+ ) is not None and not _get_immediate_cls_attr(
+ c, "_sa_decl_prepare_nocascade", strict=True
+ ):
inherits.append(c)
if inherits:
if len(inherits) > 1:
raise exc.InvalidRequestError(
- "Class %s has multiple mapped bases: %r" % (cls, inherits))
+ "Class %s has multiple mapped bases: %r" % (cls, inherits)
+ )
self.inherits = inherits[0]
else:
self.inherits = None
- if table is None and self.inherits is None and \
- not _get_immediate_cls_attr(cls, '__no_table__'):
+ if (
+ table is None
+ and self.inherits is None
+ and not _get_immediate_cls_attr(cls, "__no_table__")
+ ):
raise exc.InvalidRequestError(
"Class %r does not have a __table__ or __tablename__ "
@@ -553,8 +592,8 @@ class _MapperConfig(object):
continue
raise exc.ArgumentError(
"Column '%s' on class %s conflicts with "
- "existing column '%s'" %
- (c, cls, inherited_table.c[c.name])
+ "existing column '%s'"
+ % (c, cls, inherited_table.c[c.name])
)
if c.primary_key:
raise exc.ArgumentError(
@@ -562,8 +601,10 @@ class _MapperConfig(object):
"class with no table."
)
inherited_table.append_column(c)
- if inherited_mapped_table is not None and \
- inherited_mapped_table is not inherited_table:
+ if (
+ inherited_mapped_table is not None
+ and inherited_mapped_table is not inherited_table
+ ):
inherited_mapped_table._refresh_for_new_column(c)
def _prepare_mapper_arguments(self):
@@ -575,18 +616,19 @@ class _MapperConfig(object):
# make sure that column copies are used rather
# than the original columns from any mixins
- for k in ('version_id_col', 'polymorphic_on',):
+ for k in ("version_id_col", "polymorphic_on"):
if k in mapper_args:
v = mapper_args[k]
mapper_args[k] = self.column_copies.get(v, v)
- assert 'inherits' not in mapper_args, \
- "Can't specify 'inherits' explicitly with declarative mappings"
+ assert (
+ "inherits" not in mapper_args
+ ), "Can't specify 'inherits' explicitly with declarative mappings"
if self.inherits:
- mapper_args['inherits'] = self.inherits
+ mapper_args["inherits"] = self.inherits
- if self.inherits and not mapper_args.get('concrete', False):
+ if self.inherits and not mapper_args.get("concrete", False):
# single or joined inheritance
# exclude any cols on the inherited table which are
# not mapped on the parent class, to avoid
@@ -594,16 +636,17 @@ class _MapperConfig(object):
inherited_mapper = _declared_mapping_info(self.inherits)
inherited_table = inherited_mapper.local_table
- if 'exclude_properties' not in mapper_args:
- mapper_args['exclude_properties'] = exclude_properties = \
- set(
- [c.key for c in inherited_table.c
- if c not in inherited_mapper._columntoproperty]
- ).union(
- inherited_mapper.exclude_properties or ()
- )
+ if "exclude_properties" not in mapper_args:
+ mapper_args["exclude_properties"] = exclude_properties = set(
+ [
+ c.key
+ for c in inherited_table.c
+ if c not in inherited_mapper._columntoproperty
+ ]
+ ).union(inherited_mapper.exclude_properties or ())
exclude_properties.difference_update(
- [c.key for c in self.declared_columns])
+ [c.key for c in self.declared_columns]
+ )
# look through columns in the current mapper that
# are keyed to a propname different than the colname
@@ -621,21 +664,20 @@ class _MapperConfig(object):
# first. See [ticket:1892] for background.
properties[k] = [col] + p.columns
result_mapper_args = mapper_args.copy()
- result_mapper_args['properties'] = properties
+ result_mapper_args["properties"] = properties
self.mapper_args = result_mapper_args
def map(self):
self._prepare_mapper_arguments()
- if hasattr(self.cls, '__mapper_cls__'):
+ if hasattr(self.cls, "__mapper_cls__"):
mapper_cls = util.unbound_method_to_callable(
- self.cls.__mapper_cls__)
+ self.cls.__mapper_cls__
+ )
else:
mapper_cls = mapper
self.cls.__mapper__ = mp_ = mapper_cls(
- self.cls,
- self.local_table,
- **self.mapper_args
+ self.cls, self.local_table, **self.mapper_args
)
del self.cls._sa_declared_attr_reg
return mp_
@@ -663,8 +705,7 @@ class _DeferredMapperConfig(_MapperConfig):
@classmethod
def has_cls(cls, class_):
# 2.6 fails on weakref if class_ is an old style class
- return isinstance(class_, type) and \
- weakref.ref(class_) in cls._configs
+ return isinstance(class_, type) and weakref.ref(class_) in cls._configs
@classmethod
def config_for_cls(cls, class_):
@@ -673,18 +714,15 @@ class _DeferredMapperConfig(_MapperConfig):
@classmethod
def classes_for_base(cls, base_cls, sort=True):
classes_for_base = [
- m for m, cls_ in
- [(m, m.cls) for m in cls._configs.values()]
+ m
+ for m, cls_ in [(m, m.cls) for m in cls._configs.values()]
if cls_ is not None and issubclass(cls_, base_cls)
]
if not sort:
return classes_for_base
- all_m_by_cls = dict(
- (m.cls, m)
- for m in classes_for_base
- )
+ all_m_by_cls = dict((m.cls, m) for m in classes_for_base)
tuples = []
for m_cls in all_m_by_cls:
@@ -693,12 +731,7 @@ class _DeferredMapperConfig(_MapperConfig):
for base_cls in m_cls.__bases__
if base_cls in all_m_by_cls
)
- return list(
- topological.sort(
- tuples,
- classes_for_base
- )
- )
+ return list(topological.sort(tuples, classes_for_base))
def map(self):
self._configs.pop(self._cls, None)
@@ -713,7 +746,7 @@ def _add_attribute(cls, key, value):
"""
- if '__mapper__' in cls.__dict__:
+ if "__mapper__" in cls.__dict__:
if isinstance(value, Column):
_undefer_column_name(key, value)
cls.__table__.append_column(value)
@@ -726,16 +759,14 @@ def _add_attribute(cls, key, value):
cls.__mapper__.add_property(key, value)
elif isinstance(value, MapperProperty):
cls.__mapper__.add_property(
- key,
- clsregistry._deferred_relationship(cls, value)
+ key, clsregistry._deferred_relationship(cls, value)
)
elif isinstance(value, QueryableAttribute) and value.key != key:
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
value = synonym(value.key)
cls.__mapper__.add_property(
- key,
- clsregistry._deferred_relationship(cls, value)
+ key, clsregistry._deferred_relationship(cls, value)
)
else:
type.__setattr__(cls, key, value)
@@ -746,15 +777,18 @@ def _add_attribute(cls, key, value):
def _del_attribute(cls, key):
- if '__mapper__' in cls.__dict__ and \
- key in cls.__dict__ and not cls.__mapper__._dispose_called:
+ if (
+ "__mapper__" in cls.__dict__
+ and key in cls.__dict__
+ and not cls.__mapper__._dispose_called
+ ):
value = cls.__dict__[key]
if isinstance(
- value,
- (Column, ColumnProperty, MapperProperty, QueryableAttribute)
+ value, (Column, ColumnProperty, MapperProperty, QueryableAttribute)
):
raise NotImplementedError(
- "Can't un-map individual mapped attributes on a mapped class.")
+ "Can't un-map individual mapped attributes on a mapped class."
+ )
else:
type.__delattr__(cls, key)
cls.__mapper__._expire_memoizations()
@@ -776,10 +810,12 @@ def _declarative_constructor(self, **kwargs):
for k in kwargs:
if not hasattr(cls_, k):
raise TypeError(
- "%r is an invalid keyword argument for %s" %
- (k, cls_.__name__))
+ "%r is an invalid keyword argument for %s" % (k, cls_.__name__)
+ )
setattr(self, k, kwargs[k])
-_declarative_constructor.__name__ = '__init__'
+
+
+_declarative_constructor.__name__ = "__init__"
def _undefer_column_name(key, column):
diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py
index e941b9ed3..c52ae4a2f 100644
--- a/lib/sqlalchemy/ext/declarative/clsregistry.py
+++ b/lib/sqlalchemy/ext/declarative/clsregistry.py
@@ -10,8 +10,11 @@ This system allows specification of classes and expressions used in
:func:`.relationship` using strings.
"""
-from ...orm.properties import ColumnProperty, RelationshipProperty, \
- SynonymProperty
+from ...orm.properties import (
+ ColumnProperty,
+ RelationshipProperty,
+ SynonymProperty,
+)
from ...schema import _get_table_key
from ...orm import class_mapper, interfaces
from ... import util
@@ -35,17 +38,18 @@ def add_class(classname, cls):
# class already exists.
existing = cls._decl_class_registry[classname]
if not isinstance(existing, _MultipleClassMarker):
- existing = \
- cls._decl_class_registry[classname] = \
- _MultipleClassMarker([cls, existing])
+ existing = cls._decl_class_registry[
+ classname
+ ] = _MultipleClassMarker([cls, existing])
else:
cls._decl_class_registry[classname] = cls
try:
- root_module = cls._decl_class_registry['_sa_module_registry']
+ root_module = cls._decl_class_registry["_sa_module_registry"]
except KeyError:
- cls._decl_class_registry['_sa_module_registry'] = \
- root_module = _ModuleMarker('_sa_module_registry', None)
+ cls._decl_class_registry[
+ "_sa_module_registry"
+ ] = root_module = _ModuleMarker("_sa_module_registry", None)
tokens = cls.__module__.split(".")
@@ -71,12 +75,13 @@ class _MultipleClassMarker(object):
"""
- __slots__ = 'on_remove', 'contents', '__weakref__'
+ __slots__ = "on_remove", "contents", "__weakref__"
def __init__(self, classes, on_remove=None):
self.on_remove = on_remove
- self.contents = set([
- weakref.ref(item, self._remove_item) for item in classes])
+ self.contents = set(
+ [weakref.ref(item, self._remove_item) for item in classes]
+ )
_registries.add(self)
def __iter__(self):
@@ -85,10 +90,10 @@ class _MultipleClassMarker(object):
def attempt_get(self, path, key):
if len(self.contents) > 1:
raise exc.InvalidRequestError(
- "Multiple classes found for path \"%s\" "
+ 'Multiple classes found for path "%s" '
"in the registry of this declarative "
- "base. Please use a fully module-qualified path." %
- (".".join(path + [key]))
+ "base. Please use a fully module-qualified path."
+ % (".".join(path + [key]))
)
else:
ref = list(self.contents)[0]
@@ -108,17 +113,19 @@ class _MultipleClassMarker(object):
# protect against class registration race condition against
# asynchronous garbage collection calling _remove_item,
# [ticket:3208]
- modules = set([
- cls.__module__ for cls in
- [ref() for ref in self.contents] if cls is not None])
+ modules = set(
+ [
+ cls.__module__
+ for cls in [ref() for ref in self.contents]
+ if cls is not None
+ ]
+ )
if item.__module__ in modules:
util.warn(
"This declarative base already contains a class with the "
"same class name and module name as %s.%s, and will "
- "be replaced in the string-lookup table." % (
- item.__module__,
- item.__name__
- )
+ "be replaced in the string-lookup table."
+ % (item.__module__, item.__name__)
)
self.contents.add(weakref.ref(item, self._remove_item))
@@ -129,7 +136,7 @@ class _ModuleMarker(object):
"""
- __slots__ = 'parent', 'name', 'contents', 'mod_ns', 'path', '__weakref__'
+ __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
def __init__(self, name, parent):
self.parent = parent
@@ -170,13 +177,13 @@ class _ModuleMarker(object):
existing = self.contents[name]
existing.add_item(cls)
else:
- existing = self.contents[name] = \
- _MultipleClassMarker([cls],
- on_remove=lambda: self._remove_item(name))
+ existing = self.contents[name] = _MultipleClassMarker(
+ [cls], on_remove=lambda: self._remove_item(name)
+ )
class _ModNS(object):
- __slots__ = '__parent',
+ __slots__ = ("__parent",)
def __init__(self, parent):
self.__parent = parent
@@ -193,13 +200,14 @@ class _ModNS(object):
else:
assert isinstance(value, _MultipleClassMarker)
return value.attempt_get(self.__parent.path, key)
- raise AttributeError("Module %r has no mapped classes "
- "registered under the name %r" % (
- self.__parent.name, key))
+ raise AttributeError(
+ "Module %r has no mapped classes "
+ "registered under the name %r" % (self.__parent.name, key)
+ )
class _GetColumns(object):
- __slots__ = 'cls',
+ __slots__ = ("cls",)
def __init__(self, cls):
self.cls = cls
@@ -210,7 +218,8 @@ class _GetColumns(object):
if key not in mp.all_orm_descriptors:
raise exc.InvalidRequestError(
"Class %r does not have a mapped column named %r"
- % (self.cls, key))
+ % (self.cls, key)
+ )
desc = mp.all_orm_descriptors[key]
if desc.extension_type is interfaces.NOT_EXTENSION:
@@ -221,24 +230,25 @@ class _GetColumns(object):
raise exc.InvalidRequestError(
"Property %r is not an instance of"
" ColumnProperty (i.e. does not correspond"
- " directly to a Column)." % key)
+ " directly to a Column)." % key
+ )
return getattr(self.cls, key)
+
inspection._inspects(_GetColumns)(
- lambda target: inspection.inspect(target.cls))
+ lambda target: inspection.inspect(target.cls)
+)
class _GetTable(object):
- __slots__ = 'key', 'metadata'
+ __slots__ = "key", "metadata"
def __init__(self, key, metadata):
self.key = key
self.metadata = metadata
def __getattr__(self, key):
- return self.metadata.tables[
- _get_table_key(key, self.key)
- ]
+ return self.metadata.tables[_get_table_key(key, self.key)]
def _determine_container(key, value):
@@ -264,9 +274,11 @@ class _class_resolver(object):
return cls.metadata.tables[key]
elif key in cls.metadata._schemas:
return _GetTable(key, cls.metadata)
- elif '_sa_module_registry' in cls._decl_class_registry and \
- key in cls._decl_class_registry['_sa_module_registry']:
- registry = cls._decl_class_registry['_sa_module_registry']
+ elif (
+ "_sa_module_registry" in cls._decl_class_registry
+ and key in cls._decl_class_registry["_sa_module_registry"]
+ ):
+ registry = cls._decl_class_registry["_sa_module_registry"]
return registry.resolve_attr(key)
elif self._resolvers:
for resolv in self._resolvers:
@@ -289,8 +301,8 @@ class _class_resolver(object):
"When initializing mapper %s, expression %r failed to "
"locate a name (%r). If this is a class name, consider "
"adding this relationship() to the %r class after "
- "both dependent classes have been defined." %
- (self.prop.parent, self.arg, n.args[0], self.cls)
+ "both dependent classes have been defined."
+ % (self.prop.parent, self.arg, n.args[0], self.cls)
)
@@ -299,10 +311,11 @@ def _resolver(cls, prop):
from sqlalchemy.orm import foreign, remote
fallback = sqlalchemy.__dict__.copy()
- fallback.update({'foreign': foreign, 'remote': remote})
+ fallback.update({"foreign": foreign, "remote": remote})
def resolve_arg(arg):
return _class_resolver(cls, prop, fallback, arg)
+
return resolve_arg
@@ -311,18 +324,32 @@ def _deferred_relationship(cls, prop):
if isinstance(prop, RelationshipProperty):
resolve_arg = _resolver(cls, prop)
- for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin',
- 'secondary', '_user_defined_foreign_keys', 'remote_side'):
+ for attr in (
+ "argument",
+ "order_by",
+ "primaryjoin",
+ "secondaryjoin",
+ "secondary",
+ "_user_defined_foreign_keys",
+ "remote_side",
+ ):
v = getattr(prop, attr)
if isinstance(v, util.string_types):
setattr(prop, attr, resolve_arg(v))
if prop.backref and isinstance(prop.backref, tuple):
key, kwargs = prop.backref
- for attr in ('primaryjoin', 'secondaryjoin', 'secondary',
- 'foreign_keys', 'remote_side', 'order_by'):
- if attr in kwargs and isinstance(kwargs[attr],
- util.string_types):
+ for attr in (
+ "primaryjoin",
+ "secondaryjoin",
+ "secondary",
+ "foreign_keys",
+ "remote_side",
+ "order_by",
+ ):
+ if attr in kwargs and isinstance(
+ kwargs[attr], util.string_types
+ ):
kwargs[attr] = resolve_arg(kwargs[attr])
return prop
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index f86e4fc93..7248e5b4d 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -20,7 +20,7 @@ from .. import util
from ..orm.session import Session
from ..orm.query import Query
-__all__ = ['ShardedSession', 'ShardedQuery']
+__all__ = ["ShardedSession", "ShardedQuery"]
class ShardedQuery(Query):
@@ -43,12 +43,10 @@ class ShardedQuery(Query):
def _execute_and_instances(self, context):
def iter_for_shard(shard_id):
- context.attributes['shard_id'] = context.identity_token = shard_id
+ context.attributes["shard_id"] = context.identity_token = shard_id
result = self._connection_from_session(
- mapper=self._bind_mapper(),
- shard_id=shard_id).execute(
- context.statement,
- self._params)
+ mapper=self._bind_mapper(), shard_id=shard_id
+ ).execute(context.statement, self._params)
return self.instances(result, context)
if context.identity_token is not None:
@@ -70,7 +68,8 @@ class ShardedQuery(Query):
mapper=mapper,
shard_id=shard_id,
clause=stmt,
- close_with_result=True)
+ close_with_result=True,
+ )
result = conn.execute(stmt, self._params)
return result
@@ -87,8 +86,13 @@ class ShardedQuery(Query):
return ShardedResult(results, rowcount)
def _identity_lookup(
- self, mapper, primary_key_identity, identity_token=None,
- lazy_loaded_from=None, **kw):
+ self,
+ mapper,
+ primary_key_identity,
+ identity_token=None,
+ lazy_loaded_from=None,
+ **kw
+ ):
"""override the default Query._identity_lookup method so that we
search for a given non-token primary key identity across all
possible identity tokens (e.g. shard ids).
@@ -97,8 +101,10 @@ class ShardedQuery(Query):
if identity_token is not None:
return super(ShardedQuery, self)._identity_lookup(
- mapper, primary_key_identity,
- identity_token=identity_token, **kw
+ mapper,
+ primary_key_identity,
+ identity_token=identity_token,
+ **kw
)
else:
q = self.session.query(mapper)
@@ -113,13 +119,13 @@ class ShardedQuery(Query):
return None
- def _get_impl(
- self, primary_key_identity, db_load_fn, identity_token=None):
+ def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None):
"""Override the default Query._get_impl() method so that we emit
a query to the DB for each possible identity token, if we don't
have one already.
"""
+
def _db_load_fn(query, primary_key_identity):
# load from the database. The original db_load_fn will
# use the given Query object to load from the DB, so our
@@ -142,7 +148,8 @@ class ShardedQuery(Query):
identity_token = self._shard_id
return super(ShardedQuery, self)._get_impl(
- primary_key_identity, _db_load_fn, identity_token=identity_token)
+ primary_key_identity, _db_load_fn, identity_token=identity_token
+ )
class ShardedResult(object):
@@ -158,7 +165,7 @@ class ShardedResult(object):
.. versionadded:: 1.3
"""
- __slots__ = ('result_proxies', 'aggregate_rowcount',)
+ __slots__ = ("result_proxies", "aggregate_rowcount")
def __init__(self, result_proxies, aggregate_rowcount):
self.result_proxies = result_proxies
@@ -168,9 +175,17 @@ class ShardedResult(object):
def rowcount(self):
return self.aggregate_rowcount
+
class ShardedSession(Session):
- def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None,
- query_cls=ShardedQuery, **kwargs):
+ def __init__(
+ self,
+ shard_chooser,
+ id_chooser,
+ query_chooser,
+ shards=None,
+ query_cls=ShardedQuery,
+ **kwargs
+ ):
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped
@@ -225,16 +240,16 @@ class ShardedSession(Session):
return self.transaction.connection(mapper, shard_id=shard_id)
else:
return self.get_bind(
- mapper,
- shard_id=shard_id,
- instance=instance
+ mapper, shard_id=shard_id, instance=instance
).contextual_connect(**kwargs)
- def get_bind(self, mapper, shard_id=None,
- instance=None, clause=None, **kw):
+ def get_bind(
+ self, mapper, shard_id=None, instance=None, clause=None, **kw
+ ):
if shard_id is None:
shard_id = self._choose_shard_and_assign(
- mapper, instance, clause=clause)
+ mapper, instance, clause=clause
+ )
return self.__binds[shard_id]
def bind_shard(self, shard_id, bind):
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
index 95eecb93f..d51a083da 100644
--- a/lib/sqlalchemy/ext/hybrid.py
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -778,7 +778,7 @@ there's probably a whole lot of amazing things it can be used for.
from .. import util
from ..orm import attributes, interfaces
-HYBRID_METHOD = util.symbol('HYBRID_METHOD')
+HYBRID_METHOD = util.symbol("HYBRID_METHOD")
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.hybrid_method`.
@@ -791,7 +791,7 @@ HYBRID_METHOD = util.symbol('HYBRID_METHOD')
"""
-HYBRID_PROPERTY = util.symbol('HYBRID_PROPERTY')
+HYBRID_PROPERTY = util.symbol("HYBRID_PROPERTY")
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.hybrid_method`.
@@ -860,8 +860,14 @@ class hybrid_property(interfaces.InspectionAttrInfo):
extension_type = HYBRID_PROPERTY
def __init__(
- self, fget, fset=None, fdel=None,
- expr=None, custom_comparator=None, update_expr=None):
+ self,
+ fget,
+ fset=None,
+ fdel=None,
+ expr=None,
+ custom_comparator=None,
+ update_expr=None,
+ ):
"""Create a new :class:`.hybrid_property`.
Usage is typically via decorator::
@@ -906,7 +912,8 @@ class hybrid_property(interfaces.InspectionAttrInfo):
defaults = {
key: value
for key, value in self.__dict__.items()
- if not key.startswith("_")}
+ if not key.startswith("_")
+ }
defaults.update(**kw)
return type(self)(**defaults)
@@ -1078,9 +1085,9 @@ class hybrid_property(interfaces.InspectionAttrInfo):
return self._get_expr(self.fget)
def _get_expr(self, expr):
-
def _expr(cls):
return ExprComparator(cls, expr(cls), self)
+
util.update_wrapper(_expr, expr)
return self._get_comparator(_expr)
@@ -1091,8 +1098,13 @@ class hybrid_property(interfaces.InspectionAttrInfo):
def expr_comparator(owner):
return proxy_attr(
- owner, self.__name__, self, comparator(owner),
- doc=comparator.__doc__ or self.__doc__)
+ owner,
+ self.__name__,
+ self,
+ comparator(owner),
+ doc=comparator.__doc__ or self.__doc__,
+ )
+
return expr_comparator
@@ -1108,7 +1120,7 @@ class Comparator(interfaces.PropComparator):
def __clause_element__(self):
expr = self.expression
- if hasattr(expr, '__clause_element__'):
+ if hasattr(expr, "__clause_element__"):
expr = expr.__clause_element__()
return expr
diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py
index 0bc2b65bb..368e5b00a 100644
--- a/lib/sqlalchemy/ext/indexable.py
+++ b/lib/sqlalchemy/ext/indexable.py
@@ -232,7 +232,7 @@ from ..orm.attributes import flag_modified
from ..ext.hybrid import hybrid_property
-__all__ = ['index_property']
+__all__ = ["index_property"]
class index_property(hybrid_property): # noqa
@@ -251,8 +251,14 @@ class index_property(hybrid_property): # noqa
_NO_DEFAULT_ARGUMENT = object()
def __init__(
- self, attr_name, index, default=_NO_DEFAULT_ARGUMENT,
- datatype=None, mutable=True, onebased=True):
+ self,
+ attr_name,
+ index,
+ default=_NO_DEFAULT_ARGUMENT,
+ datatype=None,
+ mutable=True,
+ onebased=True,
+ ):
"""Create a new :class:`.index_property`.
:param attr_name:
diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py
index 30a0ab7d7..b2b8dd7c5 100644
--- a/lib/sqlalchemy/ext/instrumentation.py
+++ b/lib/sqlalchemy/ext/instrumentation.py
@@ -28,15 +28,18 @@ see the example :ref:`examples_instrumentation`.
"""
from ..orm import instrumentation as orm_instrumentation
from ..orm.instrumentation import (
- ClassManager, InstrumentationFactory, _default_state_getter,
- _default_dict_getter, _default_manager_getter
+ ClassManager,
+ InstrumentationFactory,
+ _default_state_getter,
+ _default_dict_getter,
+ _default_manager_getter,
)
from ..orm import attributes, collections, base as orm_base
from .. import util
from ..orm import exc as orm_exc
import weakref
-INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__'
+INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__"
"""Attribute, elects custom instrumentation when present on a mapped class.
Allows a class to specify a slightly or wildly different technique for
@@ -66,6 +69,7 @@ def find_native_user_instrumentation_hook(cls):
"""Find user-specified instrumentation management for a class."""
return getattr(cls, INSTRUMENTATION_MANAGER, None)
+
instrumentation_finders = [find_native_user_instrumentation_hook]
"""An extensible sequence of callables which return instrumentation
implementations
@@ -89,6 +93,7 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory):
class managers.
"""
+
_manager_finders = weakref.WeakKeyDictionary()
_state_finders = weakref.WeakKeyDictionary()
_dict_finders = weakref.WeakKeyDictionary()
@@ -104,13 +109,15 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory):
return None, None
def _check_conflicts(self, class_, factory):
- existing_factories = self._collect_management_factories_for(class_).\
- difference([factory])
+ existing_factories = self._collect_management_factories_for(
+ class_
+ ).difference([factory])
if existing_factories:
raise TypeError(
"multiple instrumentation implementations specified "
- "in %s inheritance hierarchy: %r" % (
- class_.__name__, list(existing_factories)))
+ "in %s inheritance hierarchy: %r"
+ % (class_.__name__, list(existing_factories))
+ )
def _extended_class_manager(self, class_, factory):
manager = factory(class_)
@@ -178,17 +185,20 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory):
if instance is None:
raise AttributeError("None has no persistent state.")
return self._state_finders.get(
- instance.__class__, _default_state_getter)(instance)
+ instance.__class__, _default_state_getter
+ )(instance)
def dict_of(self, instance):
if instance is None:
raise AttributeError("None has no persistent state.")
return self._dict_finders.get(
- instance.__class__, _default_dict_getter)(instance)
+ instance.__class__, _default_dict_getter
+ )(instance)
-orm_instrumentation._instrumentation_factory = \
- _instrumentation_factory = ExtendedInstrumentationRegistry()
+orm_instrumentation._instrumentation_factory = (
+ _instrumentation_factory
+) = ExtendedInstrumentationRegistry()
orm_instrumentation.instrumentation_finders = instrumentation_finders
@@ -222,14 +232,15 @@ class InstrumentationManager(object):
pass
def manage(self, class_, manager):
- setattr(class_, '_default_class_manager', manager)
+ setattr(class_, "_default_class_manager", manager)
def dispose(self, class_, manager):
- delattr(class_, '_default_class_manager')
+ delattr(class_, "_default_class_manager")
def manager_getter(self, class_):
def get(cls):
return cls._default_class_manager
+
return get
def instrument_attribute(self, class_, key, inst):
@@ -260,13 +271,13 @@ class InstrumentationManager(object):
pass
def install_state(self, class_, instance, state):
- setattr(instance, '_default_state', state)
+ setattr(instance, "_default_state", state)
def remove_state(self, class_, instance):
- delattr(instance, '_default_state')
+ delattr(instance, "_default_state")
def state_getter(self, class_):
- return lambda instance: getattr(instance, '_default_state')
+ return lambda instance: getattr(instance, "_default_state")
def dict_getter(self, class_):
return lambda inst: self.get_instance_dict(class_, inst)
@@ -314,15 +325,17 @@ class _ClassInstrumentationAdapter(ClassManager):
def instrument_collection_class(self, key, collection_class):
return self._adapted.instrument_collection_class(
- self.class_, key, collection_class)
+ self.class_, key, collection_class
+ )
def initialize_collection(self, key, state, factory):
- delegate = getattr(self._adapted, 'initialize_collection', None)
+ delegate = getattr(self._adapted, "initialize_collection", None)
if delegate:
return delegate(key, state, factory)
else:
- return ClassManager.initialize_collection(self, key,
- state, factory)
+ return ClassManager.initialize_collection(
+ self, key, state, factory
+ )
def new_instance(self, state=None):
instance = self.class_.__new__(self.class_)
@@ -384,7 +397,7 @@ def _install_instrumented_lookups():
dict(
instance_state=_instrumentation_factory.state_of,
instance_dict=_instrumentation_factory.dict_of,
- manager_of_class=_instrumentation_factory.manager_of_class
+ manager_of_class=_instrumentation_factory.manager_of_class,
)
)
@@ -395,7 +408,7 @@ def _reinstall_default_lookups():
dict(
instance_state=_default_state_getter,
instance_dict=_default_dict_getter,
- manager_of_class=_default_manager_getter
+ manager_of_class=_default_manager_getter,
)
)
_instrumentation_factory._extended = False
@@ -403,12 +416,15 @@ def _reinstall_default_lookups():
def _install_lookups(lookups):
global instance_state, instance_dict, manager_of_class
- instance_state = lookups['instance_state']
- instance_dict = lookups['instance_dict']
- manager_of_class = lookups['manager_of_class']
- orm_base.instance_state = attributes.instance_state = \
- orm_instrumentation.instance_state = instance_state
- orm_base.instance_dict = attributes.instance_dict = \
- orm_instrumentation.instance_dict = instance_dict
- orm_base.manager_of_class = attributes.manager_of_class = \
- orm_instrumentation.manager_of_class = manager_of_class
+ instance_state = lookups["instance_state"]
+ instance_dict = lookups["instance_dict"]
+ manager_of_class = lookups["manager_of_class"]
+ orm_base.instance_state = (
+ attributes.instance_state
+ ) = orm_instrumentation.instance_state = instance_state
+ orm_base.instance_dict = (
+ attributes.instance_dict
+ ) = orm_instrumentation.instance_dict = instance_dict
+ orm_base.manager_of_class = (
+ attributes.manager_of_class
+ ) = orm_instrumentation.manager_of_class = manager_of_class
diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py
index 014cef3cc..0f6ccdc33 100644
--- a/lib/sqlalchemy/ext/mutable.py
+++ b/lib/sqlalchemy/ext/mutable.py
@@ -502,27 +502,29 @@ class MutableBase(object):
def pickle(state, state_dict):
val = state.dict.get(key, None)
if val is not None:
- if 'ext.mutable.values' not in state_dict:
- state_dict['ext.mutable.values'] = []
- state_dict['ext.mutable.values'].append(val)
+ if "ext.mutable.values" not in state_dict:
+ state_dict["ext.mutable.values"] = []
+ state_dict["ext.mutable.values"].append(val)
def unpickle(state, state_dict):
- if 'ext.mutable.values' in state_dict:
- for val in state_dict['ext.mutable.values']:
+ if "ext.mutable.values" in state_dict:
+ for val in state_dict["ext.mutable.values"]:
val._parents[state.obj()] = key
- event.listen(parent_cls, 'load', load,
- raw=True, propagate=True)
- event.listen(parent_cls, 'refresh', load_attrs,
- raw=True, propagate=True)
- event.listen(parent_cls, 'refresh_flush', load_attrs,
- raw=True, propagate=True)
- event.listen(attribute, 'set', set,
- raw=True, retval=True, propagate=True)
- event.listen(parent_cls, 'pickle', pickle,
- raw=True, propagate=True)
- event.listen(parent_cls, 'unpickle', unpickle,
- raw=True, propagate=True)
+ event.listen(parent_cls, "load", load, raw=True, propagate=True)
+ event.listen(
+ parent_cls, "refresh", load_attrs, raw=True, propagate=True
+ )
+ event.listen(
+ parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True
+ )
+ event.listen(
+ attribute, "set", set, raw=True, retval=True, propagate=True
+ )
+ event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True)
+ event.listen(
+ parent_cls, "unpickle", unpickle, raw=True, propagate=True
+ )
class Mutable(MutableBase):
@@ -572,7 +574,7 @@ class Mutable(MutableBase):
if isinstance(prop.columns[0].type, sqltype):
cls.associate_with_attribute(getattr(class_, prop.key))
- event.listen(mapper, 'mapper_configured', listen_for_type)
+ event.listen(mapper, "mapper_configured", listen_for_type)
@classmethod
def as_mutable(cls, sqltype):
@@ -613,9 +615,11 @@ class Mutable(MutableBase):
# and we'll lose our ability to link that type back to the original.
# so track our original type w/ columns
if isinstance(sqltype, SchemaEventTarget):
+
@event.listens_for(sqltype, "before_parent_attach")
def _add_column_memo(sqltyp, parent):
- parent.info['_ext_mutable_orig_type'] = sqltyp
+ parent.info["_ext_mutable_orig_type"] = sqltyp
+
schema_event_check = True
else:
schema_event_check = False
@@ -625,16 +629,14 @@ class Mutable(MutableBase):
return
for prop in mapper.column_attrs:
if (
- schema_event_check and
- hasattr(prop.expression, 'info') and
- prop.expression.info.get('_ext_mutable_orig_type')
- is sqltype
- ) or (
- prop.columns[0].type is sqltype
- ):
+ schema_event_check
+ and hasattr(prop.expression, "info")
+ and prop.expression.info.get("_ext_mutable_orig_type")
+ is sqltype
+ ) or (prop.columns[0].type is sqltype):
cls.associate_with_attribute(getattr(class_, prop.key))
- event.listen(mapper, 'mapper_configured', listen_for_type)
+ event.listen(mapper, "mapper_configured", listen_for_type)
return sqltype
@@ -659,21 +661,27 @@ class MutableComposite(MutableBase):
prop = object_mapper(parent).get_property(key)
for value, attr_name in zip(
- self.__composite_values__(),
- prop._attribute_keys):
+ self.__composite_values__(), prop._attribute_keys
+ ):
setattr(parent, attr_name, value)
def _setup_composite_listener():
def _listen_for_type(mapper, class_):
for prop in mapper.iterate_properties:
- if (hasattr(prop, 'composite_class') and
- isinstance(prop.composite_class, type) and
- issubclass(prop.composite_class, MutableComposite)):
+ if (
+ hasattr(prop, "composite_class")
+ and isinstance(prop.composite_class, type)
+ and issubclass(prop.composite_class, MutableComposite)
+ ):
prop.composite_class._listen_on_attribute(
- getattr(class_, prop.key), False, class_)
+ getattr(class_, prop.key), False, class_
+ )
+
if not event.contains(Mapper, "mapper_configured", _listen_for_type):
- event.listen(Mapper, 'mapper_configured', _listen_for_type)
+ event.listen(Mapper, "mapper_configured", _listen_for_type)
+
+
_setup_composite_listener()
@@ -947,4 +955,4 @@ class MutableSet(Mutable, set):
self.update(state)
def __reduce_ex__(self, proto):
- return (self.__class__, (list(self), ))
+ return (self.__class__, (list(self),))
diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py
index 316742a67..2a8522120 100644
--- a/lib/sqlalchemy/ext/orderinglist.py
+++ b/lib/sqlalchemy/ext/orderinglist.py
@@ -122,7 +122,7 @@ start numbering at 1 or some other integer, provide ``count_from=1``.
from ..orm.collections import collection, collection_adapter
from .. import util
-__all__ = ['ordering_list']
+__all__ = ["ordering_list"]
def ordering_list(attr, count_from=None, **kw):
@@ -180,8 +180,9 @@ def count_from_n_factory(start):
def f(index, collection):
return index + start
+
try:
- f.__name__ = 'count_from_%i' % start
+ f.__name__ = "count_from_%i" % start
except TypeError:
pass
return f
@@ -194,14 +195,14 @@ def _unsugar_count_from(**kw):
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
"""
- count_from = kw.pop('count_from', None)
- if kw.get('ordering_func', None) is None and count_from is not None:
+ count_from = kw.pop("count_from", None)
+ if kw.get("ordering_func", None) is None and count_from is not None:
if count_from == 0:
- kw['ordering_func'] = count_from_0
+ kw["ordering_func"] = count_from_0
elif count_from == 1:
- kw['ordering_func'] = count_from_1
+ kw["ordering_func"] = count_from_1
else:
- kw['ordering_func'] = count_from_n_factory(count_from)
+ kw["ordering_func"] = count_from_n_factory(count_from)
return kw
@@ -214,8 +215,9 @@ class OrderingList(list):
"""
- def __init__(self, ordering_attr=None, ordering_func=None,
- reorder_on_append=False):
+ def __init__(
+ self, ordering_attr=None, ordering_func=None, reorder_on_append=False
+ ):
"""A custom list that manages position information for its children.
``OrderingList`` is a ``collection_class`` list implementation that
@@ -311,6 +313,7 @@ class OrderingList(list):
"""Append without any ordering behavior."""
super(OrderingList, self).append(entity)
+
_raw_append = collection.adds(1)(_raw_append)
def insert(self, index, entity):
@@ -361,8 +364,12 @@ class OrderingList(list):
return _reconstitute, (self.__class__, self.__dict__, list(self))
for func_name, func in list(locals().items()):
- if (util.callable(func) and func.__name__ == func_name and
- not func.__doc__ and hasattr(list, func_name)):
+ if (
+ util.callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py
index 2fded51d1..3adcec34f 100644
--- a/lib/sqlalchemy/ext/serializer.py
+++ b/lib/sqlalchemy/ext/serializer.py
@@ -64,7 +64,7 @@ from ..util import pickle, byte_buffer, b64encode, b64decode, text_type
import re
-__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads']
+__all__ = ["Serializer", "Deserializer", "dumps", "loads"]
def Serializer(*args, **kw):
@@ -79,13 +79,18 @@ def Serializer(*args, **kw):
elif isinstance(obj, Mapper) and not obj.non_primary:
id = "mapper:" + b64encode(pickle.dumps(obj.class_))
elif isinstance(obj, MapperProperty) and not obj.parent.non_primary:
- id = "mapperprop:" + b64encode(pickle.dumps(obj.parent.class_)) + \
- ":" + obj.key
+ id = (
+ "mapperprop:"
+ + b64encode(pickle.dumps(obj.parent.class_))
+ + ":"
+ + obj.key
+ )
elif isinstance(obj, Table):
id = "table:" + text_type(obj.key)
elif isinstance(obj, Column) and isinstance(obj.table, Table):
- id = "column:" + \
- text_type(obj.table.key) + ":" + text_type(obj.key)
+ id = (
+ "column:" + text_type(obj.table.key) + ":" + text_type(obj.key)
+ )
elif isinstance(obj, Session):
id = "session:"
elif isinstance(obj, Engine):
@@ -97,8 +102,10 @@ def Serializer(*args, **kw):
pickler.persistent_id = persistent_id
return pickler
+
our_ids = re.compile(
- r'(mapperprop|mapper|table|column|session|attribute|engine):(.*)')
+ r"(mapperprop|mapper|table|column|session|attribute|engine):(.*)"
+)
def Deserializer(file, metadata=None, scoped_session=None, engine=None):
@@ -120,7 +127,7 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
return None
else:
type_, args = m.group(1, 2)
- if type_ == 'attribute':
+ if type_ == "attribute":
key, clsarg = args.split(":")
cls = pickle.loads(b64decode(clsarg))
return getattr(cls, key)
@@ -128,13 +135,13 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
cls = pickle.loads(b64decode(args))
return class_mapper(cls)
elif type_ == "mapperprop":
- mapper, keyname = args.split(':')
+ mapper, keyname = args.split(":")
cls = pickle.loads(b64decode(mapper))
return class_mapper(cls).attrs[keyname]
elif type_ == "table":
return metadata.tables[args]
elif type_ == "column":
- table, colname = args.split(':')
+ table, colname = args.split(":")
return metadata.tables[table].c[colname]
elif type_ == "session":
return scoped_session()
@@ -142,6 +149,7 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
return get_engine()
else:
raise Exception("Unknown token: %s" % type_)
+
unpickler.persistent_load = persistent_load
return unpickler