diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/path_registry.py | 52 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 33 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategy_options.py | 101 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/util.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/annotation.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/_collections.py | 20 |
16 files changed, 187 insertions, 81 deletions
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 704ce9df7..2bb2eb767 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -525,9 +525,7 @@ class StrategizedProperty(MapperProperty): def _get_context_loader(self, context, path): load = None - # use EntityRegistry.__getitem__()->PropRegistry here so - # that the path is stated in terms of our base - search_path = dict.__getitem__(path, self) + search_path = path[self] # search among: exact match, "attr.*", "default" strategy # if any. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index aa350c7ba..c23aaf9ef 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2807,9 +2807,6 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): if self.base_mapper.local_table in tables: return None - class ColumnsNotAvailable(Exception): - pass - def visit_binary(binary): leftcol = binary.left rightcol = binary.right @@ -2824,7 +2821,7 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): passive=attributes.PASSIVE_NO_INITIALIZE, ) if leftval in orm_util._none_set: - raise ColumnsNotAvailable() + raise _OptGetColumnsNotAvailable() binary.left = sql.bindparam( None, leftval, type_=binary.right.type ) @@ -2836,7 +2833,7 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): passive=attributes.PASSIVE_NO_INITIALIZE, ) if rightval in orm_util._none_set: - raise ColumnsNotAvailable() + raise _OptGetColumnsNotAvailable() binary.right = sql.bindparam( None, rightval, type_=binary.right.type ) @@ -2860,7 +2857,7 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): {"binary": visit_binary}, ) ) - except ColumnsNotAvailable: + except _OptGetColumnsNotAvailable: return None cond = sql.and_(*allconds) @@ -3128,6 +3125,10 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): return result +class _OptGetColumnsNotAvailable(Exception): + pass + + def configure_mappers(): """Initialize the inter-mapper relationships of all mappers that have been constructed thus far. diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 585cb80bc..d2b82459e 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -55,6 +55,8 @@ class PathRegistry(HasCacheKey): """ + __slots__ = () + is_token = False is_root = False @@ -167,7 +169,10 @@ class PathRegistry(HasCacheKey): @classmethod def per_mapper(cls, mapper): - return EntityRegistry(cls.root, mapper) + if mapper.is_mapper: + return CachingEntityRegistry(cls.root, mapper) + else: + return SlotsEntityRegistry(cls.root, mapper) @classmethod def coerce(cls, raw): @@ -207,6 +212,8 @@ PathRegistry.root = RootRegistry() class TokenRegistry(PathRegistry): + __slots__ = ("token", "parent", "path", "natural_path") + def __init__(self, parent, token): self.token = token self.parent = parent @@ -257,7 +264,7 @@ class PropRegistry(PathRegistry): self._wildcard_path_loader_key = ( "loader", - self.parent.path + self.prop._wildcard_token, + parent.path + self.prop._wildcard_token, ) self._default_path_loader_key = self.prop._default_path_loader_key self._loader_key = ("loader", self.path) @@ -285,11 +292,12 @@ class PropRegistry(PathRegistry): if isinstance(entity, (int, slice)): return self.path[entity] else: - return EntityRegistry(self, entity) + return SlotsEntityRegistry(self, entity) -class EntityRegistry(PathRegistry, dict): - is_aliased_class = False +class AbstractEntityRegistry(PathRegistry): + __slots__ = () + has_entity = True def __init__(self, parent, entity): @@ -307,7 +315,10 @@ class EntityRegistry(PathRegistry, dict): # are usually not present in mappings. So here we track both the # "enhanced" path in self.path and the "natural" path that doesn't # include those objects so these two traversals can be matched up. + if parent.path and self.is_aliased_class: + # this is an infrequent code path used only for loader strategies + # that also make use of of_type(). if entity.mapper.isa(parent.natural_path[-1].entity): self.natural_path = parent.natural_path + (entity.mapper,) else: @@ -316,7 +327,10 @@ class EntityRegistry(PathRegistry, dict): ) else: self.natural_path = self.path - self.entity_path = self + + @property + def entity_path(self): + return self @property def mapper(self): @@ -331,8 +345,34 @@ class EntityRegistry(PathRegistry, dict): if isinstance(entity, (int, slice)): return self.path[entity] else: + return PropRegistry(self, entity) + + +class SlotsEntityRegistry(AbstractEntityRegistry): + # for aliased class, return lightweight, no-cycles created + # version + + __slots__ = ( + "key", + "parent", + "is_aliased_class", + "entity", + "path", + "natural_path", + ) + + +class CachingEntityRegistry(AbstractEntityRegistry, dict): + # for long lived mapper, return dict based caching + # version that creates reference cycles + + def __getitem__(self, entity): + if isinstance(entity, (int, slice)): + return self.path[entity] + else: return dict.__getitem__(self, entity) def __missing__(self, key): self[key] = item = PropRegistry(self, key) + return item diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 22e24be84..f492fa254 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2235,7 +2235,7 @@ class Query(Generative): while "prev" in jp: f, prev = jp["prev"] prev = prev.copy() - prev[f] = jp + prev[f] = jp.copy() jp["prev"] = (f, prev) jp = prev self._joinpath = jp diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 7ee12ca6c..656e0b53d 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -2283,6 +2283,7 @@ def _annotate_columns(element, annotations): if element is not None: element = clone(element) + clone = None # remove gc cycles return element diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 38d379a4c..38ddb70f7 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1070,11 +1070,11 @@ class SubqueryLoader(PostLoader): # build up a path indicating the path from the leftmost # entity to the thing we're subquery loading. - with_poly_info = path.get( + with_poly_entity = path.get( context.attributes, "path_with_polymorphic", None ) - if with_poly_info is not None: - effective_entity = with_poly_info.entity + if with_poly_entity is not None: + effective_entity = with_poly_entity else: effective_entity = self.entity @@ -1571,11 +1571,13 @@ class JoinedLoader(AbstractRelationshipLoader): chained_from_outerjoin, ) - with_poly_info = path.get( + with_poly_entity = path.get( context.attributes, "path_with_polymorphic", None ) - if with_poly_info is not None: - with_polymorphic = with_poly_info.with_polymorphic_mappers + if with_poly_entity is not None: + with_polymorphic = inspect( + with_poly_entity + ).with_polymorphic_mappers else: with_polymorphic = None @@ -1593,7 +1595,7 @@ class JoinedLoader(AbstractRelationshipLoader): chained_from_outerjoin=chained_from_outerjoin, ) - if with_poly_info is not None and None in set( + if with_poly_entity is not None and None in set( context.secondary_columns ): raise sa_exc.InvalidRequestError( @@ -1622,7 +1624,6 @@ class JoinedLoader(AbstractRelationshipLoader): # otherwise figure it out. alias = loadopt.local_opts["eager_from_alias"] - root_mapper, prop = path[-2:] if alias is not None: @@ -1633,11 +1634,11 @@ class JoinedLoader(AbstractRelationshipLoader): ) else: if path.contains(context.attributes, "path_with_polymorphic"): - with_poly_info = path.get( + with_poly_entity = path.get( context.attributes, "path_with_polymorphic" ) adapter = orm_util.ORMAdapter( - with_poly_info.entity, + with_poly_entity, equivalents=prop.mapper._equivalent_columns, ) else: @@ -1721,11 +1722,11 @@ class JoinedLoader(AbstractRelationshipLoader): parentmapper, chained_from_outerjoin, ): - with_poly_info = path.get( + with_poly_entity = path.get( context.attributes, "path_with_polymorphic", None ) - if with_poly_info: - to_adapt = with_poly_info.entity + if with_poly_entity: + to_adapt = with_poly_entity else: to_adapt = self._gen_pooled_aliased_class(context) @@ -2284,12 +2285,12 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): # build up a path indicating the path from the leftmost # entity to the thing we're subquery loading. - with_poly_info = path_w_prop.get( + with_poly_entity = path_w_prop.get( context.attributes, "path_with_polymorphic", None ) - if with_poly_info is not None: - effective_entity = with_poly_info.entity + if with_poly_entity is not None: + effective_entity = with_poly_entity else: effective_entity = self.entity diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 99bbbe37c..a2529e3ce 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -90,7 +90,6 @@ class Load(HasCacheKey, Generative, MapperOption): # Load objects self.context = util.OrderedDict() self.local_opts = {} - self._of_type = None self.is_class_strategy = False @classmethod @@ -105,6 +104,8 @@ class Load(HasCacheKey, Generative, MapperOption): @property def _context_cache_key(self): serialized = [] + if self.context is None: + return [] for (key, loader_path), obj in self.context.items(): if key != "loader": continue @@ -190,6 +191,7 @@ class Load(HasCacheKey, Generative, MapperOption): is_class_strategy = False strategy = None propagate_to_loaders = False + _of_type = None def process_query(self, query): self._process(query, True) @@ -316,12 +318,15 @@ class Load(HasCacheKey, Generative, MapperOption): ext_info.mapper, aliased=True, _use_mapper_path=True, - _existing_alias=existing, + _existing_alias=inspect(existing) + if existing is not None + else None, ) + ext_info = inspect(ac) path.entity_path[prop].set( - self.context, "path_with_polymorphic", ext_info + self.context, "path_with_polymorphic", ac ) path = path[prop][ext_info] @@ -399,13 +404,13 @@ class Load(HasCacheKey, Generative, MapperOption): self, attr, strategy, propagate_to_loaders=True ): strategy = self._coerce_strat(strategy) - self.is_class_strategy = False + self.propagate_to_loaders = propagate_to_loaders - # if the path is a wildcard, this will set propagate_to_loaders=False - self._generate_path(self.path, attr, strategy, "relationship") - self.strategy = strategy - if strategy is not None: - self._set_path_strategy() + cloned = self._clone_for_bind_strategy(attr, strategy, "relationship") + self.path = cloned.path + self._of_type = cloned._of_type + cloned.is_class_strategy = self.is_class_strategy = False + self.propagate_to_loaders = cloned.propagate_to_loaders @_generative def set_column_strategy(self, attrs, strategy, opts=None, opts_only=False): @@ -413,41 +418,50 @@ class Load(HasCacheKey, Generative, MapperOption): self.is_class_strategy = False for attr in attrs: - cloned = self._generate() - cloned.strategy = strategy - cloned._generate_path(self.path, attr, strategy, "column") + cloned = self._clone_for_bind_strategy( + attr, strategy, "column", opts_only=opts_only, opts=opts + ) cloned.propagate_to_loaders = True - if opts: - cloned.local_opts.update(opts) - if opts_only: - cloned.is_opts_only = True - cloned._set_path_strategy() - self.is_class_strategy = False @_generative def set_generic_strategy(self, attrs, strategy): strategy = self._coerce_strat(strategy) for attr in attrs: - path = self._generate_path(self.path, attr, strategy, None) - cloned = self._generate() - cloned.strategy = strategy - cloned.path = path + cloned = self._clone_for_bind_strategy(attr, strategy, None) cloned.propagate_to_loaders = True - cloned._set_path_strategy() @_generative def set_class_strategy(self, strategy, opts): strategy = self._coerce_strat(strategy) - cloned = self._generate() + cloned = self._clone_for_bind_strategy(None, strategy, None) cloned.is_class_strategy = True - path = cloned._generate_path(self.path, None, strategy, None) - cloned.strategy = strategy - cloned.path = path cloned.propagate_to_loaders = True - cloned._set_path_strategy() cloned.local_opts.update(opts) + def _clone_for_bind_strategy( + self, attr, strategy, wildcard_key, opts_only=False, opts=None + ): + """Create an anonymous clone of the Load/_UnboundLoad that is suitable + to be placed in the context / _to_bind collection of this Load + object. The clone will then lose references to context/_to_bind + in order to not create reference cycles. + + """ + cloned = self._generate() + cloned._generate_path(self.path, attr, strategy, wildcard_key) + cloned.strategy = strategy + + cloned.local_opts = self.local_opts + if opts: + cloned.local_opts.update(opts) + if opts_only: + cloned.is_opts_only = True + + if strategy or cloned.is_opts_only: + cloned._set_path_strategy() + return cloned + def _set_for_path(self, context, path, replace=True, merge_opts=False): if merge_opts or not replace: existing = path.get(self.context, "loader") @@ -485,18 +499,24 @@ class Load(HasCacheKey, Generative, MapperOption): merge_opts=self.is_opts_only, ) + # remove cycles; _set_path_strategy is always invoked on an + # anonymous clone of the Load / UnboundLoad object since #5056 + self.context = None + def __getstate__(self): d = self.__dict__.copy() - d["context"] = PathRegistry.serialize_context_dict( - d["context"], ("loader",) - ) + if d["context"] is not None: + d["context"] = PathRegistry.serialize_context_dict( + d["context"], ("loader",) + ) d["path"] = self.path.serialize() return d def __setstate__(self, state): self.__dict__.update(state) self.path = PathRegistry.deserialize(self.path) - self.context = PathRegistry.deserialize_context_dict(self.context) + if self.context is not None: + self.context = PathRegistry.deserialize_context_dict(self.context) def _chop_path(self, to_chop, path): i = -1 @@ -575,10 +595,17 @@ class _UnboundLoad(Load): def _set_path_strategy(self): self._to_bind.append(self) - def _apply_to_parent(self, parent, applied, bound): + # remove cycles; _set_path_strategy is always invoked on an + # anonymous clone of the Load / UnboundLoad object since #5056 + self._to_bind = None + + def _apply_to_parent(self, parent, applied, bound, to_bind=None): if self in applied: return applied[self] + if to_bind is None: + to_bind = self._to_bind + cloned = self._generate() applied[self] = cloned @@ -601,8 +628,8 @@ class _UnboundLoad(Load): assert cloned.is_opts_only == self.is_opts_only new_to_bind = { - elem._apply_to_parent(parent, applied, bound) - for elem in self._to_bind + elem._apply_to_parent(parent, applied, bound, to_bind) + for elem in to_bind } cloned._to_bind = parent._to_bind cloned._to_bind.extend(new_to_bind) @@ -681,11 +708,13 @@ class _UnboundLoad(Load): all_tokens = [token for key in keys for token in _split_key(key)] for token in all_tokens[0:-1]: + # set _is_chain_link first so that clones of the + # object also inherit this flag + opt._is_chain_link = True if chained: opt = meth(opt, token, **kw) else: opt = opt.defaultload(token) - opt._is_chain_link = True opt = meth(opt, all_tokens[-1], **kw) opt._is_chain_link = False diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 2f78fa535..047eb9c16 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -8,6 +8,7 @@ import re import types +import weakref from . import attributes # noqa from .base import _class_to_mapper # noqa @@ -583,14 +584,14 @@ class AliasedInsp(sql_base.HasCacheKey, InspectionAttr): adapt_on_names, represents_outer_join, ): - self.entity = entity + self._weak_entity = weakref.ref(entity) self.mapper = mapper self.selectable = ( self.persist_selectable ) = self.local_table = selectable self.name = name self.polymorphic_on = polymorphic_on - self._base_alias = _base_alias or self + self._base_alias = weakref.ref(_base_alias or self) self._use_mapper_path = _use_mapper_path self.represents_outer_join = represents_outer_join @@ -625,6 +626,10 @@ class AliasedInsp(sql_base.HasCacheKey, InspectionAttr): self._adapt_on_names = adapt_on_names self._target = mapper.class_ + @property + def entity(self): + return self._weak_entity() + is_aliased_class = True "always returns True" @@ -643,7 +648,7 @@ class AliasedInsp(sql_base.HasCacheKey, InspectionAttr): :class:`.AliasedInsp`.""" return self.mapper.class_ - @util.memoized_property + @property def _path_registry(self): if self._use_mapper_path: return self.mapper._path_registry @@ -659,7 +664,7 @@ class AliasedInsp(sql_base.HasCacheKey, InspectionAttr): "adapt_on_names": self._adapt_on_names, "with_polymorphic_mappers": self.with_polymorphic_mappers, "with_polymorphic_discriminator": self.polymorphic_on, - "base_alias": self._base_alias, + "base_alias": self._base_alias(), "use_mapper_path": self._use_mapper_path, "represents_outer_join": self.represents_outer_join, } @@ -1241,7 +1246,7 @@ def _entity_corresponds_to(given, entity): """ if entity.is_aliased_class: if given.is_aliased_class: - if entity._base_alias is given._base_alias: + if entity._base_alias() is given._base_alias(): return True return False elif given.is_aliased_class: diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 9853cef2a..447bbe667 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -247,6 +247,7 @@ def _deep_annotate(element, annotations, exclude=None): if element is not None: element = clone(element) + clone = None # remove gc cycles return element @@ -271,6 +272,7 @@ def _deep_deannotate(element, values=None): if element is not None: element = clone(element) + clone = None # remove gc cycles return element diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 7d857d4fe..c16c2f0ca 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -255,6 +255,12 @@ class ClauseElement( """ s = util.column_set() f = self + + # note this creates a cycle, asserted in test_memusage. however, + # turning this into a plain @property adds tends of thousands of method + # calls to Core / ORM performance tests, so the small overhead + # introduced by the relatively small amount of short term cycles + # produced here is preferable while f is not None: s.add(f) f = f._is_clone_of diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index ed7a6c2b9..428acae6c 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3695,10 +3695,10 @@ class Select( clone(f, **kw) for f in self._from_obj ).union(f for f in new_froms.values() if isinstance(f, Join)) - self._correlate = set(clone(f) for f in self._correlate) + self._correlate = set(clone(f, **kw) for f in self._correlate) if self._correlate_except: self._correlate_except = set( - clone(f) for f in self._correlate_except + clone(f, **kw) for f in self._correlate_except ) # 4. clone other things. The difficulty here is that Column diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 84a5623d3..68a3a0749 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -29,6 +29,8 @@ def compare(obj1, obj2, **kw): class HasCacheKey(object): _cache_key_traversal = NO_CACHE + __slots__ = () + def _gen_cache_key(self, anon_map, bindparams): """return an optional cache key. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 8539f4845..546d989eb 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -226,6 +226,7 @@ def visit_binary_product(fn, expr): yield e list(visit(expr)) + visit = None # remove gc cycles def find_tables( @@ -881,7 +882,7 @@ class ColumnAdapter(ClauseAdapter): anonymize_labels=anonymize_labels, ) - self.columns = util.populate_column_dict(self._locate_col) + self.columns = util.WeakPopulateDict(self._locate_col) if self.include_fn or self.exclude_fn: self.columns = self._IncludeExcludeMapping(self, self.columns) self.adapt_required = adapt_required @@ -907,7 +908,7 @@ class ColumnAdapter(ClauseAdapter): ac = self.__class__.__new__(self.__class__) ac.__dict__.update(self.__dict__) ac._wrap = adapter - ac.columns = util.populate_column_dict(ac._locate_col) + ac.columns = util.WeakPopulateDict(ac._locate_col) if ac.include_fn or ac.exclude_fn: ac.columns = self._IncludeExcludeMapping(ac, ac.columns) @@ -942,4 +943,4 @@ class ColumnAdapter(ClauseAdapter): def __setstate__(self, state): self.__dict__.update(state) - self.columns = util.PopulateDict(self._locate_col) + self.columns = util.WeakPopulateDict(self._locate_col) diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index dcded3484..afa1506d2 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -722,7 +722,7 @@ def cloned_traverse(obj, opts, visitors): return newelem cloned[id(elem)] = newelem = elem._clone() - newelem._copy_internals(clone=clone) + newelem._copy_internals(clone=clone, **kw) meth = visitors.get(newelem.__visit_name__, None) if meth: meth(newelem) @@ -730,6 +730,7 @@ def cloned_traverse(obj, opts, visitors): if obj is not None: obj = clone(obj) + clone = None # remove gc cycles return obj @@ -786,4 +787,5 @@ def replacement_traverse(obj, opts, replace): if obj is not None: obj = clone(obj, **opts) + clone = None # remove gc cycles return obj diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 78155b08a..a0821a3cc 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -31,7 +31,6 @@ from ._collections import OrderedDict # noqa from ._collections import OrderedIdentitySet # noqa from ._collections import OrderedProperties # noqa from ._collections import OrderedSet # noqa -from ._collections import populate_column_dict # noqa from ._collections import PopulateDict # noqa from ._collections import Properties # noqa from ._collections import ScopedRegistry # noqa @@ -42,6 +41,7 @@ from ._collections import to_set # noqa from ._collections import unique_list # noqa from ._collections import UniqueAppender # noqa from ._collections import update_copy # noqa +from ._collections import WeakPopulateDict # noqa from ._collections import WeakSequence # noqa from .compat import b # noqa from .compat import b64decode # noqa diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 04409cfd9..7b46bc8e6 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -676,10 +676,13 @@ class IdentitySet(object): class WeakSequence(object): def __init__(self, __elements=()): + # adapted from weakref.WeakKeyDictionary, prevent reference + # cycles in the collection itself def _remove(item, selfref=weakref.ref(self)): self = selfref() if self is not None: self._storage.remove(item) + self._remove = _remove self._storage = [ weakref.ref(element, _remove) for element in __elements @@ -737,6 +740,22 @@ class PopulateDict(dict): return val +class WeakPopulateDict(dict): + """Like PopulateDict, but assumes a self + a method and does not create + a reference cycle. + + """ + + def __init__(self, creator_method): + self.creator = creator_method.__func__ + weakself = creator_method.__self__ + self.weakself = weakref.ref(weakself) + + def __missing__(self, key): + self[key] = val = self.creator(self.weakself(), key) + return val + + # Define collections that are capable of storing # ColumnElement objects as hashable keys/elements. # At this point, these are mostly historical, things @@ -744,7 +763,6 @@ class PopulateDict(dict): column_set = set column_dict = dict ordered_column_set = OrderedSet -populate_column_dict = PopulateDict _getters = PopulateDict(operator.itemgetter) |
