diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-02-15 23:43:51 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-02-17 14:45:04 -0500 |
| commit | 5157e0aa542f390242dd7a6d27a6ce1663230e46 (patch) | |
| tree | 113f0e5a83e8229c7d0cb9e9c47387e1d703cb29 /lib/sqlalchemy/pool/impl.py | |
| parent | 20213fd1f27fea51015d753bf94c6f40674ae86f (diff) | |
| download | sqlalchemy-5157e0aa542f390242dd7a6d27a6ce1663230e46.tar.gz | |
pep-484 for pool
also extends into some areas of utils, events and others
as needed.
Formalizes a public hierarchy for pool API,
with ManagesConnection -> PoolProxiedConnection /
ConnectionPoolEntry for connectionfairy / connectionrecord,
which are now what's exposed in the event API and other
APIs. all public API docs moved to the new objects.
Corrects the mypy plugin's check for sqlalchemy-stubs
not being insatlled, which has to be imported using the
dash in the name to be effective.
Change-Id: I16c2cb43b2e840d28e70a015f370a768e70f3581
Diffstat (limited to 'lib/sqlalchemy/pool/impl.py')
| -rw-r--r-- | lib/sqlalchemy/pool/impl.py | 171 |
1 files changed, 102 insertions, 69 deletions
diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index 7a422cd2a..d1be3f541 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -9,19 +9,36 @@ """Pool implementation classes. """ +from __future__ import annotations import threading import traceback +import typing +from typing import Any +from typing import cast +from typing import List +from typing import Optional +from typing import Set +from typing import Type +from typing import Union import weakref from .base import _AsyncConnDialect from .base import _ConnectionFairy from .base import _ConnectionRecord +from .base import _CreatorFnType +from .base import _CreatorWRecFnType +from .base import ConnectionPoolEntry from .base import Pool +from .base import PoolProxiedConnection from .. import exc from .. import util from ..util import chop_traceback from ..util import queue as sqla_queue +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from ..engine.interfaces import DBAPIConnection class QueuePool(Pool): @@ -34,17 +51,22 @@ class QueuePool(Pool): """ - _is_asyncio = False - _queue_class = sqla_queue.Queue + _is_asyncio = False # type: ignore[assignment] + + _queue_class: Type[ + sqla_queue.QueueCommon[ConnectionPoolEntry] + ] = sqla_queue.Queue + + _pool: sqla_queue.QueueCommon[ConnectionPoolEntry] def __init__( self, - creator, - pool_size=5, - max_overflow=10, - timeout=30.0, - use_lifo=False, - **kw, + creator: Union[_CreatorFnType, _CreatorWRecFnType], + pool_size: int = 5, + max_overflow: int = 10, + timeout: float = 30.0, + use_lifo: bool = False, + **kw: Any, ): r""" Construct a QueuePool. @@ -107,20 +129,20 @@ class QueuePool(Pool): self._timeout = timeout self._overflow_lock = threading.Lock() - def _do_return_conn(self, conn): + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: try: - self._pool.put(conn, False) + self._pool.put(record, False) except sqla_queue.Full: try: - conn.close() + record.close() finally: self._dec_overflow() - def _do_get(self): + def _do_get(self) -> ConnectionPoolEntry: use_overflow = self._max_overflow > -1 + wait = use_overflow and self._overflow >= self._max_overflow try: - wait = use_overflow and self._overflow >= self._max_overflow return self._pool.get(wait, self._timeout) except sqla_queue.Empty: # don't do things inside of "except Empty", because when we say @@ -144,10 +166,11 @@ class QueuePool(Pool): except: with util.safe_reraise(): self._dec_overflow() + raise else: return self._do_get() - def _inc_overflow(self): + def _inc_overflow(self) -> bool: if self._max_overflow == -1: self._overflow += 1 return True @@ -158,7 +181,7 @@ class QueuePool(Pool): else: return False - def _dec_overflow(self): + def _dec_overflow(self) -> Literal[True]: if self._max_overflow == -1: self._overflow -= 1 return True @@ -166,7 +189,7 @@ class QueuePool(Pool): self._overflow -= 1 return True - def recreate(self): + def recreate(self) -> QueuePool: self.logger.info("Pool recreating") return self.__class__( self._creator, @@ -183,7 +206,7 @@ class QueuePool(Pool): dialect=self._dialect, ) - def dispose(self): + def dispose(self) -> None: while True: try: conn = self._pool.get(False) @@ -194,7 +217,7 @@ class QueuePool(Pool): self._overflow = 0 - self.size() self.logger.info("Pool disposed. %s", self.status()) - def status(self): + def status(self) -> str: return ( "Pool size: %d Connections in pool: %d " "Current Overflow: %d Current Checked out " @@ -207,25 +230,28 @@ class QueuePool(Pool): ) ) - def size(self): + def size(self) -> int: return self._pool.maxsize - def timeout(self): + def timeout(self) -> float: return self._timeout - def checkedin(self): + def checkedin(self) -> int: return self._pool.qsize() - def overflow(self): + def overflow(self) -> int: return self._overflow - def checkedout(self): + def checkedout(self) -> int: return self._pool.maxsize - self._pool.qsize() + self._overflow class AsyncAdaptedQueuePool(QueuePool): - _is_asyncio = True - _queue_class = sqla_queue.AsyncAdaptedQueue + _is_asyncio = True # type: ignore[assignment] + _queue_class: Type[ + sqla_queue.QueueCommon[ConnectionPoolEntry] + ] = sqla_queue.AsyncAdaptedQueue + _dialect = _AsyncConnDialect() @@ -246,16 +272,16 @@ class NullPool(Pool): """ - def status(self): + def status(self) -> str: return "NullPool" - def _do_return_conn(self, conn): - conn.close() + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + record.close() - def _do_get(self): + def _do_get(self) -> ConnectionPoolEntry: return self._create_connection() - def recreate(self): + def recreate(self) -> NullPool: self.logger.info("Pool recreating") return self.__class__( @@ -269,7 +295,7 @@ class NullPool(Pool): dialect=self._dialect, ) - def dispose(self): + def dispose(self) -> None: pass @@ -304,16 +330,21 @@ class SingletonThreadPool(Pool): """ - _is_asyncio = False + _is_asyncio = False # type: ignore[assignment] - def __init__(self, creator, pool_size=5, **kw): + def __init__( + self, + creator: Union[_CreatorFnType, _CreatorWRecFnType], + pool_size: int = 5, + **kw: Any, + ): Pool.__init__(self, creator, **kw) self._conn = threading.local() self._fairy = threading.local() - self._all_conns = set() + self._all_conns: Set[ConnectionPoolEntry] = set() self.size = pool_size - def recreate(self): + def recreate(self) -> SingletonThreadPool: self.logger.info("Pool recreating") return self.__class__( self._creator, @@ -327,7 +358,7 @@ class SingletonThreadPool(Pool): dialect=self._dialect, ) - def dispose(self): + def dispose(self) -> None: """Dispose of this pool.""" for conn in self._all_conns: @@ -340,23 +371,26 @@ class SingletonThreadPool(Pool): self._all_conns.clear() - def _cleanup(self): + def _cleanup(self) -> None: while len(self._all_conns) >= self.size: c = self._all_conns.pop() c.close() - def status(self): + def status(self) -> str: return "SingletonThreadPool id:%d size: %d" % ( id(self), len(self._all_conns), ) - def _do_return_conn(self, conn): - pass + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + try: + del self._fairy.current # type: ignore + except AttributeError: + pass - def _do_get(self): + def _do_get(self) -> ConnectionPoolEntry: try: - c = self._conn.current() + c = cast(ConnectionPoolEntry, self._conn.current()) if c: return c except AttributeError: @@ -368,11 +402,11 @@ class SingletonThreadPool(Pool): self._all_conns.add(c) return c - def connect(self): + def connect(self) -> PoolProxiedConnection: # vendored from Pool to include the now removed use_threadlocal # behavior try: - rec = self._fairy.current() + rec = cast(_ConnectionFairy, self._fairy.current()) except AttributeError: pass else: @@ -381,13 +415,6 @@ class SingletonThreadPool(Pool): return _ConnectionFairy._checkout(self, self._fairy) - def _return_conn(self, record): - try: - del self._fairy.current - except AttributeError: - pass - self._do_return_conn(record) - class StaticPool(Pool): @@ -401,13 +428,13 @@ class StaticPool(Pool): """ @util.memoized_property - def connection(self): + def connection(self) -> _ConnectionRecord: return _ConnectionRecord(self) - def status(self): + def status(self) -> str: return "StaticPool" - def dispose(self): + def dispose(self) -> None: if ( "connection" in self.__dict__ and self.connection.dbapi_connection is not None @@ -415,7 +442,7 @@ class StaticPool(Pool): self.connection.close() del self.__dict__["connection"] - def recreate(self): + def recreate(self) -> StaticPool: self.logger.info("Pool recreating") return self.__class__( creator=self._creator, @@ -428,20 +455,23 @@ class StaticPool(Pool): dialect=self._dialect, ) - def _transfer_from(self, other_static_pool): + def _transfer_from(self, other_static_pool: StaticPool) -> None: # used by the test suite to make a new engine / pool without # losing the state of an existing SQLite :memory: connection - self._invoke_creator = ( - lambda crec: other_static_pool.connection.dbapi_connection - ) + def creator(rec: ConnectionPoolEntry) -> DBAPIConnection: + conn = other_static_pool.connection.dbapi_connection + assert conn is not None + return conn - def _create_connection(self): + self._invoke_creator = creator + + def _create_connection(self) -> ConnectionPoolEntry: raise NotImplementedError() - def _do_return_conn(self, conn): + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: pass - def _do_get(self): + def _do_get(self) -> ConnectionPoolEntry: rec = self.connection if rec._is_hard_or_soft_invalidated(): del self.__dict__["connection"] @@ -461,28 +491,31 @@ class AssertionPool(Pool): """ - def __init__(self, *args, **kw): + _conn: Optional[ConnectionPoolEntry] + _checkout_traceback: Optional[List[str]] + + def __init__(self, *args: Any, **kw: Any): self._conn = None self._checked_out = False self._store_traceback = kw.pop("store_traceback", True) self._checkout_traceback = None Pool.__init__(self, *args, **kw) - def status(self): + def status(self) -> str: return "AssertionPool" - def _do_return_conn(self, conn): + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: if not self._checked_out: raise AssertionError("connection is not checked out") self._checked_out = False - assert conn is self._conn + assert record is self._conn - def dispose(self): + def dispose(self) -> None: self._checked_out = False if self._conn: self._conn.close() - def recreate(self): + def recreate(self) -> AssertionPool: self.logger.info("Pool recreating") return self.__class__( self._creator, @@ -495,7 +528,7 @@ class AssertionPool(Pool): dialect=self._dialect, ) - def _do_get(self): + def _do_get(self) -> ConnectionPoolEntry: if self._checked_out: if self._checkout_traceback: suffix = " at:\n%s" % "".join( |
