summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Grainger <tagrain@gmail.com>2023-04-22 22:41:43 +0100
committerGitHub <noreply@github.com>2023-04-22 16:41:43 -0500
commit26013d00201d772d6572c8830ed314f2834feb3a (patch)
treed1f1891b57a9243c96cfb7ab096dbd919ec6fbe1
parent2c02fc98c0e2c67be5b79a50f8bc8c7f5cfe6c07 (diff)
downloadurllib3-26013d00201d772d6572c8830ed314f2834feb3a.tar.gz
Fix io_refs bug with pyopenssl.WrappedSocket and securetransport.WrappedSocket.close()
-rw-r--r--changelog/2970.bugfix.rst1
-rw-r--r--src/urllib3/contrib/pyopenssl.py17
-rw-r--r--src/urllib3/contrib/securetransport.py43
-rw-r--r--test/with_dummyserver/test_socketlevel.py58
4 files changed, 91 insertions, 28 deletions
diff --git a/changelog/2970.bugfix.rst b/changelog/2970.bugfix.rst
new file mode 100644
index 00000000..cc20d222
--- /dev/null
+++ b/changelog/2970.bugfix.rst
@@ -0,0 +1 @@
+fix ``urllib3.contrib.pyopenssl.WrappedSocket`` and ``urllib3.contrib.securetransport.WrappedSocket`` close
diff --git a/src/urllib3/contrib/pyopenssl.py b/src/urllib3/contrib/pyopenssl.py
index 8ff0326b..0089cd27 100644
--- a/src/urllib3/contrib/pyopenssl.py
+++ b/src/urllib3/contrib/pyopenssl.py
@@ -380,14 +380,15 @@ class WrappedSocket:
self.connection.shutdown()
def close(self) -> None:
- if self._io_refs < 1:
- try:
- self._closed = True
- return self.connection.close() # type: ignore[no-any-return]
- except OpenSSL.SSL.Error:
- return
- else:
- self._io_refs -= 1
+ self._closed = True
+ if self._io_refs <= 0:
+ self._real_close()
+
+ def _real_close(self) -> None:
+ try:
+ return self.connection.close() # type: ignore[no-any-return]
+ except OpenSSL.SSL.Error:
+ return
def getpeercert(
self, binary_form: bool = False
diff --git a/src/urllib3/contrib/securetransport.py b/src/urllib3/contrib/securetransport.py
index b90d78e5..11beb3df 100644
--- a/src/urllib3/contrib/securetransport.py
+++ b/src/urllib3/contrib/securetransport.py
@@ -317,6 +317,7 @@ class WrappedSocket:
self.context = None
self._io_refs = 0
self._closed = False
+ self._real_closed = False
self._exception: Exception | None = None
self._keychain = None
self._keychain_dir: str | None = None
@@ -348,7 +349,7 @@ class WrappedSocket:
yield
if self._exception is not None:
exception, self._exception = self._exception, None
- self.close()
+ self._real_close()
raise exception
def _set_alpn_protocols(self, protocols: list[bytes] | None) -> None:
@@ -398,7 +399,7 @@ class WrappedSocket:
# l_linger = 0, linger for 0 seoncds
opts = struct.pack("ii", 1, 0)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts)
- self.close()
+ self._real_close()
raise ssl.SSLError(f"certificate verify failed, {reason}") from exc
def _evaluate_trust(self, trust_bundle: bytes) -> int:
@@ -553,7 +554,7 @@ class WrappedSocket:
self, buffer: ctypes.Array[ctypes.c_char], nbytes: int | None = None
) -> int:
# Read short on EOF.
- if self._closed:
+ if self._real_closed:
return 0
if nbytes is None:
@@ -586,7 +587,7 @@ class WrappedSocket:
# well. Note that we don't actually return here because in
# principle this could actually be fired along with return data.
# It's unlikely though.
- self.close()
+ self._real_close()
else:
_assert_no_error(result)
@@ -628,23 +629,25 @@ class WrappedSocket:
Security.SSLClose(self.context)
def close(self) -> None:
+ self._closed = True
# TODO: should I do clean shutdown here? Do I have to?
- if self._io_refs < 1:
- self._closed = True
- if self.context:
- CoreFoundation.CFRelease(self.context)
- self.context = None
- if self._client_cert_chain:
- CoreFoundation.CFRelease(self._client_cert_chain)
- self._client_cert_chain = None
- if self._keychain:
- Security.SecKeychainDelete(self._keychain)
- CoreFoundation.CFRelease(self._keychain)
- shutil.rmtree(self._keychain_dir)
- self._keychain = self._keychain_dir = None
- return self.socket.close()
- else:
- self._io_refs -= 1
+ if self._io_refs <= 0:
+ self._real_close()
+
+ def _real_close(self) -> None:
+ self._real_closed = True
+ if self.context:
+ CoreFoundation.CFRelease(self.context)
+ self.context = None
+ if self._client_cert_chain:
+ CoreFoundation.CFRelease(self._client_cert_chain)
+ self._client_cert_chain = None
+ if self._keychain:
+ Security.SecKeychainDelete(self._keychain)
+ CoreFoundation.CFRelease(self._keychain)
+ shutil.rmtree(self._keychain_dir)
+ self._keychain = self._keychain_dir = None
+ return self.socket.close()
def getpeercert(self, binary_form: bool = False) -> bytes | None:
# Urgh, annoying.
diff --git a/test/with_dummyserver/test_socketlevel.py b/test/with_dummyserver/test_socketlevel.py
index 0a6a8ae3..809c6852 100644
--- a/test/with_dummyserver/test_socketlevel.py
+++ b/test/with_dummyserver/test_socketlevel.py
@@ -2,6 +2,7 @@
# rather than the socket level-ness of it.
from __future__ import annotations
+import contextlib
import errno
import io
import os
@@ -966,6 +967,63 @@ class TestSocketClosing(SocketDummyServerTestCase):
assert pool.pool.qsize() == 1
assert response.connection is None
+ def test_socket_close_socket_then_file(self) -> None:
+ def consume_ssl_socket(listener: socket.socket) -> None:
+ try:
+ with listener.accept()[0] as sock, original_ssl_wrap_socket(
+ sock,
+ server_side=True,
+ keyfile=DEFAULT_CERTS["keyfile"],
+ certfile=DEFAULT_CERTS["certfile"],
+ ca_certs=DEFAULT_CA,
+ ) as ssl_sock:
+ consume_socket(ssl_sock)
+ except (ConnectionResetError, ConnectionAbortedError, OSError):
+ pass
+
+ self._start_server(consume_ssl_socket)
+ with socket.create_connection(
+ (self.host, self.port)
+ ) as sock, contextlib.closing(
+ ssl_wrap_socket(sock, server_hostname=self.host, ca_certs=DEFAULT_CA)
+ ) as ssl_sock, ssl_sock.makefile(
+ "rb"
+ ) as f:
+ ssl_sock.close()
+ f.close()
+ # SecureTransport is supposed to raise OSError but raises
+ # ssl.SSLError when closed because ssl_sock.context is None
+ with pytest.raises((OSError, ssl.SSLError)):
+ ssl_sock.sendall(b"hello")
+ assert ssl_sock.fileno() == -1
+
+ def test_socket_close_stays_open_with_makefile_open(self) -> None:
+ def consume_ssl_socket(listener: socket.socket) -> None:
+ try:
+ with listener.accept()[0] as sock, original_ssl_wrap_socket(
+ sock,
+ server_side=True,
+ keyfile=DEFAULT_CERTS["keyfile"],
+ certfile=DEFAULT_CERTS["certfile"],
+ ca_certs=DEFAULT_CA,
+ ) as ssl_sock:
+ consume_socket(ssl_sock)
+ except (ConnectionResetError, ConnectionAbortedError, OSError):
+ pass
+
+ self._start_server(consume_ssl_socket)
+ with socket.create_connection(
+ (self.host, self.port)
+ ) as sock, contextlib.closing(
+ ssl_wrap_socket(sock, server_hostname=self.host, ca_certs=DEFAULT_CA)
+ ) as ssl_sock, ssl_sock.makefile(
+ "rb"
+ ):
+ ssl_sock.close()
+ ssl_sock.close()
+ ssl_sock.sendall(b"hello")
+ assert ssl_sock.fileno() > 0
+
class TestProxyManager(SocketDummyServerTestCase):
def test_simple(self) -> None: