summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-12-08 03:10:59 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-12-08 03:10:59 +0000
commita1a60c88ff2233d21e4dd5fb20eee27f93118021 (patch)
tree07ab91d6d5549e20b5852bd759f5cd266fc9df39
parent2305e22d6bde2161f5bee25514f0a8444cce8416 (diff)
downloadsqlalchemy-a1a60c88ff2233d21e4dd5fb20eee27f93118021.tar.gz
- merge of trunk r6544
- Session.execute() now locates table- and mapper-specific binds based on a passed in expression which is an insert()/update()/delete() construct. [ticket:1054]
-rw-r--r--CHANGES7
-rw-r--r--lib/sqlalchemy/orm/session.py16
-rw-r--r--lib/sqlalchemy/sql/util.py10
-rw-r--r--test/orm/test_session.py69
4 files changed, 68 insertions, 34 deletions
diff --git a/CHANGES b/CHANGES
index b97828ff0..525a9112f 100644
--- a/CHANGES
+++ b/CHANGES
@@ -22,7 +22,12 @@ CHANGES
various unserializable options like those generated
by contains_eager() out of individual instance states.
[ticket:1553]
-
+
+ - Session.execute() now locates table- and
+ mapper-specific binds based on a passed
+ in expression which is an insert()/update()/delete()
+ construct. [ticket:1054]
+
- Fixed a needless select which would occur when merging
transient objects that contained a null primary key
identifier. [ticket:1618]
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 3dea2f6a3..be0821e3b 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -570,13 +570,11 @@ class Session(object):
self._mapper_flush_opts = {}
if binds is not None:
- for mapperortable, value in binds.iteritems():
- if isinstance(mapperortable, type):
- mapperortable = _class_mapper(mapperortable).base_mapper
- self.__binds[mapperortable] = value
- if isinstance(mapperortable, Mapper):
- for t in mapperortable._all_tables:
- self.__binds[t] = value
+ for mapperortable, bind in binds.iteritems():
+ if isinstance(mapperortable, (type, Mapper)):
+ self.bind_mapper(mapperortable, bind)
+ else:
+ self.bind_table(mapperortable, bind)
if not self.autocommit:
self.begin()
@@ -857,7 +855,7 @@ class Session(object):
"a binding.")
c_mapper = mapper is not None and _class_to_mapper(mapper) or None
-
+
# manually bound?
if self.__binds:
if c_mapper:
@@ -866,7 +864,7 @@ class Session(object):
elif c_mapper.mapped_table in self.__binds:
return self.__binds[c_mapper.mapped_table]
if clause:
- for t in sql_util.find_tables(clause):
+ for t in sql_util.find_tables(clause, include_crud=True):
if t in self.__binds:
return self.__binds[t]
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 9be405e21..dccd3d462 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -47,7 +47,9 @@ def find_join_source(clauses, join_to):
return None, None
-def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False):
+def find_tables(clause, check_columns=False,
+ include_aliases=False, include_joins=False,
+ include_selects=False, include_crud=False):
"""locate Table objects within the given expression."""
tables = []
@@ -61,7 +63,11 @@ def find_tables(clause, check_columns=False, include_aliases=False, include_join
if include_aliases:
_visitors['alias'] = tables.append
-
+
+ if include_crud:
+ _visitors['insert'] = _visitors['update'] = \
+ _visitors['delete'] = lambda ent: tables.append(ent.table)
+
if check_columns:
def visit_column(column):
tables.append(column.table)
diff --git a/test/orm/test_session.py b/test/orm/test_session.py
index 89923081a..828dd1316 100644
--- a/test/orm/test_session.py
+++ b/test/orm/test_session.py
@@ -105,49 +105,74 @@ class SessionTest(_fixtures.FixtureTest):
@engines.close_open_connections
@testing.resolve_artifact_names
- def test_table_binds_from_expression(self):
- """Session can extract Table objects from ClauseElements and match them to tables."""
+ def test_mapped_binds(self):
- mapper(Address, addresses)
- mapper(User, users, properties={
+ # ensure tables are unbound
+ m2 = sa.MetaData()
+ users_unbound =users.tometadata(m2)
+ addresses_unbound = addresses.tometadata(m2)
+
+ mapper(Address, addresses_unbound)
+ mapper(User, users_unbound, properties={
'addresses':relation(Address,
backref=backref("user", cascade="all"),
cascade="all")})
- Session = sessionmaker(binds={users: self.metadata.bind,
- addresses: self.metadata.bind})
+ Session = sessionmaker(binds={User: self.metadata.bind,
+ Address: self.metadata.bind})
sess = Session()
- sess.execute(users.insert(), params=dict(id=1, name='ed'))
- eq_(sess.execute(users.select(users.c.id == 1)).fetchall(),
- [(1, 'ed')])
+ u1 = User(id=1, name='ed')
+ sess.add(u1)
+ eq_(sess.query(User).filter(User.id==1).all(),
+ [User(id=1, name='ed')])
+
+ # test expression binding
+ sess.execute(users_unbound.insert(), params=dict(id=2, name='jack'))
+ eq_(sess.execute(users_unbound.select(users_unbound.c.id == 2)).fetchall(),
+ [(2, 'jack')])
- eq_(sess.execute(users.select(User.id == 1)).fetchall(),
- [(1, 'ed')])
+ eq_(sess.execute(users_unbound.select(User.id == 2)).fetchall(),
+ [(2, 'jack')])
+ sess.execute(users_unbound.delete())
+ eq_(sess.execute(users_unbound.select()).fetchall(), [])
+
sess.close()
@engines.close_open_connections
@testing.resolve_artifact_names
- def test_mapped_binds_from_expression(self):
- """Session can extract Table objects from ClauseElements and match them to tables."""
+ def test_table_binds(self):
- mapper(Address, addresses)
- mapper(User, users, properties={
+ # ensure tables are unbound
+ m2 = sa.MetaData()
+ users_unbound =users.tometadata(m2)
+ addresses_unbound = addresses.tometadata(m2)
+
+ mapper(Address, addresses_unbound)
+ mapper(User, users_unbound, properties={
'addresses':relation(Address,
backref=backref("user", cascade="all"),
cascade="all")})
- Session = sessionmaker(binds={User: self.metadata.bind,
- Address: self.metadata.bind})
+ Session = sessionmaker(binds={users_unbound: self.metadata.bind,
+ addresses_unbound: self.metadata.bind})
sess = Session()
- sess.execute(users.insert(), params=dict(id=1, name='ed'))
- eq_(sess.execute(users.select(users.c.id == 1)).fetchall(),
- [(1, 'ed')])
+ u1 = User(id=1, name='ed')
+ sess.add(u1)
+ eq_(sess.query(User).filter(User.id==1).all(),
+ [User(id=1, name='ed')])
+
+ sess.execute(users_unbound.insert(), params=dict(id=2, name='jack'))
+ eq_(sess.execute(users_unbound.select(users_unbound.c.id == 2)).fetchall(),
+ [(2, 'jack')])
+
+ eq_(sess.execute(users_unbound.select(User.id == 2)).fetchall(),
+ [(2, 'jack')])
- eq_(sess.execute(users.select(User.id == 1)).fetchall(),
- [(1, 'ed')])
+ sess.execute(users_unbound.delete())
+ eq_(sess.execute(users_unbound.select()).fetchall(), [])
sess.close()