diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-12-08 03:10:59 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-12-08 03:10:59 +0000 |
commit | a1a60c88ff2233d21e4dd5fb20eee27f93118021 (patch) | |
tree | 07ab91d6d5549e20b5852bd759f5cd266fc9df39 | |
parent | 2305e22d6bde2161f5bee25514f0a8444cce8416 (diff) | |
download | sqlalchemy-a1a60c88ff2233d21e4dd5fb20eee27f93118021.tar.gz |
- merge of trunk r6544
- Session.execute() now locates table- and
mapper-specific binds based on a passed
in expression which is an insert()/update()/delete()
construct. [ticket:1054]
-rw-r--r-- | CHANGES | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 16 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 10 | ||||
-rw-r--r-- | test/orm/test_session.py | 69 |
4 files changed, 68 insertions, 34 deletions
@@ -22,7 +22,12 @@ CHANGES various unserializable options like those generated by contains_eager() out of individual instance states. [ticket:1553] - + + - Session.execute() now locates table- and + mapper-specific binds based on a passed + in expression which is an insert()/update()/delete() + construct. [ticket:1054] + - Fixed a needless select which would occur when merging transient objects that contained a null primary key identifier. [ticket:1618] diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 3dea2f6a3..be0821e3b 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -570,13 +570,11 @@ class Session(object): self._mapper_flush_opts = {} if binds is not None: - for mapperortable, value in binds.iteritems(): - if isinstance(mapperortable, type): - mapperortable = _class_mapper(mapperortable).base_mapper - self.__binds[mapperortable] = value - if isinstance(mapperortable, Mapper): - for t in mapperortable._all_tables: - self.__binds[t] = value + for mapperortable, bind in binds.iteritems(): + if isinstance(mapperortable, (type, Mapper)): + self.bind_mapper(mapperortable, bind) + else: + self.bind_table(mapperortable, bind) if not self.autocommit: self.begin() @@ -857,7 +855,7 @@ class Session(object): "a binding.") c_mapper = mapper is not None and _class_to_mapper(mapper) or None - + # manually bound? if self.__binds: if c_mapper: @@ -866,7 +864,7 @@ class Session(object): elif c_mapper.mapped_table in self.__binds: return self.__binds[c_mapper.mapped_table] if clause: - for t in sql_util.find_tables(clause): + for t in sql_util.find_tables(clause, include_crud=True): if t in self.__binds: return self.__binds[t] diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 9be405e21..dccd3d462 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -47,7 +47,9 @@ def find_join_source(clauses, join_to): return None, None -def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False): +def find_tables(clause, check_columns=False, + include_aliases=False, include_joins=False, + include_selects=False, include_crud=False): """locate Table objects within the given expression.""" tables = [] @@ -61,7 +63,11 @@ def find_tables(clause, check_columns=False, include_aliases=False, include_join if include_aliases: _visitors['alias'] = tables.append - + + if include_crud: + _visitors['insert'] = _visitors['update'] = \ + _visitors['delete'] = lambda ent: tables.append(ent.table) + if check_columns: def visit_column(column): tables.append(column.table) diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 89923081a..828dd1316 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -105,49 +105,74 @@ class SessionTest(_fixtures.FixtureTest): @engines.close_open_connections @testing.resolve_artifact_names - def test_table_binds_from_expression(self): - """Session can extract Table objects from ClauseElements and match them to tables.""" + def test_mapped_binds(self): - mapper(Address, addresses) - mapper(User, users, properties={ + # ensure tables are unbound + m2 = sa.MetaData() + users_unbound =users.tometadata(m2) + addresses_unbound = addresses.tometadata(m2) + + mapper(Address, addresses_unbound) + mapper(User, users_unbound, properties={ 'addresses':relation(Address, backref=backref("user", cascade="all"), cascade="all")}) - Session = sessionmaker(binds={users: self.metadata.bind, - addresses: self.metadata.bind}) + Session = sessionmaker(binds={User: self.metadata.bind, + Address: self.metadata.bind}) sess = Session() - sess.execute(users.insert(), params=dict(id=1, name='ed')) - eq_(sess.execute(users.select(users.c.id == 1)).fetchall(), - [(1, 'ed')]) + u1 = User(id=1, name='ed') + sess.add(u1) + eq_(sess.query(User).filter(User.id==1).all(), + [User(id=1, name='ed')]) + + # test expression binding + sess.execute(users_unbound.insert(), params=dict(id=2, name='jack')) + eq_(sess.execute(users_unbound.select(users_unbound.c.id == 2)).fetchall(), + [(2, 'jack')]) - eq_(sess.execute(users.select(User.id == 1)).fetchall(), - [(1, 'ed')]) + eq_(sess.execute(users_unbound.select(User.id == 2)).fetchall(), + [(2, 'jack')]) + sess.execute(users_unbound.delete()) + eq_(sess.execute(users_unbound.select()).fetchall(), []) + sess.close() @engines.close_open_connections @testing.resolve_artifact_names - def test_mapped_binds_from_expression(self): - """Session can extract Table objects from ClauseElements and match them to tables.""" + def test_table_binds(self): - mapper(Address, addresses) - mapper(User, users, properties={ + # ensure tables are unbound + m2 = sa.MetaData() + users_unbound =users.tometadata(m2) + addresses_unbound = addresses.tometadata(m2) + + mapper(Address, addresses_unbound) + mapper(User, users_unbound, properties={ 'addresses':relation(Address, backref=backref("user", cascade="all"), cascade="all")}) - Session = sessionmaker(binds={User: self.metadata.bind, - Address: self.metadata.bind}) + Session = sessionmaker(binds={users_unbound: self.metadata.bind, + addresses_unbound: self.metadata.bind}) sess = Session() - sess.execute(users.insert(), params=dict(id=1, name='ed')) - eq_(sess.execute(users.select(users.c.id == 1)).fetchall(), - [(1, 'ed')]) + u1 = User(id=1, name='ed') + sess.add(u1) + eq_(sess.query(User).filter(User.id==1).all(), + [User(id=1, name='ed')]) + + sess.execute(users_unbound.insert(), params=dict(id=2, name='jack')) + eq_(sess.execute(users_unbound.select(users_unbound.c.id == 2)).fetchall(), + [(2, 'jack')]) + + eq_(sess.execute(users_unbound.select(User.id == 2)).fetchall(), + [(2, 'jack')]) - eq_(sess.execute(users.select(User.id == 1)).fetchall(), - [(1, 'ed')]) + sess.execute(users_unbound.delete()) + eq_(sess.execute(users_unbound.select()).fetchall(), []) sess.close() |