import sqlalchemy as sa from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import ForeignKeyConstraint from sqlalchemy import Index from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import util from sqlalchemy.ext import declarative as decl from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.declarative import synonym_for from sqlalchemy.orm import backref from sqlalchemy.orm import class_mapper from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import column_property from sqlalchemy.orm import composite from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import create_session from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper from sqlalchemy.orm import properties from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm.events import MapperEvents from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import assertions from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.util import with_metaclass Base = None User = Address = None class DeclarativeTestBase( fixtures.TestBase, testing.AssertsExecutionResults, testing.AssertsCompiledSQL, ): __dialect__ = "default" def setup(self): global Base Base = decl.declarative_base(testing.db) def teardown(self): Session.close_all() clear_mappers() Base.metadata.drop_all() class DeclarativeTest(DeclarativeTestBase): def test_basic(self): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) name = Column("name", String(50)) addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" id = Column( Integer, primary_key=True, test_needs_autoincrement=True ) email = Column(String(50), key="_email") user_id = Column( "user_id", Integer, ForeignKey("users.id"), key="_user_id" ) Base.metadata.create_all() eq_(Address.__table__.c["id"].name, "id") eq_(Address.__table__.c["_email"].name, "email") eq_(Address.__table__.c["_user_id"].name, "user_id") u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() eq_( sess.query(User).all(), [ User( name="u1", addresses=[Address(email="one"), Address(email="two")], ) ], ) a1 = sess.query(Address).filter(Address.email == "two").one() eq_(a1, Address(email="two")) eq_(a1.user, User(name="u1")) def test_unicode_string_resolve(self): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) name = Column("name", String(50)) addresses = relationship(util.u("Address"), backref="user") class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" id = Column( Integer, primary_key=True, test_needs_autoincrement=True ) email = Column(String(50), key="_email") user_id = Column( "user_id", Integer, ForeignKey("users.id"), key="_user_id" ) assert User.addresses.property.mapper.class_ is Address def test_unicode_string_resolve_backref(self): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) name = Column("name", String(50)) class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" id = Column( Integer, primary_key=True, test_needs_autoincrement=True ) email = Column(String(50), key="_email") user_id = Column( "user_id", Integer, ForeignKey("users.id"), key="_user_id" ) user = relationship( User, backref=backref("addresses", order_by=util.u("Address.email")), ) assert Address.user.property.mapper.class_ is User def test_no_table(self): def go(): class User(Base): id = Column("id", Integer, primary_key=True) assert_raises_message( sa.exc.InvalidRequestError, "does not have a __table__", go ) def test_table_args_empty_dict(self): class MyModel(Base): __tablename__ = "test" id = Column(Integer, primary_key=True) __table_args__ = {} def test_table_args_empty_tuple(self): class MyModel(Base): __tablename__ = "test" id = Column(Integer, primary_key=True) __table_args__ = () def test_cant_add_columns(self): t = Table( "t", Base.metadata, Column("id", Integer, primary_key=True), Column("data", String), ) def go(): class User(Base): __table__ = t foo = Column(Integer, primary_key=True) # can't specify new columns not already in the table assert_raises_message( sa.exc.ArgumentError, "Can't add additional column 'foo' when " "specifying __table__", go, ) # regular re-mapping works tho class Bar(Base): __table__ = t some_data = t.c.data assert ( class_mapper(Bar).get_property("some_data").columns[0] is t.c.data ) def test_lower_case_c_column_warning(self): with assertions.expect_warnings( r"Attribute 'x' on class " ): class MyBase(Base): __tablename__ = "foo" id = Column(Integer, primary_key=True) @declared_attr.cascading def somecol(cls): return Column(Integer) def test_column(self): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) name = Column("name", String(50)) User.a = Column("a", String(10)) User.b = Column(String(10)) Base.metadata.create_all() u1 = User(name="u1", a="a", b="b") eq_(u1.a, "a") eq_(User.a.get_history(u1), (["a"], (), ())) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() eq_(sess.query(User).all(), [User(name="u1", a="a", b="b")]) def test_column_properties(self): class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" id = Column( Integer, primary_key=True, test_needs_autoincrement=True ) email = Column(String(50)) user_id = Column(Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) name = Column("name", String(50)) adr_count = sa.orm.column_property( sa.select( [sa.func.count(Address.id)], Address.user_id == id ).as_scalar() ) addresses = relationship(Address) Base.metadata.create_all() u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() eq_( sess.query(User).all(), [ User( name="u1", adr_count=2, addresses=[Address(email="one"), Address(email="two")], ) ], ) def test_column_properties_2(self): class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" id = Column(Integer, primary_key=True) email = Column(String(50)) user_id = Column(Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column("id", Integer, primary_key=True) name = Column("name", String(50)) # this is not "valid" but we want to test that Address.id # doesn't get stuck into user's table adr_count = Address.id eq_(set(User.__table__.c.keys()), set(["id", "name"])) eq_(set(Address.__table__.c.keys()), set(["id", "email", "user_id"])) def test_deferred(self): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( Integer, primary_key=True, test_needs_autoincrement=True ) name = sa.orm.deferred(Column(String(50))) Base.metadata.create_all() sess = create_session() sess.add(User(name="u1")) sess.flush() sess.expunge_all() u1 = sess.query(User).filter(User.name == "u1").one() assert "name" not in u1.__dict__ def go(): eq_(u1.name, "u1") self.assert_sql_count(testing.db, go, 1) def test_composite_inline(self): class AddressComposite(fixtures.ComparableEntity): def __init__(self, street, state): self.street = street self.state = state def __composite_values__(self): return [self.street, self.state] class User(Base, fixtures.ComparableEntity): __tablename__ = "user" id = Column( Integer, primary_key=True, test_needs_autoincrement=True ) address = composite( AddressComposite, Column("street", String(50)), Column("state", String(2)), ) Base.metadata.create_all() sess = Session() sess.add(User(address=AddressComposite("123 anywhere street", "MD"))) sess.commit() eq_( sess.query(User).all(), [User(address=AddressComposite("123 anywhere street", "MD"))], ) def test_composite_separate(self): class AddressComposite(fixtures.ComparableEntity): def __init__(self, street, state): self.street = street self.state = state def __composite_values__(self): return [self.street, self.state] class User(Base, fixtures.ComparableEntity): __tablename__ = "user" id = Column( Integer, primary_key=True, test_needs_autoincrement=True ) street = Column(String(50)) state = Column(String(2)) address = composite(AddressComposite, street, state) Base.metadata.create_all() sess = Session() sess.add(User(address=AddressComposite("123 anywhere street", "MD"))) sess.commit() eq_( sess.query(User).all(), [User(address=AddressComposite("123 anywhere street", "MD"))], ) def test_mapping_to_join(self): users = Table( "users", Base.metadata, Column("id", Integer, primary_key=True) ) addresses = Table( "addresses", Base.metadata, Column("id", Integer, primary_key=True), Column("user_id", Integer, ForeignKey("users.id")), ) usersaddresses = sa.join( users, addresses, users.c.id == addresses.c.user_id ) class User(Base): __table__ = usersaddresses __table_args__ = {"primary_key": [users.c.id]} # need to use column_property for now user_id = column_property(users.c.id, addresses.c.user_id) address_id = addresses.c.id assert User.__mapper__.get_property("user_id").columns[0] is users.c.id assert ( User.__mapper__.get_property("user_id").columns[1] is addresses.c.user_id ) def test_synonym_inline(self): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) _name = Column("name", String(50)) def _set_name(self, name): self._name = "SOMENAME " + name def _get_name(self): return self._name name = sa.orm.synonym( "_name", descriptor=property(_get_name, _set_name) ) Base.metadata.create_all() sess = create_session() u1 = User(name="someuser") eq_(u1.name, "SOMENAME someuser") sess.add(u1) sess.flush() eq_( sess.query(User).filter(User.name == "SOMENAME someuser").one(), u1 ) def test_synonym_no_descriptor(self): from sqlalchemy.orm.properties import ColumnProperty class CustomCompare(ColumnProperty.Comparator): __hash__ = None def __eq__(self, other): return self.__clause_element__() == other + " FOO" class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) _name = Column("name", String(50)) name = sa.orm.synonym("_name", comparator_factory=CustomCompare) Base.metadata.create_all() sess = create_session() u1 = User(name="someuser FOO") sess.add(u1) sess.flush() eq_(sess.query(User).filter(User.name == "someuser").one(), u1) def test_synonym_added(self): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) _name = Column("name", String(50)) def _set_name(self, name): self._name = "SOMENAME " + name def _get_name(self): return self._name name = property(_get_name, _set_name) User.name = sa.orm.synonym("_name", descriptor=User.name) Base.metadata.create_all() sess = create_session() u1 = User(name="someuser") eq_(u1.name, "SOMENAME someuser") sess.add(u1) sess.flush() eq_( sess.query(User).filter(User.name == "SOMENAME someuser").one(), u1 ) def test_reentrant_compile_via_foreignkey(self): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) name = Column("name", String(50)) addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) email = Column("email", String(50)) user_id = Column("user_id", Integer, ForeignKey(User.id)) # previous versions would force a re-entrant mapper compile via # the User.id inside the ForeignKey but this is no longer the # case sa.orm.configure_mappers() eq_( list(Address.user_id.property.columns[0].foreign_keys)[0].column, User.__table__.c.id, ) Base.metadata.create_all() u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() eq_( sess.query(User).all(), [ User( name="u1", addresses=[Address(email="one"), Address(email="two")], ) ], ) def test_relationship_reference(self): class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) email = Column("email", String(50)) user_id = Column("user_id", Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) name = Column("name", String(50)) addresses = relationship( "Address", backref="user", primaryjoin=id == Address.user_id ) User.address_count = sa.orm.column_property( sa.select([sa.func.count(Address.id)]) .where(Address.user_id == User.id) .as_scalar() ) Base.metadata.create_all() u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() eq_( sess.query(User).all(), [ User( name="u1", address_count=2, addresses=[Address(email="one"), Address(email="two")], ) ], ) def test_pk_with_fk_init(self): class Bar(Base): __tablename__ = "bar" id = sa.Column( sa.Integer, sa.ForeignKey("foo.id"), primary_key=True ) ex = sa.Column(sa.Integer, primary_key=True) class Foo(Base): __tablename__ = "foo" id = sa.Column(sa.Integer, primary_key=True) bars = sa.orm.relationship(Bar) assert Bar.__mapper__.primary_key[0] is Bar.__table__.c.id assert Bar.__mapper__.primary_key[1] is Bar.__table__.c.ex def test_with_explicit_autoloaded(self): meta = MetaData(testing.db) t1 = Table( "t1", meta, Column("id", String(50), primary_key=True), Column("data", String(50)), ) meta.create_all() try: class MyObj(Base): __table__ = Table("t1", Base.metadata, autoload=True) sess = create_session() m = MyObj(id="someid", data="somedata") sess.add(m) sess.flush() eq_(t1.select().execute().fetchall(), [("someid", "somedata")]) finally: meta.drop_all() def test_synonym_for(self): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) name = Column("name", String(50)) @decl.synonym_for("name") @property def namesyn(self): return self.name Base.metadata.create_all() sess = create_session() u1 = User(name="someuser") eq_(u1.name, "someuser") eq_(u1.namesyn, "someuser") sess.add(u1) sess.flush() rt = sess.query(User).filter(User.namesyn == "someuser").one() eq_(rt, u1) def test_comparable_using(self): class NameComparator(sa.orm.PropComparator): @property def upperself(self): cls = self.prop.parent.class_ col = getattr(cls, "name") return sa.func.upper(col) def operate(self, op, other, **kw): return op(self.upperself, other, **kw) class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ) name = Column("name", String(50)) @decl.comparable_using(NameComparator) @property def uc_name(self): return self.name is not None and self.name.upper() or None Base.metadata.create_all() sess = create_session() u1 = User(name="someuser") eq_(u1.name, "someuser", u1.name) eq_(u1.uc_name, "SOMEUSER", u1.uc_name) sess.add(u1) sess.flush() sess.expunge_all() rt = sess.query(User).filter(User.uc_name == "SOMEUSER").one() eq_(rt, u1) sess.expunge_all() rt = sess.query(User).filter(User.uc_name.startswith("SOMEUSE")).one() eq_(rt, u1) def test_duplicate_classes_in_base(self): class Test(Base): __tablename__ = "a" id = Column(Integer, primary_key=True) assert_raises_message( sa.exc.SAWarning, "This declarative base already contains a class with ", lambda: type(Base)( "Test", (Base,), dict(__tablename__="b", id=Column(Integer, primary_key=True)), ), ) @testing.teardown_events(MapperEvents) def test_instrument_class_before_instrumentation(self): # test #3388 canary = mock.Mock() @event.listens_for(mapper, "instrument_class") def instrument_class(mp, cls): canary.instrument_class(mp, cls) @event.listens_for(object, "class_instrument") def class_instrument(cls): canary.class_instrument(cls) class Test(Base): __tablename__ = "test" id = Column(Integer, primary_key=True) # MARKMARK eq_( canary.mock_calls, [ mock.call.instrument_class(Test.__mapper__, Test), mock.call.class_instrument(Test), ], ) def test_cls_docstring(self): class MyBase(object): """MyBase Docstring""" Base = decl.declarative_base(cls=MyBase) eq_(Base.__doc__, MyBase.__doc__) def _produce_test(inline, stringbased): class ExplicitJoinTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global User, Address Base = decl.declarative_base(metadata=metadata) class User(Base, fixtures.ComparableEntity): __tablename__ = "users" id = Column( Integer, primary_key=True, test_needs_autoincrement=True ) name = Column(String(50)) class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" id = Column( Integer, primary_key=True, test_needs_autoincrement=True ) email = Column(String(50)) user_id = Column(Integer, ForeignKey("users.id")) if inline: if stringbased: user = relationship( "User", primaryjoin="User.id==Address.user_id", backref="addresses", ) else: user = relationship( User, primaryjoin=User.id == user_id, backref="addresses", ) if not inline: configure_mappers() if stringbased: Address.user = relationship( "User", primaryjoin="User.id==Address.user_id", backref="addresses", ) else: Address.user = relationship( User, primaryjoin=User.id == Address.user_id, backref="addresses", ) @classmethod def insert_data(cls): params = [ dict(list(zip(("id", "name"), column_values))) for column_values in [ (7, "jack"), (8, "ed"), (9, "fred"), (10, "chuck"), ] ] User.__table__.insert().execute(params) Address.__table__.insert().execute( [ dict(list(zip(("id", "user_id", "email"), column_values))) for column_values in [ (1, 7, "jack@bean.com"), (2, 8, "ed@wood.com"), (3, 8, "ed@bettyboop.com"), (4, 8, "ed@lala.com"), (5, 9, "fred@fred.com"), ] ] ) def test_aliased_join(self): # this query will screw up if the aliasing enabled in # query.join() gets applied to the right half of the join # condition inside the any(). the join condition inside of # any() comes from the "primaryjoin" of the relationship, # and should not be annotated with _orm_adapt. # PropertyLoader.Comparator will annotate the left side with # _orm_adapt, though. sess = create_session() eq_( sess.query(User) .join(User.addresses, aliased=True) .filter(Address.email == "ed@wood.com") .filter(User.addresses.any(Address.email == "jack@bean.com")) .all(), [], ) ExplicitJoinTest.__name__ = "ExplicitJoinTest%s%s" % ( inline and "Inline" or "Separate", stringbased and "String" or "Literal", ) return ExplicitJoinTest for inline in True, False: for stringbased in True, False: testclass = _produce_test(inline, stringbased) exec("%s = testclass" % testclass.__name__) del testclass