summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-05-01 12:06:34 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-05-01 12:33:45 -0400
commit95949db715ff54be01bfd260a51903ede60597ae (patch)
tree926fadafe63839b78c472deae3c899b43b607acc
parent635f06c3ebc787b98cf0ee1e94eff12fc96daff0 (diff)
downloadsqlalchemy-95949db715ff54be01bfd260a51903ede60597ae.tar.gz
- Repair _reinstall_default_lookups to also flip the _extended flag
off again so that test fixtures setup/teardown instrumentation as expected - clean up test_extendedattr.py and fix it to no longer leak itself outside by ensuring _reinstall_default_lookups is always called, part of #3408 - Fixed bug where when using extended attribute instrumentation system, the correct exception would not be raised when :func:`.class_mapper` were called with an invalid input that also happened to not be weak referencable, such as an integer. fixes #3408
-rw-r--r--doc/build/changelog/changelog_09.rst10
-rw-r--r--lib/sqlalchemy/ext/instrumentation.py9
-rw-r--r--test/ext/test_extendedattr.py458
3 files changed, 327 insertions, 150 deletions
diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst
index e66203ed3..2506d21bd 100644
--- a/doc/build/changelog/changelog_09.rst
+++ b/doc/build/changelog/changelog_09.rst
@@ -15,6 +15,16 @@
:version: 0.9.10
.. change::
+ :tags: bug, ext
+ :tickets: 3408
+ :versions: 1.0.4
+
+ Fixed bug where when using extended attribute instrumentation system,
+ the correct exception would not be raised when :func:`.class_mapper`
+ were called with an invalid input that also happened to not
+ be weak referencable, such as an integer.
+
+ .. change::
:tags: bug, tests, pypy
:tickets: 3406
:versions: 1.0.4
diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py
index 024136661..30a0ab7d7 100644
--- a/lib/sqlalchemy/ext/instrumentation.py
+++ b/lib/sqlalchemy/ext/instrumentation.py
@@ -166,7 +166,13 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory):
def manager_of_class(self, cls):
if cls is None:
return None
- return self._manager_finders.get(cls, _default_manager_getter)(cls)
+ try:
+ finder = self._manager_finders.get(cls, _default_manager_getter)
+ except TypeError:
+ # due to weakref lookup on invalid object
+ return None
+ else:
+ return finder(cls)
def state_of(self, instance):
if instance is None:
@@ -392,6 +398,7 @@ def _reinstall_default_lookups():
manager_of_class=_default_manager_getter
)
)
+ _instrumentation_factory._extended = False
def _install_lookups(lookups):
diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py
index c7627c8b2..653418ac4 100644
--- a/test/ext/test_extendedattr.py
+++ b/test/ext/test_extendedattr.py
@@ -1,10 +1,12 @@
from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, ne_
from sqlalchemy import util
+import sqlalchemy as sa
+from sqlalchemy.orm import class_mapper
from sqlalchemy.orm import attributes
-from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute
+from sqlalchemy.orm.attributes import set_attribute, \
+ get_attribute, del_attribute
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.orm import clear_mappers
-from sqlalchemy import testing
from sqlalchemy.testing import fixtures
from sqlalchemy.ext import instrumentation
from sqlalchemy.orm.instrumentation import register_class
@@ -12,6 +14,7 @@ from sqlalchemy.testing.util import decorator
from sqlalchemy.orm import events
from sqlalchemy import event
+
@decorator
def modifies_instrumentation_finders(fn, *args, **kw):
pristine = instrumentation.instrumentation_finders[:]
@@ -21,15 +24,11 @@ def modifies_instrumentation_finders(fn, *args, **kw):
del instrumentation.instrumentation_finders[:]
instrumentation.instrumentation_finders.extend(pristine)
-def with_lookup_strategy(strategy):
- @decorator
- def decorate(fn, *args, **kw):
- try:
- ext_instrumentation._install_instrumented_lookups()
- return fn(*args, **kw)
- finally:
- ext_instrumentation._reinstall_default_lookups()
- return decorate
+
+class _ExtBase(object):
+ @classmethod
+ def teardown_class(cls):
+ instrumentation._reinstall_default_lookups()
class MyTypesManager(instrumentation.InstrumentationManager):
@@ -58,16 +57,19 @@ class MyTypesManager(instrumentation.InstrumentationManager):
def state_getter(self, class_):
return lambda instance: instance.__dict__['_my_state']
+
class MyListLike(list):
# add @appender, @remover decorators as needed
_sa_iterator = list.__iter__
_sa_linker = None
_sa_converter = None
+
def _sa_appender(self, item, _sa_initiator=None):
if _sa_initiator is not False:
self._sa_adapter.fire_append_event(item, _sa_initiator)
list.append(self, item)
append = _sa_appender
+
def _sa_remover(self, item, _sa_initiator=None):
self._sa_adapter.fire_pre_remove_event(_sa_initiator)
if _sa_initiator is not False:
@@ -75,57 +77,64 @@ class MyListLike(list):
list.remove(self, item)
remove = _sa_remover
-class MyBaseClass(object):
- __sa_instrumentation_manager__ = instrumentation.InstrumentationManager
-
-class MyClass(object):
-
- # This proves that a staticmethod will work here; don't
- # flatten this back to a class assignment!
- def __sa_instrumentation_manager__(cls):
- return MyTypesManager(cls)
-
- __sa_instrumentation_manager__ = staticmethod(__sa_instrumentation_manager__)
-
- # This proves SA can handle a class with non-string dict keys
- if not util.pypy and not util.jython:
- locals()[42] = 99 # Don't remove this line!
-
- def __init__(self, **kwargs):
- for k in kwargs:
- setattr(self, k, kwargs[k])
-
- def __getattr__(self, key):
- if is_instrumented(self, key):
- return get_attribute(self, key)
- else:
- try:
- return self._goofy_dict[key]
- except KeyError:
- raise AttributeError(key)
-
- def __setattr__(self, key, value):
- if is_instrumented(self, key):
- set_attribute(self, key, value)
- else:
- self._goofy_dict[key] = value
-
- def __hasattr__(self, key):
- if is_instrumented(self, key):
- return True
- else:
- return key in self._goofy_dict
-
- def __delattr__(self, key):
- if is_instrumented(self, key):
- del_attribute(self, key)
- else:
- del self._goofy_dict[key]
-
-class UserDefinedExtensionTest(fixtures.ORMTest):
+
+MyBaseClass, MyClass = None, None
+
+
+class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
+
@classmethod
- def teardown_class(cls):
- instrumentation._reinstall_default_lookups()
+ def setup_class(cls):
+ global MyBaseClass, MyClass
+
+ class MyBaseClass(object):
+ __sa_instrumentation_manager__ = \
+ instrumentation.InstrumentationManager
+
+ class MyClass(object):
+
+ # This proves that a staticmethod will work here; don't
+ # flatten this back to a class assignment!
+ def __sa_instrumentation_manager__(cls):
+ return MyTypesManager(cls)
+
+ __sa_instrumentation_manager__ = staticmethod(
+ __sa_instrumentation_manager__)
+
+ # This proves SA can handle a class with non-string dict keys
+ if not util.pypy and not util.jython:
+ locals()[42] = 99 # Don't remove this line!
+
+ def __init__(self, **kwargs):
+ for k in kwargs:
+ setattr(self, k, kwargs[k])
+
+ def __getattr__(self, key):
+ if is_instrumented(self, key):
+ return get_attribute(self, key)
+ else:
+ try:
+ return self._goofy_dict[key]
+ except KeyError:
+ raise AttributeError(key)
+
+ def __setattr__(self, key, value):
+ if is_instrumented(self, key):
+ set_attribute(self, key, value)
+ else:
+ self._goofy_dict[key] = value
+
+ def __hasattr__(self, key):
+ if is_instrumented(self, key):
+ return True
+ else:
+ return key in self._goofy_dict
+
+ def __delattr__(self, key):
+ if is_instrumented(self, key):
+ del_attribute(self, key)
+ else:
+ del self._goofy_dict[key]
def teardown(self):
clear_mappers()
@@ -135,15 +144,25 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
pass
register_class(User)
- attributes.register_attribute(User, 'user_id', uselist = False, useobject=False)
- attributes.register_attribute(User, 'user_name', uselist = False, useobject=False)
- attributes.register_attribute(User, 'email_address', uselist = False, useobject=False)
+ attributes.register_attribute(
+ User, 'user_id', uselist=False, useobject=False)
+ attributes.register_attribute(
+ User, 'user_name', uselist=False, useobject=False)
+ attributes.register_attribute(
+ User, 'email_address', uselist=False, useobject=False)
u = User()
u.user_id = 7
u.user_name = 'john'
u.email_address = 'lala@123.com'
- self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}}, u.__dict__)
+ eq_(
+ u.__dict__,
+ {
+ '_my_state': u._my_state,
+ '_goofy_dict': {
+ 'user_id': 7, 'user_name': 'john',
+ 'email_address': 'lala@123.com'}}
+ )
def test_basic(self):
for base in (object, MyBaseClass, MyClass):
@@ -151,29 +170,40 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
pass
register_class(User)
- attributes.register_attribute(User, 'user_id', uselist = False, useobject=False)
- attributes.register_attribute(User, 'user_name', uselist = False, useobject=False)
- attributes.register_attribute(User, 'email_address', uselist = False, useobject=False)
+ attributes.register_attribute(
+ User, 'user_id', uselist=False, useobject=False)
+ attributes.register_attribute(
+ User, 'user_name', uselist=False, useobject=False)
+ attributes.register_attribute(
+ User, 'email_address', uselist=False, useobject=False)
u = User()
u.user_id = 7
u.user_name = 'john'
u.email_address = 'lala@123.com'
- self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
- attributes.instance_state(u)._commit_all(attributes.instance_dict(u))
- self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+ eq_(u.user_id, 7)
+ eq_(u.user_name, "john")
+ eq_(u.email_address, "lala@123.com")
+ attributes.instance_state(u)._commit_all(
+ attributes.instance_dict(u))
+ eq_(u.user_id, 7)
+ eq_(u.user_name, "john")
+ eq_(u.email_address, "lala@123.com")
u.user_name = 'heythere'
u.email_address = 'foo@bar.com'
- self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com')
+ eq_(u.user_id, 7)
+ eq_(u.user_name, "heythere")
+ eq_(u.email_address, "foo@bar.com")
def test_deferred(self):
for base in (object, MyBaseClass, MyClass):
class Foo(base):
pass
- data = {'a':'this is a', 'b':12}
+ data = {'a': 'this is a', 'b': 12}
+
def loader(state, keys):
for k in keys:
state.dict[k] = data[k]
@@ -181,30 +211,38 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
manager = register_class(Foo)
manager.deferred_scalar_loader = loader
- attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
- attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
+ attributes.register_attribute(
+ Foo, 'a', uselist=False, useobject=False)
+ attributes.register_attribute(
+ Foo, 'b', uselist=False, useobject=False)
if base is object:
- assert Foo not in instrumentation._instrumentation_factory._state_finders
+ assert Foo not in \
+ instrumentation._instrumentation_factory._state_finders
else:
- assert Foo in instrumentation._instrumentation_factory._state_finders
+ assert Foo in \
+ instrumentation._instrumentation_factory._state_finders
f = Foo()
- attributes.instance_state(f)._expire(attributes.instance_dict(f), set())
+ attributes.instance_state(f)._expire(
+ attributes.instance_dict(f), set())
eq_(f.a, "this is a")
eq_(f.b, 12)
f.a = "this is some new a"
- attributes.instance_state(f)._expire(attributes.instance_dict(f), set())
+ attributes.instance_state(f)._expire(
+ attributes.instance_dict(f), set())
eq_(f.a, "this is a")
eq_(f.b, 12)
- attributes.instance_state(f)._expire(attributes.instance_dict(f), set())
+ attributes.instance_state(f)._expire(
+ attributes.instance_dict(f), set())
f.a = "this is another new a"
eq_(f.a, "this is another new a")
eq_(f.b, 12)
- attributes.instance_state(f)._expire(attributes.instance_dict(f), set())
+ attributes.instance_state(f)._expire(
+ attributes.instance_dict(f), set())
eq_(f.a, "this is a")
eq_(f.b, 12)
@@ -212,7 +250,8 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
eq_(f.a, None)
eq_(f.b, 12)
- attributes.instance_state(f)._commit_all(attributes.instance_dict(f))
+ attributes.instance_state(f)._commit_all(
+ attributes.instance_dict(f))
eq_(f.a, None)
eq_(f.b, 12)
@@ -220,27 +259,32 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
"""tests that attributes are polymorphic"""
for base in (object, MyBaseClass, MyClass):
- class Foo(base):pass
- class Bar(Foo):pass
+ class Foo(base):
+ pass
+
+ class Bar(Foo):
+ pass
register_class(Foo)
register_class(Bar)
def func1(state, passive):
return "this is the foo attr"
+
def func2(state, passive):
return "this is the bar attr"
+
def func3(state, passive):
return "this is the shared attr"
attributes.register_attribute(Foo, 'element',
- uselist=False, callable_=func1,
- useobject=True)
+ uselist=False, callable_=func1,
+ useobject=True)
attributes.register_attribute(Foo, 'element2',
- uselist=False, callable_=func3,
- useobject=True)
+ uselist=False, callable_=func3,
+ useobject=True)
attributes.register_attribute(Bar, 'element',
- uselist=False, callable_=func2,
- useobject=True)
+ uselist=False, callable_=func2,
+ useobject=True)
x = Foo()
y = Bar()
@@ -251,15 +295,20 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
def test_collection_with_backref(self):
for base in (object, MyBaseClass, MyClass):
- class Post(base):pass
- class Blog(base):pass
+ class Post(base):
+ pass
+
+ class Blog(base):
+ pass
register_class(Post)
register_class(Blog)
- attributes.register_attribute(Post, 'blog', uselist=False,
- backref='posts', trackparent=True, useobject=True)
- attributes.register_attribute(Blog, 'posts', uselist=True,
- backref='blog', trackparent=True, useobject=True)
+ attributes.register_attribute(
+ Post, 'blog', uselist=False,
+ backref='posts', trackparent=True, useobject=True)
+ attributes.register_attribute(
+ Blog, 'posts', uselist=True,
+ backref='blog', trackparent=True, useobject=True)
b = Blog()
(p1, p2, p3) = (Post(), Post(), Post())
b.posts.append(p1)
@@ -287,47 +336,77 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
for base in (object, MyBaseClass, MyClass):
class Foo(base):
pass
+
class Bar(base):
pass
register_class(Foo)
register_class(Bar)
- attributes.register_attribute(Foo, "name", uselist=False, useobject=False)
- attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True)
- attributes.register_attribute(Bar, "name", uselist=False, useobject=False)
-
+ attributes.register_attribute(
+ Foo, "name", uselist=False, useobject=False)
+ attributes.register_attribute(
+ Foo, "bars", uselist=True, trackparent=True, useobject=True)
+ attributes.register_attribute(
+ Bar, "name", uselist=False, useobject=False)
f1 = Foo()
f1.name = 'f1'
- eq_(attributes.get_state_history(attributes.instance_state(f1), 'name'), (['f1'], (), ()))
+ eq_(
+ attributes.get_state_history(
+ attributes.instance_state(f1), 'name'),
+ (['f1'], (), ()))
b1 = Bar()
b1.name = 'b1'
f1.bars.append(b1)
- eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
-
- attributes.instance_state(f1)._commit_all(attributes.instance_dict(f1))
- attributes.instance_state(b1)._commit_all(attributes.instance_dict(b1))
-
- eq_(attributes.get_state_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ()))
- eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ((), [b1], ()))
+ eq_(
+ attributes.get_state_history(
+ attributes.instance_state(f1), 'bars'),
+ ([b1], [], []))
+
+ attributes.instance_state(f1)._commit_all(
+ attributes.instance_dict(f1))
+ attributes.instance_state(b1)._commit_all(
+ attributes.instance_dict(b1))
+
+ eq_(
+ attributes.get_state_history(
+ attributes.instance_state(f1),
+ 'name'),
+ ((), ['f1'], ()))
+ eq_(
+ attributes.get_state_history(
+ attributes.instance_state(f1),
+ 'bars'),
+ ((), [b1], ()))
f1.name = 'f1mod'
b2 = Bar()
b2.name = 'b2'
f1.bars.append(b2)
- eq_(attributes.get_state_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1']))
- eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], []))
+ eq_(
+ attributes.get_state_history(
+ attributes.instance_state(f1), 'name'),
+ (['f1mod'], (), ['f1']))
+ eq_(
+ attributes.get_state_history(
+ attributes.instance_state(f1), 'bars'),
+ ([b2], [b1], []))
f1.bars.remove(b1)
- eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1]))
+ eq_(
+ attributes.get_state_history(
+ attributes.instance_state(f1), 'bars'),
+ ([b2], [], [b1]))
def test_null_instrumentation(self):
class Foo(MyBaseClass):
pass
register_class(Foo)
- attributes.register_attribute(Foo, "name", uselist=False, useobject=False)
- attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True)
+ attributes.register_attribute(
+ Foo, "name", uselist=False, useobject=False)
+ attributes.register_attribute(
+ Foo, "bars", uselist=True, trackparent=True, useobject=True)
assert Foo.name == attributes.manager_of_class(Foo)['name']
assert Foo.bars == attributes.manager_of_class(Foo)['bars']
@@ -335,8 +414,11 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
def test_alternate_finders(self):
"""Ensure the generic finder front-end deals with edge cases."""
- class Unknown(object): pass
- class Known(MyBaseClass): pass
+ class Unknown(object):
+ pass
+
+ class Known(MyBaseClass):
+ pass
register_class(Known)
k, u = Known(), Unknown()
@@ -347,28 +429,59 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
assert attributes.instance_state(k) is not None
assert_raises((AttributeError, KeyError),
- attributes.instance_state, u)
+ attributes.instance_state, u)
assert_raises((AttributeError, KeyError),
- attributes.instance_state, None)
+ attributes.instance_state, None)
+
+ def test_unmapped_not_type_error(self):
+ """extension version of the same test in test_mapper.
+
+ fixes #3408
+ """
+ assert_raises_message(
+ sa.exc.ArgumentError,
+ "Class object expected, got '5'.",
+ class_mapper, 5
+ )
+ def test_unmapped_not_type_error_iter_ok(self):
+ """extension version of the same test in test_mapper.
+
+ fixes #3408
+ """
+ assert_raises_message(
+ sa.exc.ArgumentError,
+ r"Class object expected, got '\(5, 6\)'.",
+ class_mapper, (5, 6)
+ )
+
+
+class FinderTest(_ExtBase, fixtures.ORMTest):
-class FinderTest(fixtures.ORMTest):
def test_standard(self):
- class A(object): pass
+ class A(object):
+ pass
register_class(A)
- eq_(type(instrumentation.manager_of_class(A)), instrumentation.ClassManager)
+ eq_(
+ type(instrumentation.manager_of_class(A)),
+ instrumentation.ClassManager)
def test_nativeext_interfaceexact(self):
class A(object):
- __sa_instrumentation_manager__ = instrumentation.InstrumentationManager
+ __sa_instrumentation_manager__ = \
+ instrumentation.InstrumentationManager
register_class(A)
- ne_(type(instrumentation.manager_of_class(A)), instrumentation.ClassManager)
+ ne_(
+ type(instrumentation.manager_of_class(A)),
+ instrumentation.ClassManager)
def test_nativeext_submanager(self):
- class Mine(instrumentation.ClassManager): pass
+ class Mine(instrumentation.ClassManager):
+ pass
+
class A(object):
__sa_instrumentation_manager__ = Mine
@@ -377,8 +490,12 @@ class FinderTest(fixtures.ORMTest):
@modifies_instrumentation_finders
def test_customfinder_greedy(self):
- class Mine(instrumentation.ClassManager): pass
- class A(object): pass
+ class Mine(instrumentation.ClassManager):
+ pass
+
+ class A(object):
+ pass
+
def find(cls):
return Mine
@@ -388,20 +505,28 @@ class FinderTest(fixtures.ORMTest):
@modifies_instrumentation_finders
def test_customfinder_pass(self):
- class A(object): pass
+ class A(object):
+ pass
+
def find(cls):
return None
instrumentation.instrumentation_finders.insert(0, find)
register_class(A)
- eq_(type(instrumentation.manager_of_class(A)), instrumentation.ClassManager)
+ eq_(
+ type(instrumentation.manager_of_class(A)),
+ instrumentation.ClassManager)
+
+
+class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest):
-class InstrumentationCollisionTest(fixtures.ORMTest):
def test_none(self):
- class A(object): pass
+ class A(object):
+ pass
register_class(A)
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
+
class B(object):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
register_class(B)
@@ -411,79 +536,114 @@ class InstrumentationCollisionTest(fixtures.ORMTest):
register_class(C)
def test_single_down(self):
- class A(object): pass
+ class A(object):
+ pass
register_class(A)
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
+
class B(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
- assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B)
+ assert_raises_message(
+ TypeError, "multiple instrumentation implementations",
+ register_class, B)
def test_single_up(self):
- class A(object): pass
+ class A(object):
+ pass
# delay registration
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
+
class B(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
register_class(B)
- assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, A)
+ assert_raises_message(
+ TypeError, "multiple instrumentation implementations",
+ register_class, A)
def test_diamond_b1(self):
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
- class A(object): pass
- class B1(A): pass
+ class A(object):
+ pass
+
+ class B1(A):
+ pass
+
class B2(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
- class C(object): pass
- assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B1)
+ class C(object):
+ pass
+
+ assert_raises_message(
+ TypeError, "multiple instrumentation implementations",
+ register_class, B1)
def test_diamond_b2(self):
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
- class A(object): pass
- class B1(A): pass
+ class A(object):
+ pass
+
+ class B1(A):
+ pass
+
class B2(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
- class C(object): pass
+
+ class C(object):
+ pass
register_class(B2)
- assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B1)
+ assert_raises_message(
+ TypeError, "multiple instrumentation implementations",
+ register_class, B1)
def test_diamond_c_b(self):
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
- class A(object): pass
- class B1(A): pass
+ class A(object):
+ pass
+
+ class B1(A):
+ pass
+
class B2(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
- class C(object): pass
+
+ class C(object):
+ pass
register_class(C)
- assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B1)
+ assert_raises_message(
+ TypeError, "multiple instrumentation implementations",
+ register_class, B1)
-class ExtendedEventsTest(fixtures.ORMTest):
+class ExtendedEventsTest(_ExtBase, fixtures.ORMTest):
+
"""Allow custom Events implementations."""
@modifies_instrumentation_finders
def test_subclassed(self):
class MyEvents(events.InstanceEvents):
pass
+
class MyClassManager(instrumentation.ClassManager):
dispatch = event.dispatcher(MyEvents)
- instrumentation.instrumentation_finders.insert(0, lambda cls: MyClassManager)
+ instrumentation.instrumentation_finders.insert(
+ 0, lambda cls: MyClassManager)
- class A(object): pass
+ class A(object):
+ pass
register_class(A)
manager = instrumentation.manager_of_class(A)
assert issubclass(manager.dispatch._events, MyEvents)
-