summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r--lib/sqlalchemy/orm/session.py34
1 files changed, 20 insertions, 14 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 9a0438fc9..9e504c104 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -4,7 +4,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import util, exceptions, sql
+from sqlalchemy import util, exceptions, sql, engine
from sqlalchemy.orm import unitofwork, query
from sqlalchemy.orm.mapper import object_mapper as _object_mapper
from sqlalchemy.orm.mapper import class_mapper as _class_mapper
@@ -30,8 +30,6 @@ class SessionTransaction(object):
def connection(self, mapper_or_class, entity_name=None):
if isinstance(mapper_or_class, type):
mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
- if self.parent is not None:
- return self.parent.connection(mapper_or_class)
engine = self.session.get_bind(mapper_or_class)
return self.get_or_add(engine)
@@ -39,28 +37,36 @@ class SessionTransaction(object):
return SessionTransaction(self.session, self)
def add(self, bind):
+ if self.parent is not None:
+ return self.parent.add(bind)
+
if self.connections.has_key(bind.engine):
raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
return self.get_or_add(bind)
def get_or_add(self, bind):
- # we reference the 'engine' attribute on the given object, which in the case of
- # Connection, ProxyEngine, Engine, whatever, should return the original
- # "Engine" object that is handling the connection.
- if self.connections.has_key(bind.engine):
- return self.connections[bind.engine][0]
- e = bind.engine
- c = bind.contextual_connect()
- if not self.connections.has_key(e):
- self.connections[e] = (c, c.begin(), c is not bind)
- return self.connections[e][0]
+ if self.parent is not None:
+ return self.parent.get_or_add(bind)
+
+ if self.connections.has_key(bind):
+ return self.connections[bind][0]
+
+ if not isinstance(bind, engine.Connection):
+ e = bind
+ c = bind.contextual_connect()
+ else:
+ e = bind.engine
+ c = bind
+
+ self.connections[bind] = self.connections[e] = (c, c.begin(), c is not bind)
+ return self.connections[bind][0]
def commit(self):
if self.parent is not None:
return
if self.autoflush:
self.session.flush()
- for t in self.connections.values():
+ for t in util.Set(self.connections.values()):
t[1].commit()
self.close()