diff options
-rw-r--r-- | doc/build/changelog/changelog_09.rst | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/declarative/api.py | 24 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/declarative/clsregistry.py | 89 | ||||
-rw-r--r-- | test/ext/declarative/test_reflection.py | 112 |
4 files changed, 170 insertions, 67 deletions
diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index 17116c2c4..367fa1df9 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -15,6 +15,18 @@ :version: 0.9.0b2 .. change:: + :tags: bug, orm, declarative + :tickets: 2865 + + The :class:`.DeferredReflection` class has been enhanced to provide + automatic reflection support for the "secondary" table referred + to by a :func:`.relationship`. "secondary", when specified + either as a string table name, or as a :class:`.Table` object with + only a name and :class:`.MetaData` object will also be included + in the reflection process when :meth:`.DeferredReflection.prepare` + is called. + + .. change:: :tags: feature, orm, backrefs :tickets: 1535 diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py index 1cb653a23..64bf7fd9f 100644 --- a/lib/sqlalchemy/ext/declarative/api.py +++ b/lib/sqlalchemy/ext/declarative/api.py @@ -9,15 +9,17 @@ from ...schema import Table, MetaData from ...orm import synonym as _orm_synonym, mapper,\ comparable_property,\ - interfaces + interfaces, properties from ...orm.util import polymorphic_union from ...orm.base import _mapper_or_none +from ...util import compat from ... import exc import weakref from .base import _as_declarative, \ _declarative_constructor,\ _MapperConfig, _add_attribute +from .clsregistry import _class_resolver def instrument_declarative(cls, registry, metadata): @@ -465,11 +467,31 @@ class DeferredReflection(object): def prepare(cls, engine): """Reflect all :class:`.Table` objects for all current :class:`.DeferredReflection` subclasses""" + to_map = [m for m in _MapperConfig.configs.values() if issubclass(m.cls, cls)] for thingy in to_map: cls._sa_decl_prepare(thingy.local_table, engine) thingy.map() + mapper = thingy.cls.__mapper__ + metadata = mapper.class_.metadata + for rel in mapper._props.values(): + if isinstance(rel, properties.RelationshipProperty) and \ + rel.secondary is not None: + if isinstance(rel.secondary, Table): + cls._sa_decl_prepare(rel.secondary, engine) + elif isinstance(rel.secondary, _class_resolver): + rel.secondary._resolvers += ( + cls._sa_deferred_table_resolver(engine, metadata), + ) + + @classmethod + def _sa_deferred_table_resolver(cls, engine, metadata): + def _resolve(key): + t1 = Table(key, metadata) + cls._sa_decl_prepare(t1, engine) + return t1 + return _resolve @classmethod def _sa_decl_prepare(cls, local_table, engine): diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index 8fef8f1bc..04567b32c 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -225,47 +225,62 @@ def _determine_container(key, value): return _GetColumns(value) -def _resolver(cls, prop): - def resolve_arg(arg): - import sqlalchemy - from sqlalchemy.orm import foreign, remote - - fallback = sqlalchemy.__dict__.copy() - fallback.update({'foreign': foreign, 'remote': remote}) - - def access_cls(key): - if key in cls._decl_class_registry: - return _determine_container(key, cls._decl_class_registry[key]) - elif key in cls.metadata.tables: - return cls.metadata.tables[key] - elif key in cls.metadata._schemas: - return _GetTable(key, cls.metadata) - elif '_sa_module_registry' in cls._decl_class_registry and \ - key in cls._decl_class_registry['_sa_module_registry']: - registry = cls._decl_class_registry['_sa_module_registry'] - return registry.resolve_attr(key) +class _class_resolver(object): + def __init__(self, cls, prop, fallback, arg): + self.cls = cls + self.prop = prop + self.arg = self._declarative_arg = arg + self.fallback = fallback + self._dict = util.PopulateDict(self._access_cls) + self._resolvers = () + + def _access_cls(self, key): + cls = self.cls + if key in cls._decl_class_registry: + return _determine_container(key, cls._decl_class_registry[key]) + elif key in cls.metadata.tables: + return cls.metadata.tables[key] + elif key in cls.metadata._schemas: + return _GetTable(key, cls.metadata) + elif '_sa_module_registry' in cls._decl_class_registry and \ + key in cls._decl_class_registry['_sa_module_registry']: + registry = cls._decl_class_registry['_sa_module_registry'] + return registry.resolve_attr(key) + elif self._resolvers: + for resolv in self._resolvers: + value = resolv(key) + if value is not None: + return value + + return self.fallback[key] + + def __call__(self): + try: + x = eval(self.arg, globals(), self._dict) + + if isinstance(x, _GetColumns): + return x.cls else: - return fallback[key] + return x + except NameError as n: + raise exc.InvalidRequestError( + "When initializing mapper %s, expression %r failed to " + "locate a name (%r). If this is a class name, consider " + "adding this relationship() to the %r class after " + "both dependent classes have been defined." % + (self.prop.parent, self.arg, n.args[0], self.cls) + ) - d = util.PopulateDict(access_cls) - def return_cls(): - try: - x = eval(arg, globals(), d) +def _resolver(cls, prop): + import sqlalchemy + from sqlalchemy.orm import foreign, remote - if isinstance(x, _GetColumns): - return x.cls - else: - return x - except NameError as n: - raise exc.InvalidRequestError( - "When initializing mapper %s, expression %r failed to " - "locate a name (%r). If this is a class name, consider " - "adding this relationship() to the %r class after " - "both dependent classes have been defined." % - (prop.parent, arg, n.args[0], cls) - ) - return return_cls + fallback = sqlalchemy.__dict__.copy() + fallback.update({'foreign': foreign, 'remote': remote}) + + def resolve_arg(arg): + return _class_resolver(cls, prop, fallback, arg) return resolve_arg diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py index 013439f93..26496f1ad 100644 --- a/test/ext/declarative/test_reflection.py +++ b/test/ext/declarative/test_reflection.py @@ -47,9 +47,8 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): test_needs_fk=True, ) - def test_basic(self): - meta = MetaData(testing.db) + def test_basic(self): class User(Base, fixtures.ComparableEntity): __tablename__ = 'users' @@ -80,8 +79,6 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): eq_(a1.user, User(name='u1')) def test_rekey(self): - meta = MetaData(testing.db) - class User(Base, fixtures.ComparableEntity): __tablename__ = 'users' @@ -114,8 +111,6 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): assert_raises(TypeError, User, name='u3') def test_supplied_fk(self): - meta = MetaData(testing.db) - class IMHandle(Base, fixtures.ComparableEntity): __tablename__ = 'imhandles' @@ -151,7 +146,7 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): class DeferredReflectBase(DeclarativeReflectionBase): def teardown(self): - super(DeferredReflectBase,self).teardown() + super(DeferredReflectBase, self).teardown() from sqlalchemy.ext.declarative.base import _MapperConfig _MapperConfig.configs.clear() @@ -275,7 +270,7 @@ class DeferredReflectionTest(DeferredReflectBase): @decl.declared_attr def __mapper_args__(cls): return { - "order_by":cls.__table__.c.name + "order_by": cls.__table__.c.name } decl.DeferredReflection.prepare(testing.db) @@ -297,6 +292,65 @@ class DeferredReflectionTest(DeferredReflectBase): ] ) +class DeferredSecondaryReflectionTest(DeferredReflectBase): + @classmethod + def define_tables(cls, metadata): + Table('users', metadata, + Column('id', Integer, + primary_key=True, test_needs_autoincrement=True), + Column('name', String(50)), test_needs_fk=True) + + Table('user_items', metadata, + Column('user_id', ForeignKey('users.id'), primary_key=True), + Column('item_id', ForeignKey('items.id'), primary_key=True), + test_needs_fk=True + ) + + Table('items', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(50)), + test_needs_fk=True + ) + + def _roundtrip(self): + + User = Base._decl_class_registry['User'] + Item = Base._decl_class_registry['Item'] + + u1 = User(name='u1', items=[Item(name='i1'), Item(name='i2')]) + + sess = Session() + sess.add(u1) + sess.commit() + + eq_(sess.query(User).all(), [User(name='u1', + items=[Item(name='i1'), Item(name='i2')])]) + + def test_string_resolution(self): + class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = 'users' + + items = relationship("Item", secondary="user_items") + + class Item(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = 'items' + + decl.DeferredReflection.prepare(testing.db) + self._roundtrip() + + def test_table_resolution(self): + class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = 'users' + + items = relationship("Item", secondary=Table("user_items", Base.metadata)) + + class Item(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = 'items' + + decl.DeferredReflection.prepare(testing.db) + self._roundtrip() + class DeferredInhReflectBase(DeferredReflectBase): def _roundtrip(self): Foo = Base._decl_class_registry['Foo'] @@ -338,11 +392,11 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on":"type", - "polymorphic_identity":"foo"} + __mapper_args__ = {"polymorphic_on": "type", + "polymorphic_identity": "foo"} class Bar(Foo): - __mapper_args__ = {"polymorphic_identity":"bar"} + __mapper_args__ = {"polymorphic_identity": "bar"} decl.DeferredReflection.prepare(testing.db) self._roundtrip() @@ -351,11 +405,11 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on":"type", - "polymorphic_identity":"foo"} + __mapper_args__ = {"polymorphic_on": "type", + "polymorphic_identity": "foo"} class Bar(Foo): - __mapper_args__ = {"polymorphic_identity":"bar"} + __mapper_args__ = {"polymorphic_identity": "bar"} bar_data = Column(String(30)) decl.DeferredReflection.prepare(testing.db) @@ -365,12 +419,12 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on":"type", - "polymorphic_identity":"foo"} + __mapper_args__ = {"polymorphic_on": "type", + "polymorphic_identity": "foo"} id = Column(Integer, primary_key=True) class Bar(Foo): - __mapper_args__ = {"polymorphic_identity":"bar"} + __mapper_args__ = {"polymorphic_identity": "bar"} decl.DeferredReflection.prepare(testing.db) self._roundtrip() @@ -395,12 +449,12 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on":"type", - "polymorphic_identity":"foo"} + __mapper_args__ = {"polymorphic_on": "type", + "polymorphic_identity": "foo"} class Bar(Foo): __tablename__ = 'bar' - __mapper_args__ = {"polymorphic_identity":"bar"} + __mapper_args__ = {"polymorphic_identity": "bar"} decl.DeferredReflection.prepare(testing.db) self._roundtrip() @@ -409,12 +463,12 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on":"type", - "polymorphic_identity":"foo"} + __mapper_args__ = {"polymorphic_on": "type", + "polymorphic_identity": "foo"} class Bar(Foo): __tablename__ = 'bar' - __mapper_args__ = {"polymorphic_identity":"bar"} + __mapper_args__ = {"polymorphic_identity": "bar"} bar_data = Column(String(30)) decl.DeferredReflection.prepare(testing.db) @@ -424,13 +478,13 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on":"type", - "polymorphic_identity":"foo"} + __mapper_args__ = {"polymorphic_on": "type", + "polymorphic_identity": "foo"} id = Column(Integer, primary_key=True) class Bar(Foo): __tablename__ = 'bar' - __mapper_args__ = {"polymorphic_identity":"bar"} + __mapper_args__ = {"polymorphic_identity": "bar"} decl.DeferredReflection.prepare(testing.db) self._roundtrip() @@ -439,12 +493,12 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on":"type", - "polymorphic_identity":"foo"} + __mapper_args__ = {"polymorphic_on": "type", + "polymorphic_identity": "foo"} class Bar(Foo): __tablename__ = 'bar' - __mapper_args__ = {"polymorphic_identity":"bar"} + __mapper_args__ = {"polymorphic_identity": "bar"} id = Column(Integer, ForeignKey('foo.id'), primary_key=True) decl.DeferredReflection.prepare(testing.db) |