diff options
Diffstat (limited to 'Lib/socket.py')
-rw-r--r-- | Lib/socket.py | 126 |
1 files changed, 108 insertions, 18 deletions
diff --git a/Lib/socket.py b/Lib/socket.py index d0350617ae..ea56a67d55 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -23,7 +23,8 @@ inet_aton() -- convert IP addr string (123.45.67.89) to 32-bit packed format inet_ntoa() -- convert 32-bit packed format IP to string (123.45.67.89) socket.getdefaulttimeout() -- get the default timeout value socket.setdefaulttimeout() -- set the default timeout value -create_connection() -- connects to an address, with an optional timeout +create_connection() -- connects to an address, with an optional timeout and + optional source address. [*] not available on all platforms! @@ -48,9 +49,13 @@ from _socket import * import os, sys, io try: - from errno import EBADF + import errno except ImportError: - EBADF = 9 + errno = None +EBADF = getattr(errno, 'EBADF', 9) +EINTR = getattr(errno, 'EINTR', 4) +EAGAIN = getattr(errno, 'EAGAIN', 11) +EWOULDBLOCK = getattr(errno, 'EWOULDBLOCK', 11) __all__ = ["getfqdn", "create_connection"] __all__.extend(os._get_exports_list(_socket)) @@ -90,13 +95,20 @@ class socket(_socket.socket): self._io_refs = 0 self._closed = False + def __enter__(self): + return self + + def __exit__(self, *args): + if not self._closed: + self.close() + def __repr__(self): """Wrap __repr__() to reveal the real class name.""" s = _socket.socket.__repr__(self) if s.startswith("<socket object"): s = "<%s.%s%s%s" % (self.__class__.__module__, self.__class__.__name__, - (self._closed and " [closed] ") or "", + getattr(self, '_closed', False) and " [closed] " or "", s[7:]) return s @@ -118,7 +130,13 @@ class socket(_socket.socket): For IP sockets, the address info is a pair (hostaddr, port). """ fd, addr = self._accept() - return socket(self.family, self.type, self.proto, fileno=fd), addr + sock = socket(self.family, self.type, self.proto, fileno=fd) + # Issue #7995: if no default timeout is set and the listening + # socket had a (non-zero) timeout, force the new socket in blocking + # mode to override platform-specific socket flags inheritance. + if getdefaulttimeout() is None and self.gettimeout(): + sock.setblocking(True) + return sock, addr def makefile(self, mode="r", buffering=None, *, encoding=None, errors=None, newline=None): @@ -169,14 +187,27 @@ class socket(_socket.socket): if self._closed: self.close() - def _real_close(self): - _socket.socket.close(self) + def _real_close(self, _ss=_socket.socket): + # This function should not reference any globals. See issue #808164. + _ss.close(self) def close(self): + # This function should not reference any globals. See issue #808164. self._closed = True if self._io_refs <= 0: self._real_close() + def detach(self): + """detach() -> file descriptor + + Close the socket object without closing the underlying file descriptor. + The object cannot be used after this call, but the file descriptor + can be reused for other purposes. The file descriptor is returned. + """ + self._closed = True + return super().detach() + + def fromfd(fd, family, type, proto=0): """ fromfd(fd, family, type[, proto]) -> socket object @@ -187,6 +218,29 @@ def fromfd(fd, family, type, proto=0): return socket(family, type, proto, nfd) +if hasattr(_socket, "socketpair"): + + def socketpair(family=None, type=SOCK_STREAM, proto=0): + """socketpair([family[, type[, proto]]]) -> (socket object, socket object) + + Create a pair of socket objects from the sockets returned by the platform + socketpair() function. + The arguments are the same as for socket() except the default family is + AF_UNIX if defined on the platform; otherwise, the default is AF_INET. + """ + if family is None: + try: + family = AF_UNIX + except NameError: + family = AF_INET + a, b = _socket.socketpair(family, type, proto) + a = socket(family, type, proto, a.detach()) + b = socket(family, type, proto, b.detach()) + return a, b + + +_blocking_errnos = { EAGAIN, EWOULDBLOCK } + class SocketIO(io.RawIOBase): """Raw I/O implementation for stream sockets. @@ -214,6 +268,7 @@ class SocketIO(io.RawIOBase): self._mode = mode self._reading = "r" in mode self._writing = "w" in mode + self._timeout_occurred = False def readinto(self, b): """Read up to len(b) bytes into the writable buffer *b* and return @@ -225,7 +280,21 @@ class SocketIO(io.RawIOBase): """ self._checkClosed() self._checkReadable() - return self._sock.recv_into(b) + if self._timeout_occurred: + raise IOError("cannot read from timed out object") + while True: + try: + return self._sock.recv_into(b) + except timeout: + self._timeout_occurred = True + raise + except error as e: + n = e.args[0] + if n == EINTR: + continue + if n in _blocking_errnos: + return None + raise def write(self, b): """Write the given bytes or bytearray object *b* to the socket @@ -235,17 +304,34 @@ class SocketIO(io.RawIOBase): """ self._checkClosed() self._checkWritable() - return self._sock.send(b) + try: + return self._sock.send(b) + except error as e: + # XXX what about EINTR? + if e.args[0] in _blocking_errnos: + return None + raise def readable(self): """True if the SocketIO is open for reading. """ - return self._reading and not self.closed + if self.closed: + raise ValueError("I/O operation on closed socket.") + return self._reading def writable(self): """True if the SocketIO is open for writing. """ - return self._writing and not self.closed + if self.closed: + raise ValueError("I/O operation on closed socket.") + return self._writing + + def seekable(self): + """True if the SocketIO is open for seeking. + """ + if self.closed: + raise ValueError("I/O operation on closed socket.") + return super().seekable() def fileno(self): """Return the file descriptor of the underlying socket. @@ -255,7 +341,10 @@ class SocketIO(io.RawIOBase): @property def name(self): - return self.fileno() + if not self.closed: + return self.fileno() + else: + return -1 @property def mode(self): @@ -271,10 +360,6 @@ class SocketIO(io.RawIOBase): self._sock._decref_socketios() self._sock = None - def __del__(self): - if not self.closed: - self._sock._decref_socketios() - def getfqdn(name=''): """Get fully qualified domain name from name. @@ -304,7 +389,8 @@ def getfqdn(name=''): _GLOBAL_DEFAULT_TIMEOUT = object() -def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT): +def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, + source_address=None): """Connect to *address* and return the socket object. Convenience function. Connect to *address* (a 2-tuple ``(host, @@ -312,7 +398,9 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT): *timeout* parameter will set the timeout on the socket instance before attempting to connect. If no *timeout* is supplied, the global default timeout setting returned by :func:`getdefaulttimeout` - is used. + is used. If *source_address* is set it must be a tuple of (host, port) + for the socket to bind as a source address before making the connection. + An host of '' or port 0 tells the OS to use the default. """ host, port = address @@ -324,6 +412,8 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT): sock = socket(af, socktype, proto) if timeout is not _GLOBAL_DEFAULT_TIMEOUT: sock.settimeout(timeout) + if source_address: + sock.bind(source_address) sock.connect(sa) return sock |