diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-04-30 19:44:16 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-04-30 19:55:28 -0400 |
commit | e0f9b279f43759886c61e6c82f97d95d0093fdf7 (patch) | |
tree | 30752e0689319850969ebd52174532afeab9980e | |
parent | 0e98795ff2c7a164b4da164d7b26af3faabf84d1 (diff) | |
download | sqlalchemy-e0f9b279f43759886c61e6c82f97d95d0093fdf7.tar.gz |
- work the wrapping of the "creator" to be as resilient to
old / new style, direct access, and ad-hoc patching and
unpatching as possible
-rw-r--r-- | lib/sqlalchemy/pool.py | 39 | ||||
-rw-r--r-- | test/engine/test_execute.py | 2 | ||||
-rw-r--r-- | test/engine/test_pool.py | 55 |
3 files changed, 87 insertions, 9 deletions
diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 902309d75..8eb9d796d 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -219,7 +219,7 @@ class Pool(log.Identified): log.instance_logger(self, echoflag=echo) self._threadconns = threading.local() self._creator = creator - self._wrapped_creator = self._maybe_wrap_callable(creator) + self._set_should_wrap_creator() self._recycle = recycle self._invalidate_time = 0 self._use_threadlocal = use_threadlocal @@ -250,16 +250,17 @@ class Pool(log.Identified): for l in listeners: self.add_listener(l) - def _maybe_wrap_callable(self, fn): + def _set_should_wrap_creator(self): """Detect if creator accepts a single argument, or is sent as a legacy style no-arg function. """ try: - argspec = util.get_callable_argspec(fn, no_self=True) + argspec = util.get_callable_argspec(self._creator, no_self=True) except TypeError: - return lambda ctx: fn() + self._should_wrap_creator = (True, self._creator) + return defaulted = argspec[3] is not None and len(argspec[3]) or 0 positionals = len(argspec[0]) - defaulted @@ -267,14 +268,36 @@ class Pool(log.Identified): # look for the exact arg signature that DefaultStrategy # sends us if (argspec[0], argspec[3]) == (['connection_record'], (None,)): - return fn + self._should_wrap_creator = (False, self._creator) # or just a single positional elif positionals == 1: - return fn + self._should_wrap_creator = (False, self._creator) # all other cases, just wrap and assume legacy "creator" callable # thing else: - return lambda ctx: fn() + self._should_wrap_creator = (True, self._creator) + + def _invoke_creator(self, connection_record): + """adjust for old or new style "creator" callable. + + This function is spending extra effort in order to accommodate + any degree of manipulation of the _creator callable by end-user + applications, including ad-hoc patching in test suites. + + """ + + should_wrap, against_creator = self._should_wrap_creator + creator = self._creator + + if creator is not against_creator: + # check if the _creator function has been patched since + # we last looked at it + self._set_should_wrap_creator() + return self._invoke_creator(connection_record) + elif should_wrap: + return self._creator() + else: + return self._creator(connection_record) def _close_connection(self, connection): self.logger.debug("Closing connection %r", connection) @@ -591,7 +614,7 @@ class _ConnectionRecord(object): def __connect(self): try: self.starttime = time.time() - connection = self.__pool._wrapped_creator(self) + connection = self.__pool._invoke_creator(self) self.__pool.logger.debug("Created new connection %r", connection) return connection except Exception as e: diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index cba3972f6..761ac102a 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -2174,7 +2174,7 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): conn.invalidate() - eng.pool._wrapped_creator = Mock( + eng.pool._creator = Mock( side_effect=self.ProgrammingError( "Cannot operate on a closed database.")) diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 3d93cda89..912c6c3fe 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -1807,3 +1807,58 @@ class StaticPoolTest(PoolTestBase): p = pool.StaticPool(creator) p2 = p.recreate() assert p._creator is p2._creator + + +class CreatorCompatibilityTest(PoolTestBase): + def test_creator_callable_outside_noarg(self): + e = testing_engine() + + creator = e.pool._creator + try: + conn = creator() + finally: + conn.close() + + def test_creator_callable_outside_witharg(self): + e = testing_engine() + + creator = e.pool._creator + try: + conn = creator(Mock()) + finally: + conn.close() + + def test_creator_patching_arg_to_noarg(self): + e = testing_engine() + creator = e.pool._creator + try: + # the creator is the two-arg form + conn = creator(Mock()) + finally: + conn.close() + + def mock_create(): + return creator() + + conn = e.connect() + conn.invalidate() + conn.close() + + # test that the 'should_wrap_creator' memoized attribute + # will dynamically switch if the _creator is monkeypatched. + + is_(e.pool.__dict__.get("_should_wrap_creator")[0], False) + + # patch it with a zero-arg form + with patch.object(e.pool, "_creator", mock_create): + conn = e.connect() + conn.invalidate() + conn.close() + + is_(e.pool.__dict__.get("_should_wrap_creator")[0], True) + + conn = e.connect() + conn.close() + + is_(e.pool.__dict__.get("_should_wrap_creator")[0], False) + |