summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/pool/impl.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-02-15 23:43:51 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-02-17 14:45:04 -0500
commit5157e0aa542f390242dd7a6d27a6ce1663230e46 (patch)
tree113f0e5a83e8229c7d0cb9e9c47387e1d703cb29 /lib/sqlalchemy/pool/impl.py
parent20213fd1f27fea51015d753bf94c6f40674ae86f (diff)
downloadsqlalchemy-5157e0aa542f390242dd7a6d27a6ce1663230e46.tar.gz
pep-484 for pool
also extends into some areas of utils, events and others as needed. Formalizes a public hierarchy for pool API, with ManagesConnection -> PoolProxiedConnection / ConnectionPoolEntry for connectionfairy / connectionrecord, which are now what's exposed in the event API and other APIs. all public API docs moved to the new objects. Corrects the mypy plugin's check for sqlalchemy-stubs not being insatlled, which has to be imported using the dash in the name to be effective. Change-Id: I16c2cb43b2e840d28e70a015f370a768e70f3581
Diffstat (limited to 'lib/sqlalchemy/pool/impl.py')
-rw-r--r--lib/sqlalchemy/pool/impl.py171
1 files changed, 102 insertions, 69 deletions
diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py
index 7a422cd2a..d1be3f541 100644
--- a/lib/sqlalchemy/pool/impl.py
+++ b/lib/sqlalchemy/pool/impl.py
@@ -9,19 +9,36 @@
"""Pool implementation classes.
"""
+from __future__ import annotations
import threading
import traceback
+import typing
+from typing import Any
+from typing import cast
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Type
+from typing import Union
import weakref
from .base import _AsyncConnDialect
from .base import _ConnectionFairy
from .base import _ConnectionRecord
+from .base import _CreatorFnType
+from .base import _CreatorWRecFnType
+from .base import ConnectionPoolEntry
from .base import Pool
+from .base import PoolProxiedConnection
from .. import exc
from .. import util
from ..util import chop_traceback
from ..util import queue as sqla_queue
+from ..util.typing import Literal
+
+if typing.TYPE_CHECKING:
+ from ..engine.interfaces import DBAPIConnection
class QueuePool(Pool):
@@ -34,17 +51,22 @@ class QueuePool(Pool):
"""
- _is_asyncio = False
- _queue_class = sqla_queue.Queue
+ _is_asyncio = False # type: ignore[assignment]
+
+ _queue_class: Type[
+ sqla_queue.QueueCommon[ConnectionPoolEntry]
+ ] = sqla_queue.Queue
+
+ _pool: sqla_queue.QueueCommon[ConnectionPoolEntry]
def __init__(
self,
- creator,
- pool_size=5,
- max_overflow=10,
- timeout=30.0,
- use_lifo=False,
- **kw,
+ creator: Union[_CreatorFnType, _CreatorWRecFnType],
+ pool_size: int = 5,
+ max_overflow: int = 10,
+ timeout: float = 30.0,
+ use_lifo: bool = False,
+ **kw: Any,
):
r"""
Construct a QueuePool.
@@ -107,20 +129,20 @@ class QueuePool(Pool):
self._timeout = timeout
self._overflow_lock = threading.Lock()
- def _do_return_conn(self, conn):
+ def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
try:
- self._pool.put(conn, False)
+ self._pool.put(record, False)
except sqla_queue.Full:
try:
- conn.close()
+ record.close()
finally:
self._dec_overflow()
- def _do_get(self):
+ def _do_get(self) -> ConnectionPoolEntry:
use_overflow = self._max_overflow > -1
+ wait = use_overflow and self._overflow >= self._max_overflow
try:
- wait = use_overflow and self._overflow >= self._max_overflow
return self._pool.get(wait, self._timeout)
except sqla_queue.Empty:
# don't do things inside of "except Empty", because when we say
@@ -144,10 +166,11 @@ class QueuePool(Pool):
except:
with util.safe_reraise():
self._dec_overflow()
+ raise
else:
return self._do_get()
- def _inc_overflow(self):
+ def _inc_overflow(self) -> bool:
if self._max_overflow == -1:
self._overflow += 1
return True
@@ -158,7 +181,7 @@ class QueuePool(Pool):
else:
return False
- def _dec_overflow(self):
+ def _dec_overflow(self) -> Literal[True]:
if self._max_overflow == -1:
self._overflow -= 1
return True
@@ -166,7 +189,7 @@ class QueuePool(Pool):
self._overflow -= 1
return True
- def recreate(self):
+ def recreate(self) -> QueuePool:
self.logger.info("Pool recreating")
return self.__class__(
self._creator,
@@ -183,7 +206,7 @@ class QueuePool(Pool):
dialect=self._dialect,
)
- def dispose(self):
+ def dispose(self) -> None:
while True:
try:
conn = self._pool.get(False)
@@ -194,7 +217,7 @@ class QueuePool(Pool):
self._overflow = 0 - self.size()
self.logger.info("Pool disposed. %s", self.status())
- def status(self):
+ def status(self) -> str:
return (
"Pool size: %d Connections in pool: %d "
"Current Overflow: %d Current Checked out "
@@ -207,25 +230,28 @@ class QueuePool(Pool):
)
)
- def size(self):
+ def size(self) -> int:
return self._pool.maxsize
- def timeout(self):
+ def timeout(self) -> float:
return self._timeout
- def checkedin(self):
+ def checkedin(self) -> int:
return self._pool.qsize()
- def overflow(self):
+ def overflow(self) -> int:
return self._overflow
- def checkedout(self):
+ def checkedout(self) -> int:
return self._pool.maxsize - self._pool.qsize() + self._overflow
class AsyncAdaptedQueuePool(QueuePool):
- _is_asyncio = True
- _queue_class = sqla_queue.AsyncAdaptedQueue
+ _is_asyncio = True # type: ignore[assignment]
+ _queue_class: Type[
+ sqla_queue.QueueCommon[ConnectionPoolEntry]
+ ] = sqla_queue.AsyncAdaptedQueue
+
_dialect = _AsyncConnDialect()
@@ -246,16 +272,16 @@ class NullPool(Pool):
"""
- def status(self):
+ def status(self) -> str:
return "NullPool"
- def _do_return_conn(self, conn):
- conn.close()
+ def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
+ record.close()
- def _do_get(self):
+ def _do_get(self) -> ConnectionPoolEntry:
return self._create_connection()
- def recreate(self):
+ def recreate(self) -> NullPool:
self.logger.info("Pool recreating")
return self.__class__(
@@ -269,7 +295,7 @@ class NullPool(Pool):
dialect=self._dialect,
)
- def dispose(self):
+ def dispose(self) -> None:
pass
@@ -304,16 +330,21 @@ class SingletonThreadPool(Pool):
"""
- _is_asyncio = False
+ _is_asyncio = False # type: ignore[assignment]
- def __init__(self, creator, pool_size=5, **kw):
+ def __init__(
+ self,
+ creator: Union[_CreatorFnType, _CreatorWRecFnType],
+ pool_size: int = 5,
+ **kw: Any,
+ ):
Pool.__init__(self, creator, **kw)
self._conn = threading.local()
self._fairy = threading.local()
- self._all_conns = set()
+ self._all_conns: Set[ConnectionPoolEntry] = set()
self.size = pool_size
- def recreate(self):
+ def recreate(self) -> SingletonThreadPool:
self.logger.info("Pool recreating")
return self.__class__(
self._creator,
@@ -327,7 +358,7 @@ class SingletonThreadPool(Pool):
dialect=self._dialect,
)
- def dispose(self):
+ def dispose(self) -> None:
"""Dispose of this pool."""
for conn in self._all_conns:
@@ -340,23 +371,26 @@ class SingletonThreadPool(Pool):
self._all_conns.clear()
- def _cleanup(self):
+ def _cleanup(self) -> None:
while len(self._all_conns) >= self.size:
c = self._all_conns.pop()
c.close()
- def status(self):
+ def status(self) -> str:
return "SingletonThreadPool id:%d size: %d" % (
id(self),
len(self._all_conns),
)
- def _do_return_conn(self, conn):
- pass
+ def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
+ try:
+ del self._fairy.current # type: ignore
+ except AttributeError:
+ pass
- def _do_get(self):
+ def _do_get(self) -> ConnectionPoolEntry:
try:
- c = self._conn.current()
+ c = cast(ConnectionPoolEntry, self._conn.current())
if c:
return c
except AttributeError:
@@ -368,11 +402,11 @@ class SingletonThreadPool(Pool):
self._all_conns.add(c)
return c
- def connect(self):
+ def connect(self) -> PoolProxiedConnection:
# vendored from Pool to include the now removed use_threadlocal
# behavior
try:
- rec = self._fairy.current()
+ rec = cast(_ConnectionFairy, self._fairy.current())
except AttributeError:
pass
else:
@@ -381,13 +415,6 @@ class SingletonThreadPool(Pool):
return _ConnectionFairy._checkout(self, self._fairy)
- def _return_conn(self, record):
- try:
- del self._fairy.current
- except AttributeError:
- pass
- self._do_return_conn(record)
-
class StaticPool(Pool):
@@ -401,13 +428,13 @@ class StaticPool(Pool):
"""
@util.memoized_property
- def connection(self):
+ def connection(self) -> _ConnectionRecord:
return _ConnectionRecord(self)
- def status(self):
+ def status(self) -> str:
return "StaticPool"
- def dispose(self):
+ def dispose(self) -> None:
if (
"connection" in self.__dict__
and self.connection.dbapi_connection is not None
@@ -415,7 +442,7 @@ class StaticPool(Pool):
self.connection.close()
del self.__dict__["connection"]
- def recreate(self):
+ def recreate(self) -> StaticPool:
self.logger.info("Pool recreating")
return self.__class__(
creator=self._creator,
@@ -428,20 +455,23 @@ class StaticPool(Pool):
dialect=self._dialect,
)
- def _transfer_from(self, other_static_pool):
+ def _transfer_from(self, other_static_pool: StaticPool) -> None:
# used by the test suite to make a new engine / pool without
# losing the state of an existing SQLite :memory: connection
- self._invoke_creator = (
- lambda crec: other_static_pool.connection.dbapi_connection
- )
+ def creator(rec: ConnectionPoolEntry) -> DBAPIConnection:
+ conn = other_static_pool.connection.dbapi_connection
+ assert conn is not None
+ return conn
- def _create_connection(self):
+ self._invoke_creator = creator
+
+ def _create_connection(self) -> ConnectionPoolEntry:
raise NotImplementedError()
- def _do_return_conn(self, conn):
+ def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
pass
- def _do_get(self):
+ def _do_get(self) -> ConnectionPoolEntry:
rec = self.connection
if rec._is_hard_or_soft_invalidated():
del self.__dict__["connection"]
@@ -461,28 +491,31 @@ class AssertionPool(Pool):
"""
- def __init__(self, *args, **kw):
+ _conn: Optional[ConnectionPoolEntry]
+ _checkout_traceback: Optional[List[str]]
+
+ def __init__(self, *args: Any, **kw: Any):
self._conn = None
self._checked_out = False
self._store_traceback = kw.pop("store_traceback", True)
self._checkout_traceback = None
Pool.__init__(self, *args, **kw)
- def status(self):
+ def status(self) -> str:
return "AssertionPool"
- def _do_return_conn(self, conn):
+ def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
if not self._checked_out:
raise AssertionError("connection is not checked out")
self._checked_out = False
- assert conn is self._conn
+ assert record is self._conn
- def dispose(self):
+ def dispose(self) -> None:
self._checked_out = False
if self._conn:
self._conn.close()
- def recreate(self):
+ def recreate(self) -> AssertionPool:
self.logger.info("Pool recreating")
return self.__class__(
self._creator,
@@ -495,7 +528,7 @@ class AssertionPool(Pool):
dialect=self._dialect,
)
- def _do_get(self):
+ def _do_get(self) -> ConnectionPoolEntry:
if self._checked_out:
if self._checkout_traceback:
suffix = " at:\n%s" % "".join(