summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-04-30 19:44:16 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-04-30 19:55:28 -0400
commite0f9b279f43759886c61e6c82f97d95d0093fdf7 (patch)
tree30752e0689319850969ebd52174532afeab9980e
parent0e98795ff2c7a164b4da164d7b26af3faabf84d1 (diff)
downloadsqlalchemy-e0f9b279f43759886c61e6c82f97d95d0093fdf7.tar.gz
- work the wrapping of the "creator" to be as resilient to
old / new style, direct access, and ad-hoc patching and unpatching as possible
-rw-r--r--lib/sqlalchemy/pool.py39
-rw-r--r--test/engine/test_execute.py2
-rw-r--r--test/engine/test_pool.py55
3 files changed, 87 insertions, 9 deletions
diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py
index 902309d75..8eb9d796d 100644
--- a/lib/sqlalchemy/pool.py
+++ b/lib/sqlalchemy/pool.py
@@ -219,7 +219,7 @@ class Pool(log.Identified):
log.instance_logger(self, echoflag=echo)
self._threadconns = threading.local()
self._creator = creator
- self._wrapped_creator = self._maybe_wrap_callable(creator)
+ self._set_should_wrap_creator()
self._recycle = recycle
self._invalidate_time = 0
self._use_threadlocal = use_threadlocal
@@ -250,16 +250,17 @@ class Pool(log.Identified):
for l in listeners:
self.add_listener(l)
- def _maybe_wrap_callable(self, fn):
+ def _set_should_wrap_creator(self):
"""Detect if creator accepts a single argument, or is sent
as a legacy style no-arg function.
"""
try:
- argspec = util.get_callable_argspec(fn, no_self=True)
+ argspec = util.get_callable_argspec(self._creator, no_self=True)
except TypeError:
- return lambda ctx: fn()
+ self._should_wrap_creator = (True, self._creator)
+ return
defaulted = argspec[3] is not None and len(argspec[3]) or 0
positionals = len(argspec[0]) - defaulted
@@ -267,14 +268,36 @@ class Pool(log.Identified):
# look for the exact arg signature that DefaultStrategy
# sends us
if (argspec[0], argspec[3]) == (['connection_record'], (None,)):
- return fn
+ self._should_wrap_creator = (False, self._creator)
# or just a single positional
elif positionals == 1:
- return fn
+ self._should_wrap_creator = (False, self._creator)
# all other cases, just wrap and assume legacy "creator" callable
# thing
else:
- return lambda ctx: fn()
+ self._should_wrap_creator = (True, self._creator)
+
+ def _invoke_creator(self, connection_record):
+ """adjust for old or new style "creator" callable.
+
+ This function is spending extra effort in order to accommodate
+ any degree of manipulation of the _creator callable by end-user
+ applications, including ad-hoc patching in test suites.
+
+ """
+
+ should_wrap, against_creator = self._should_wrap_creator
+ creator = self._creator
+
+ if creator is not against_creator:
+ # check if the _creator function has been patched since
+ # we last looked at it
+ self._set_should_wrap_creator()
+ return self._invoke_creator(connection_record)
+ elif should_wrap:
+ return self._creator()
+ else:
+ return self._creator(connection_record)
def _close_connection(self, connection):
self.logger.debug("Closing connection %r", connection)
@@ -591,7 +614,7 @@ class _ConnectionRecord(object):
def __connect(self):
try:
self.starttime = time.time()
- connection = self.__pool._wrapped_creator(self)
+ connection = self.__pool._invoke_creator(self)
self.__pool.logger.debug("Created new connection %r", connection)
return connection
except Exception as e:
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index cba3972f6..761ac102a 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -2174,7 +2174,7 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase):
conn.invalidate()
- eng.pool._wrapped_creator = Mock(
+ eng.pool._creator = Mock(
side_effect=self.ProgrammingError(
"Cannot operate on a closed database."))
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py
index 3d93cda89..912c6c3fe 100644
--- a/test/engine/test_pool.py
+++ b/test/engine/test_pool.py
@@ -1807,3 +1807,58 @@ class StaticPoolTest(PoolTestBase):
p = pool.StaticPool(creator)
p2 = p.recreate()
assert p._creator is p2._creator
+
+
+class CreatorCompatibilityTest(PoolTestBase):
+ def test_creator_callable_outside_noarg(self):
+ e = testing_engine()
+
+ creator = e.pool._creator
+ try:
+ conn = creator()
+ finally:
+ conn.close()
+
+ def test_creator_callable_outside_witharg(self):
+ e = testing_engine()
+
+ creator = e.pool._creator
+ try:
+ conn = creator(Mock())
+ finally:
+ conn.close()
+
+ def test_creator_patching_arg_to_noarg(self):
+ e = testing_engine()
+ creator = e.pool._creator
+ try:
+ # the creator is the two-arg form
+ conn = creator(Mock())
+ finally:
+ conn.close()
+
+ def mock_create():
+ return creator()
+
+ conn = e.connect()
+ conn.invalidate()
+ conn.close()
+
+ # test that the 'should_wrap_creator' memoized attribute
+ # will dynamically switch if the _creator is monkeypatched.
+
+ is_(e.pool.__dict__.get("_should_wrap_creator")[0], False)
+
+ # patch it with a zero-arg form
+ with patch.object(e.pool, "_creator", mock_create):
+ conn = e.connect()
+ conn.invalidate()
+ conn.close()
+
+ is_(e.pool.__dict__.get("_should_wrap_creator")[0], True)
+
+ conn = e.connect()
+ conn.close()
+
+ is_(e.pool.__dict__.get("_should_wrap_creator")[0], False)
+