diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-05-01 12:06:34 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-05-01 12:33:45 -0400 |
commit | 95949db715ff54be01bfd260a51903ede60597ae (patch) | |
tree | 926fadafe63839b78c472deae3c899b43b607acc | |
parent | 635f06c3ebc787b98cf0ee1e94eff12fc96daff0 (diff) | |
download | sqlalchemy-95949db715ff54be01bfd260a51903ede60597ae.tar.gz |
- Repair _reinstall_default_lookups to also flip the _extended flag
off again so that test fixtures setup/teardown instrumentation as
expected
- clean up test_extendedattr.py and fix it to no longer leak
itself outside by ensuring _reinstall_default_lookups is always called,
part of #3408
- Fixed bug where when using extended attribute instrumentation system,
the correct exception would not be raised when :func:`.class_mapper`
were called with an invalid input that also happened to not
be weak referencable, such as an integer.
fixes #3408
-rw-r--r-- | doc/build/changelog/changelog_09.rst | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/instrumentation.py | 9 | ||||
-rw-r--r-- | test/ext/test_extendedattr.py | 458 |
3 files changed, 327 insertions, 150 deletions
diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index e66203ed3..2506d21bd 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -15,6 +15,16 @@ :version: 0.9.10 .. change:: + :tags: bug, ext + :tickets: 3408 + :versions: 1.0.4 + + Fixed bug where when using extended attribute instrumentation system, + the correct exception would not be raised when :func:`.class_mapper` + were called with an invalid input that also happened to not + be weak referencable, such as an integer. + + .. change:: :tags: bug, tests, pypy :tickets: 3406 :versions: 1.0.4 diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index 024136661..30a0ab7d7 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -166,7 +166,13 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): def manager_of_class(self, cls): if cls is None: return None - return self._manager_finders.get(cls, _default_manager_getter)(cls) + try: + finder = self._manager_finders.get(cls, _default_manager_getter) + except TypeError: + # due to weakref lookup on invalid object + return None + else: + return finder(cls) def state_of(self, instance): if instance is None: @@ -392,6 +398,7 @@ def _reinstall_default_lookups(): manager_of_class=_default_manager_getter ) ) + _instrumentation_factory._extended = False def _install_lookups(lookups): diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py index c7627c8b2..653418ac4 100644 --- a/test/ext/test_extendedattr.py +++ b/test/ext/test_extendedattr.py @@ -1,10 +1,12 @@ from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, ne_ from sqlalchemy import util +import sqlalchemy as sa +from sqlalchemy.orm import class_mapper from sqlalchemy.orm import attributes -from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute +from sqlalchemy.orm.attributes import set_attribute, \ + get_attribute, del_attribute from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.orm import clear_mappers -from sqlalchemy import testing from sqlalchemy.testing import fixtures from sqlalchemy.ext import instrumentation from sqlalchemy.orm.instrumentation import register_class @@ -12,6 +14,7 @@ from sqlalchemy.testing.util import decorator from sqlalchemy.orm import events from sqlalchemy import event + @decorator def modifies_instrumentation_finders(fn, *args, **kw): pristine = instrumentation.instrumentation_finders[:] @@ -21,15 +24,11 @@ def modifies_instrumentation_finders(fn, *args, **kw): del instrumentation.instrumentation_finders[:] instrumentation.instrumentation_finders.extend(pristine) -def with_lookup_strategy(strategy): - @decorator - def decorate(fn, *args, **kw): - try: - ext_instrumentation._install_instrumented_lookups() - return fn(*args, **kw) - finally: - ext_instrumentation._reinstall_default_lookups() - return decorate + +class _ExtBase(object): + @classmethod + def teardown_class(cls): + instrumentation._reinstall_default_lookups() class MyTypesManager(instrumentation.InstrumentationManager): @@ -58,16 +57,19 @@ class MyTypesManager(instrumentation.InstrumentationManager): def state_getter(self, class_): return lambda instance: instance.__dict__['_my_state'] + class MyListLike(list): # add @appender, @remover decorators as needed _sa_iterator = list.__iter__ _sa_linker = None _sa_converter = None + def _sa_appender(self, item, _sa_initiator=None): if _sa_initiator is not False: self._sa_adapter.fire_append_event(item, _sa_initiator) list.append(self, item) append = _sa_appender + def _sa_remover(self, item, _sa_initiator=None): self._sa_adapter.fire_pre_remove_event(_sa_initiator) if _sa_initiator is not False: @@ -75,57 +77,64 @@ class MyListLike(list): list.remove(self, item) remove = _sa_remover -class MyBaseClass(object): - __sa_instrumentation_manager__ = instrumentation.InstrumentationManager - -class MyClass(object): - - # This proves that a staticmethod will work here; don't - # flatten this back to a class assignment! - def __sa_instrumentation_manager__(cls): - return MyTypesManager(cls) - - __sa_instrumentation_manager__ = staticmethod(__sa_instrumentation_manager__) - - # This proves SA can handle a class with non-string dict keys - if not util.pypy and not util.jython: - locals()[42] = 99 # Don't remove this line! - - def __init__(self, **kwargs): - for k in kwargs: - setattr(self, k, kwargs[k]) - - def __getattr__(self, key): - if is_instrumented(self, key): - return get_attribute(self, key) - else: - try: - return self._goofy_dict[key] - except KeyError: - raise AttributeError(key) - - def __setattr__(self, key, value): - if is_instrumented(self, key): - set_attribute(self, key, value) - else: - self._goofy_dict[key] = value - - def __hasattr__(self, key): - if is_instrumented(self, key): - return True - else: - return key in self._goofy_dict - - def __delattr__(self, key): - if is_instrumented(self, key): - del_attribute(self, key) - else: - del self._goofy_dict[key] - -class UserDefinedExtensionTest(fixtures.ORMTest): + +MyBaseClass, MyClass = None, None + + +class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): + @classmethod - def teardown_class(cls): - instrumentation._reinstall_default_lookups() + def setup_class(cls): + global MyBaseClass, MyClass + + class MyBaseClass(object): + __sa_instrumentation_manager__ = \ + instrumentation.InstrumentationManager + + class MyClass(object): + + # This proves that a staticmethod will work here; don't + # flatten this back to a class assignment! + def __sa_instrumentation_manager__(cls): + return MyTypesManager(cls) + + __sa_instrumentation_manager__ = staticmethod( + __sa_instrumentation_manager__) + + # This proves SA can handle a class with non-string dict keys + if not util.pypy and not util.jython: + locals()[42] = 99 # Don't remove this line! + + def __init__(self, **kwargs): + for k in kwargs: + setattr(self, k, kwargs[k]) + + def __getattr__(self, key): + if is_instrumented(self, key): + return get_attribute(self, key) + else: + try: + return self._goofy_dict[key] + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key, value): + if is_instrumented(self, key): + set_attribute(self, key, value) + else: + self._goofy_dict[key] = value + + def __hasattr__(self, key): + if is_instrumented(self, key): + return True + else: + return key in self._goofy_dict + + def __delattr__(self, key): + if is_instrumented(self, key): + del_attribute(self, key) + else: + del self._goofy_dict[key] def teardown(self): clear_mappers() @@ -135,15 +144,25 @@ class UserDefinedExtensionTest(fixtures.ORMTest): pass register_class(User) - attributes.register_attribute(User, 'user_id', uselist = False, useobject=False) - attributes.register_attribute(User, 'user_name', uselist = False, useobject=False) - attributes.register_attribute(User, 'email_address', uselist = False, useobject=False) + attributes.register_attribute( + User, 'user_id', uselist=False, useobject=False) + attributes.register_attribute( + User, 'user_name', uselist=False, useobject=False) + attributes.register_attribute( + User, 'email_address', uselist=False, useobject=False) u = User() u.user_id = 7 u.user_name = 'john' u.email_address = 'lala@123.com' - self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}}, u.__dict__) + eq_( + u.__dict__, + { + '_my_state': u._my_state, + '_goofy_dict': { + 'user_id': 7, 'user_name': 'john', + 'email_address': 'lala@123.com'}} + ) def test_basic(self): for base in (object, MyBaseClass, MyClass): @@ -151,29 +170,40 @@ class UserDefinedExtensionTest(fixtures.ORMTest): pass register_class(User) - attributes.register_attribute(User, 'user_id', uselist = False, useobject=False) - attributes.register_attribute(User, 'user_name', uselist = False, useobject=False) - attributes.register_attribute(User, 'email_address', uselist = False, useobject=False) + attributes.register_attribute( + User, 'user_id', uselist=False, useobject=False) + attributes.register_attribute( + User, 'user_name', uselist=False, useobject=False) + attributes.register_attribute( + User, 'email_address', uselist=False, useobject=False) u = User() u.user_id = 7 u.user_name = 'john' u.email_address = 'lala@123.com' - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - attributes.instance_state(u)._commit_all(attributes.instance_dict(u)) - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') + eq_(u.user_id, 7) + eq_(u.user_name, "john") + eq_(u.email_address, "lala@123.com") + attributes.instance_state(u)._commit_all( + attributes.instance_dict(u)) + eq_(u.user_id, 7) + eq_(u.user_name, "john") + eq_(u.email_address, "lala@123.com") u.user_name = 'heythere' u.email_address = 'foo@bar.com' - self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com') + eq_(u.user_id, 7) + eq_(u.user_name, "heythere") + eq_(u.email_address, "foo@bar.com") def test_deferred(self): for base in (object, MyBaseClass, MyClass): class Foo(base): pass - data = {'a':'this is a', 'b':12} + data = {'a': 'this is a', 'b': 12} + def loader(state, keys): for k in keys: state.dict[k] = data[k] @@ -181,30 +211,38 @@ class UserDefinedExtensionTest(fixtures.ORMTest): manager = register_class(Foo) manager.deferred_scalar_loader = loader - attributes.register_attribute(Foo, 'a', uselist=False, useobject=False) - attributes.register_attribute(Foo, 'b', uselist=False, useobject=False) + attributes.register_attribute( + Foo, 'a', uselist=False, useobject=False) + attributes.register_attribute( + Foo, 'b', uselist=False, useobject=False) if base is object: - assert Foo not in instrumentation._instrumentation_factory._state_finders + assert Foo not in \ + instrumentation._instrumentation_factory._state_finders else: - assert Foo in instrumentation._instrumentation_factory._state_finders + assert Foo in \ + instrumentation._instrumentation_factory._state_finders f = Foo() - attributes.instance_state(f)._expire(attributes.instance_dict(f), set()) + attributes.instance_state(f)._expire( + attributes.instance_dict(f), set()) eq_(f.a, "this is a") eq_(f.b, 12) f.a = "this is some new a" - attributes.instance_state(f)._expire(attributes.instance_dict(f), set()) + attributes.instance_state(f)._expire( + attributes.instance_dict(f), set()) eq_(f.a, "this is a") eq_(f.b, 12) - attributes.instance_state(f)._expire(attributes.instance_dict(f), set()) + attributes.instance_state(f)._expire( + attributes.instance_dict(f), set()) f.a = "this is another new a" eq_(f.a, "this is another new a") eq_(f.b, 12) - attributes.instance_state(f)._expire(attributes.instance_dict(f), set()) + attributes.instance_state(f)._expire( + attributes.instance_dict(f), set()) eq_(f.a, "this is a") eq_(f.b, 12) @@ -212,7 +250,8 @@ class UserDefinedExtensionTest(fixtures.ORMTest): eq_(f.a, None) eq_(f.b, 12) - attributes.instance_state(f)._commit_all(attributes.instance_dict(f)) + attributes.instance_state(f)._commit_all( + attributes.instance_dict(f)) eq_(f.a, None) eq_(f.b, 12) @@ -220,27 +259,32 @@ class UserDefinedExtensionTest(fixtures.ORMTest): """tests that attributes are polymorphic""" for base in (object, MyBaseClass, MyClass): - class Foo(base):pass - class Bar(Foo):pass + class Foo(base): + pass + + class Bar(Foo): + pass register_class(Foo) register_class(Bar) def func1(state, passive): return "this is the foo attr" + def func2(state, passive): return "this is the bar attr" + def func3(state, passive): return "this is the shared attr" attributes.register_attribute(Foo, 'element', - uselist=False, callable_=func1, - useobject=True) + uselist=False, callable_=func1, + useobject=True) attributes.register_attribute(Foo, 'element2', - uselist=False, callable_=func3, - useobject=True) + uselist=False, callable_=func3, + useobject=True) attributes.register_attribute(Bar, 'element', - uselist=False, callable_=func2, - useobject=True) + uselist=False, callable_=func2, + useobject=True) x = Foo() y = Bar() @@ -251,15 +295,20 @@ class UserDefinedExtensionTest(fixtures.ORMTest): def test_collection_with_backref(self): for base in (object, MyBaseClass, MyClass): - class Post(base):pass - class Blog(base):pass + class Post(base): + pass + + class Blog(base): + pass register_class(Post) register_class(Blog) - attributes.register_attribute(Post, 'blog', uselist=False, - backref='posts', trackparent=True, useobject=True) - attributes.register_attribute(Blog, 'posts', uselist=True, - backref='blog', trackparent=True, useobject=True) + attributes.register_attribute( + Post, 'blog', uselist=False, + backref='posts', trackparent=True, useobject=True) + attributes.register_attribute( + Blog, 'posts', uselist=True, + backref='blog', trackparent=True, useobject=True) b = Blog() (p1, p2, p3) = (Post(), Post(), Post()) b.posts.append(p1) @@ -287,47 +336,77 @@ class UserDefinedExtensionTest(fixtures.ORMTest): for base in (object, MyBaseClass, MyClass): class Foo(base): pass + class Bar(base): pass register_class(Foo) register_class(Bar) - attributes.register_attribute(Foo, "name", uselist=False, useobject=False) - attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True) - attributes.register_attribute(Bar, "name", uselist=False, useobject=False) - + attributes.register_attribute( + Foo, "name", uselist=False, useobject=False) + attributes.register_attribute( + Foo, "bars", uselist=True, trackparent=True, useobject=True) + attributes.register_attribute( + Bar, "name", uselist=False, useobject=False) f1 = Foo() f1.name = 'f1' - eq_(attributes.get_state_history(attributes.instance_state(f1), 'name'), (['f1'], (), ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f1), 'name'), + (['f1'], (), ())) b1 = Bar() b1.name = 'b1' f1.bars.append(b1) - eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ([b1], [], [])) - - attributes.instance_state(f1)._commit_all(attributes.instance_dict(f1)) - attributes.instance_state(b1)._commit_all(attributes.instance_dict(b1)) - - eq_(attributes.get_state_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ())) - eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ((), [b1], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f1), 'bars'), + ([b1], [], [])) + + attributes.instance_state(f1)._commit_all( + attributes.instance_dict(f1)) + attributes.instance_state(b1)._commit_all( + attributes.instance_dict(b1)) + + eq_( + attributes.get_state_history( + attributes.instance_state(f1), + 'name'), + ((), ['f1'], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f1), + 'bars'), + ((), [b1], ())) f1.name = 'f1mod' b2 = Bar() b2.name = 'b2' f1.bars.append(b2) - eq_(attributes.get_state_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1'])) - eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], [])) + eq_( + attributes.get_state_history( + attributes.instance_state(f1), 'name'), + (['f1mod'], (), ['f1'])) + eq_( + attributes.get_state_history( + attributes.instance_state(f1), 'bars'), + ([b2], [b1], [])) f1.bars.remove(b1) - eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1])) + eq_( + attributes.get_state_history( + attributes.instance_state(f1), 'bars'), + ([b2], [], [b1])) def test_null_instrumentation(self): class Foo(MyBaseClass): pass register_class(Foo) - attributes.register_attribute(Foo, "name", uselist=False, useobject=False) - attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True) + attributes.register_attribute( + Foo, "name", uselist=False, useobject=False) + attributes.register_attribute( + Foo, "bars", uselist=True, trackparent=True, useobject=True) assert Foo.name == attributes.manager_of_class(Foo)['name'] assert Foo.bars == attributes.manager_of_class(Foo)['bars'] @@ -335,8 +414,11 @@ class UserDefinedExtensionTest(fixtures.ORMTest): def test_alternate_finders(self): """Ensure the generic finder front-end deals with edge cases.""" - class Unknown(object): pass - class Known(MyBaseClass): pass + class Unknown(object): + pass + + class Known(MyBaseClass): + pass register_class(Known) k, u = Known(), Unknown() @@ -347,28 +429,59 @@ class UserDefinedExtensionTest(fixtures.ORMTest): assert attributes.instance_state(k) is not None assert_raises((AttributeError, KeyError), - attributes.instance_state, u) + attributes.instance_state, u) assert_raises((AttributeError, KeyError), - attributes.instance_state, None) + attributes.instance_state, None) + + def test_unmapped_not_type_error(self): + """extension version of the same test in test_mapper. + + fixes #3408 + """ + assert_raises_message( + sa.exc.ArgumentError, + "Class object expected, got '5'.", + class_mapper, 5 + ) + def test_unmapped_not_type_error_iter_ok(self): + """extension version of the same test in test_mapper. + + fixes #3408 + """ + assert_raises_message( + sa.exc.ArgumentError, + r"Class object expected, got '\(5, 6\)'.", + class_mapper, (5, 6) + ) + + +class FinderTest(_ExtBase, fixtures.ORMTest): -class FinderTest(fixtures.ORMTest): def test_standard(self): - class A(object): pass + class A(object): + pass register_class(A) - eq_(type(instrumentation.manager_of_class(A)), instrumentation.ClassManager) + eq_( + type(instrumentation.manager_of_class(A)), + instrumentation.ClassManager) def test_nativeext_interfaceexact(self): class A(object): - __sa_instrumentation_manager__ = instrumentation.InstrumentationManager + __sa_instrumentation_manager__ = \ + instrumentation.InstrumentationManager register_class(A) - ne_(type(instrumentation.manager_of_class(A)), instrumentation.ClassManager) + ne_( + type(instrumentation.manager_of_class(A)), + instrumentation.ClassManager) def test_nativeext_submanager(self): - class Mine(instrumentation.ClassManager): pass + class Mine(instrumentation.ClassManager): + pass + class A(object): __sa_instrumentation_manager__ = Mine @@ -377,8 +490,12 @@ class FinderTest(fixtures.ORMTest): @modifies_instrumentation_finders def test_customfinder_greedy(self): - class Mine(instrumentation.ClassManager): pass - class A(object): pass + class Mine(instrumentation.ClassManager): + pass + + class A(object): + pass + def find(cls): return Mine @@ -388,20 +505,28 @@ class FinderTest(fixtures.ORMTest): @modifies_instrumentation_finders def test_customfinder_pass(self): - class A(object): pass + class A(object): + pass + def find(cls): return None instrumentation.instrumentation_finders.insert(0, find) register_class(A) - eq_(type(instrumentation.manager_of_class(A)), instrumentation.ClassManager) + eq_( + type(instrumentation.manager_of_class(A)), + instrumentation.ClassManager) + + +class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest): -class InstrumentationCollisionTest(fixtures.ORMTest): def test_none(self): - class A(object): pass + class A(object): + pass register_class(A) mgr_factory = lambda cls: instrumentation.ClassManager(cls) + class B(object): __sa_instrumentation_manager__ = staticmethod(mgr_factory) register_class(B) @@ -411,79 +536,114 @@ class InstrumentationCollisionTest(fixtures.ORMTest): register_class(C) def test_single_down(self): - class A(object): pass + class A(object): + pass register_class(A) mgr_factory = lambda cls: instrumentation.ClassManager(cls) + class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) - assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B) + assert_raises_message( + TypeError, "multiple instrumentation implementations", + register_class, B) def test_single_up(self): - class A(object): pass + class A(object): + pass # delay registration mgr_factory = lambda cls: instrumentation.ClassManager(cls) + class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) register_class(B) - assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, A) + assert_raises_message( + TypeError, "multiple instrumentation implementations", + register_class, A) def test_diamond_b1(self): mgr_factory = lambda cls: instrumentation.ClassManager(cls) - class A(object): pass - class B1(A): pass + class A(object): + pass + + class B1(A): + pass + class B2(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) - class C(object): pass - assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B1) + class C(object): + pass + + assert_raises_message( + TypeError, "multiple instrumentation implementations", + register_class, B1) def test_diamond_b2(self): mgr_factory = lambda cls: instrumentation.ClassManager(cls) - class A(object): pass - class B1(A): pass + class A(object): + pass + + class B1(A): + pass + class B2(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) - class C(object): pass + + class C(object): + pass register_class(B2) - assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B1) + assert_raises_message( + TypeError, "multiple instrumentation implementations", + register_class, B1) def test_diamond_c_b(self): mgr_factory = lambda cls: instrumentation.ClassManager(cls) - class A(object): pass - class B1(A): pass + class A(object): + pass + + class B1(A): + pass + class B2(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) - class C(object): pass + + class C(object): + pass register_class(C) - assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B1) + assert_raises_message( + TypeError, "multiple instrumentation implementations", + register_class, B1) -class ExtendedEventsTest(fixtures.ORMTest): +class ExtendedEventsTest(_ExtBase, fixtures.ORMTest): + """Allow custom Events implementations.""" @modifies_instrumentation_finders def test_subclassed(self): class MyEvents(events.InstanceEvents): pass + class MyClassManager(instrumentation.ClassManager): dispatch = event.dispatcher(MyEvents) - instrumentation.instrumentation_finders.insert(0, lambda cls: MyClassManager) + instrumentation.instrumentation_finders.insert( + 0, lambda cls: MyClassManager) - class A(object): pass + class A(object): + pass register_class(A) manager = instrumentation.manager_of_class(A) assert issubclass(manager.dispatch._events, MyEvents) - |