diff options
| -rw-r--r-- | lib/sqlalchemy/util/_collections.py | 13 | ||||
| -rw-r--r-- | test/base/test_utils.py | 36 |
2 files changed, 43 insertions, 6 deletions
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index b2e5c6250..a43115203 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -649,16 +649,23 @@ class IdentitySet(object): class WeakSequence(object): - def __init__(self, elements): + def __init__(self, __elements=()): self._storage = [ - weakref.ref(element) for element in elements + weakref.ref(element, self._remove) for element in __elements ] + def append(self, item): + self._storage.append(weakref.ref(item, self._remove)) + def _remove(self, ref): self._storage.remove(ref) + def __len__(self): + return len(self._storage) + def __iter__(self): - return (obj for obj in (ref() for ref in self._storage) if obj is not None) + return (obj for obj in + (ref() for ref in self._storage) if obj is not None) def __getitem__(self, index): try: diff --git a/test/base/test_utils.py b/test/base/test_utils.py index aefc6d421..2fd1edbb5 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,10 +1,10 @@ import copy -from sqlalchemy import util, sql, exc +from sqlalchemy import util, sql, exc, testing from sqlalchemy.testing import assert_raises, assert_raises_message, fixtures from sqlalchemy.testing import eq_, is_, ne_, fails_if -from sqlalchemy.testing.util import picklers -from sqlalchemy.util import classproperty +from sqlalchemy.testing.util import picklers, gc_collect +from sqlalchemy.util import classproperty, WeakSequence class KeyedTupleTest(): @@ -115,6 +115,36 @@ class KeyedTupleTest(): keyed_tuple[0] = 100 assert_raises(TypeError, should_raise) +class WeakSequenceTest(fixtures.TestBase): + @testing.requires.predictable_gc + def test_cleanout_elements(self): + class Foo(object): + pass + f1, f2, f3 = Foo(), Foo(), Foo() + w = WeakSequence([f1, f2, f3]) + eq_(len(w), 3) + eq_(len(w._storage), 3) + del f2 + gc_collect() + eq_(len(w), 2) + eq_(len(w._storage), 2) + + @testing.requires.predictable_gc + def test_cleanout_appended(self): + class Foo(object): + pass + f1, f2, f3 = Foo(), Foo(), Foo() + w = WeakSequence() + w.append(f1) + w.append(f2) + w.append(f3) + eq_(len(w), 3) + eq_(len(w._storage), 3) + del f2 + gc_collect() + eq_(len(w), 2) + eq_(len(w._storage), 2) + class OrderedDictTest(fixtures.TestBase): |
