from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.ext import declarative as decl from sqlalchemy.ext.declarative.base import _DeferredMapperConfig from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect class DeclarativeReflectionBase(fixtures.TablesTest): __requires__ = ("reflectable_autoincrement",) def setup(self): global Base Base = decl.declarative_base(testing.db) def teardown(self): super(DeclarativeReflectionBase, self).teardown() clear_mappers() class DeclarativeReflectionTest(DeclarativeReflectionBase): @classmethod def define_tables(cls, metadata): Table( "users", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("name", String(50)), test_needs_fk=True, ) Table( "addresses", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("email", String(50)), Column("user_id", Integer, ForeignKey("users.id")), test_needs_fk=True, ) Table( "imhandles", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("user_id", Integer), Column("network", String(50)), Column("handle", String(50)), test_needs_fk=True, ) def test_basic(self): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" __autoload__ = True addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" __autoload__ = True u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() eq_( sess.query(User).all(), [ User( name="u1", addresses=[Address(email="one"), Address(email="two")], ) ], ) a1 = sess.query(Address).filter(Address.email == "two").one() eq_(a1, Address(email="two")) eq_(a1.user, User(name="u1")) def test_rekey(self): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" __autoload__ = True nom = Column("name", String(50), key="nom") addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" __autoload__ = True u1 = User( nom="u1", addresses=[Address(email="one"), Address(email="two")] ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() eq_( sess.query(User).all(), [ User( nom="u1", addresses=[Address(email="one"), Address(email="two")], ) ], ) a1 = sess.query(Address).filter(Address.email == "two").one() eq_(a1, Address(email="two")) eq_(a1.user, User(nom="u1")) assert_raises(TypeError, User, name="u3") def test_supplied_fk(self): class IMHandle(Base, fixtures.ComparableEntity): __tablename__ = "imhandles" __autoload__ = True user_id = Column("user_id", Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): __tablename__ = "users" __autoload__ = True handles = relationship("IMHandle", backref="user") u1 = User( name="u1", handles=[ IMHandle(network="blabber", handle="foo"), IMHandle(network="lol", handle="zomg"), ], ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() eq_( sess.query(User).all(), [ User( name="u1", handles=[ IMHandle(network="blabber", handle="foo"), IMHandle(network="lol", handle="zomg"), ], ) ], ) a1 = sess.query(IMHandle).filter(IMHandle.handle == "zomg").one() eq_(a1, IMHandle(network="lol", handle="zomg")) eq_(a1.user, User(name="u1")) class DeferredReflectBase(DeclarativeReflectionBase): def teardown(self): super(DeferredReflectBase, self).teardown() _DeferredMapperConfig._configs.clear() Base = None class DeferredReflectPKFKTest(DeferredReflectBase): @classmethod def define_tables(cls, metadata): Table( "a", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), ) Table( "b", metadata, Column("id", Integer, ForeignKey("a.id"), primary_key=True), Column("x", Integer, primary_key=True), ) def test_pk_fk(self): class B(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "b" a = relationship("A") class A(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "a" decl.DeferredReflection.prepare(testing.db) class DeferredReflectionTest(DeferredReflectBase): @classmethod def define_tables(cls, metadata): Table( "users", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("name", String(50)), test_needs_fk=True, ) Table( "addresses", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("email", String(50)), Column("user_id", Integer, ForeignKey("users.id")), test_needs_fk=True, ) def _roundtrip(self): User = Base._decl_class_registry["User"] Address = Base._decl_class_registry["Address"] u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() eq_( sess.query(User).all(), [ User( name="u1", addresses=[Address(email="one"), Address(email="two")], ) ], ) a1 = sess.query(Address).filter(Address.email == "two").one() eq_(a1, Address(email="two")) eq_(a1.user, User(name="u1")) def test_exception_prepare_not_called(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "users" addresses = relationship("Address", backref="user") class Address( decl.DeferredReflection, fixtures.ComparableEntity, Base ): __tablename__ = "addresses" assert_raises_message( orm_exc.UnmappedClassError, "Class test.ext.declarative.test_reflection.User is a " "subclass of DeferredReflection. Mappings are not produced " r"until the .prepare\(\) method is called on the class " "hierarchy.", Session().query, User, ) def test_basic_deferred(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "users" addresses = relationship("Address", backref="user") class Address( decl.DeferredReflection, fixtures.ComparableEntity, Base ): __tablename__ = "addresses" decl.DeferredReflection.prepare(testing.db) self._roundtrip() def test_abstract_base(self): class DefBase(decl.DeferredReflection, Base): __abstract__ = True class OtherDefBase(decl.DeferredReflection, Base): __abstract__ = True class User(fixtures.ComparableEntity, DefBase): __tablename__ = "users" addresses = relationship("Address", backref="user") class Address(fixtures.ComparableEntity, DefBase): __tablename__ = "addresses" class Fake(OtherDefBase): __tablename__ = "nonexistent" DefBase.prepare(testing.db) self._roundtrip() def test_redefine_fk_double(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "users" addresses = relationship("Address", backref="user") class Address( decl.DeferredReflection, fixtures.ComparableEntity, Base ): __tablename__ = "addresses" user_id = Column(Integer, ForeignKey("users.id")) decl.DeferredReflection.prepare(testing.db) self._roundtrip() def test_mapper_args_deferred(self): """test that __mapper_args__ is not called until *after* table reflection""" class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "users" @decl.declared_attr def __mapper_args__(cls): return {"primary_key": cls.__table__.c.id} decl.DeferredReflection.prepare(testing.db) sess = Session() sess.add_all( [User(name="G"), User(name="Q"), User(name="A"), User(name="C")] ) sess.commit() eq_( sess.query(User).order_by(User.name).all(), [User(name="A"), User(name="C"), User(name="G"), User(name="Q")], ) @testing.requires.predictable_gc def test_cls_not_strong_ref(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "users" class Address( decl.DeferredReflection, fixtures.ComparableEntity, Base ): __tablename__ = "addresses" eq_(len(_DeferredMapperConfig._configs), 2) del Address gc_collect() eq_(len(_DeferredMapperConfig._configs), 1) decl.DeferredReflection.prepare(testing.db) assert not _DeferredMapperConfig._configs class DeferredSecondaryReflectionTest(DeferredReflectBase): @classmethod def define_tables(cls, metadata): Table( "users", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("name", String(50)), test_needs_fk=True, ) Table( "user_items", metadata, Column("user_id", ForeignKey("users.id"), primary_key=True), Column("item_id", ForeignKey("items.id"), primary_key=True), test_needs_fk=True, ) Table( "items", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("name", String(50)), test_needs_fk=True, ) def _roundtrip(self): User = Base._decl_class_registry["User"] Item = Base._decl_class_registry["Item"] u1 = User(name="u1", items=[Item(name="i1"), Item(name="i2")]) sess = Session() sess.add(u1) sess.commit() eq_( sess.query(User).all(), [User(name="u1", items=[Item(name="i1"), Item(name="i2")])], ) def test_string_resolution(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "users" items = relationship("Item", secondary="user_items") class Item(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "items" decl.DeferredReflection.prepare(testing.db) self._roundtrip() def test_table_resolution(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "users" items = relationship( "Item", secondary=Table("user_items", Base.metadata) ) class Item(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "items" decl.DeferredReflection.prepare(testing.db) self._roundtrip() class DeferredInhReflectBase(DeferredReflectBase): def _roundtrip(self): Foo = Base._decl_class_registry["Foo"] Bar = Base._decl_class_registry["Bar"] s = Session(testing.db) s.add_all( [ Bar(data="d1", bar_data="b1"), Bar(data="d2", bar_data="b2"), Bar(data="d3", bar_data="b3"), Foo(data="d4"), ] ) s.commit() eq_( s.query(Foo).order_by(Foo.id).all(), [ Bar(data="d1", bar_data="b1"), Bar(data="d2", bar_data="b2"), Bar(data="d3", bar_data="b3"), Foo(data="d4"), ], ) class DeferredSingleInhReflectionTest(DeferredInhReflectBase): @classmethod def define_tables(cls, metadata): Table( "foo", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("type", String(32)), Column("data", String(30)), Column("bar_data", String(30)), ) def test_basic(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", "polymorphic_identity": "foo", } class Bar(Foo): __mapper_args__ = {"polymorphic_identity": "bar"} decl.DeferredReflection.prepare(testing.db) self._roundtrip() def test_add_subclass_column(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", "polymorphic_identity": "foo", } class Bar(Foo): __mapper_args__ = {"polymorphic_identity": "bar"} bar_data = Column(String(30)) decl.DeferredReflection.prepare(testing.db) self._roundtrip() def test_add_pk_column(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", "polymorphic_identity": "foo", } id = Column(Integer, primary_key=True) class Bar(Foo): __mapper_args__ = {"polymorphic_identity": "bar"} decl.DeferredReflection.prepare(testing.db) self._roundtrip() class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): @classmethod def define_tables(cls, metadata): Table( "foo", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("type", String(32)), Column("data", String(30)), test_needs_fk=True, ) Table( "bar", metadata, Column("id", Integer, ForeignKey("foo.id"), primary_key=True), Column("bar_data", String(30)), test_needs_fk=True, ) def test_basic(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", "polymorphic_identity": "foo", } class Bar(Foo): __tablename__ = "bar" __mapper_args__ = {"polymorphic_identity": "bar"} decl.DeferredReflection.prepare(testing.db) self._roundtrip() def test_add_subclass_column(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", "polymorphic_identity": "foo", } class Bar(Foo): __tablename__ = "bar" __mapper_args__ = {"polymorphic_identity": "bar"} bar_data = Column(String(30)) decl.DeferredReflection.prepare(testing.db) self._roundtrip() def test_add_pk_column(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", "polymorphic_identity": "foo", } id = Column(Integer, primary_key=True) class Bar(Foo): __tablename__ = "bar" __mapper_args__ = {"polymorphic_identity": "bar"} decl.DeferredReflection.prepare(testing.db) self._roundtrip() def test_add_fk_pk_column(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", "polymorphic_identity": "foo", } class Bar(Foo): __tablename__ = "bar" __mapper_args__ = {"polymorphic_identity": "bar"} id = Column(Integer, ForeignKey("foo.id"), primary_key=True) decl.DeferredReflection.prepare(testing.db) self._roundtrip()