diff options
| author | Maximilian Hils <git@maximilianhils.com> | 2015-08-17 19:27:20 +0200 |
|---|---|---|
| committer | Maximilian Hils <git@maximilianhils.com> | 2015-08-17 19:27:20 +0200 |
| commit | 1d95dea7fea03c7c0df345a5ea30c12d8a0378d2 (patch) | |
| tree | 01ee197401d1b37617fc5dbcae5f2226b17b2ea8 | |
| parent | cd0d22c16447cfea0e0e2d3147a905f1866530ac (diff) | |
| download | pyopenssl-1d95dea7fea03c7c0df345a5ea30c12d8a0378d2.tar.gz | |
add ssl_peek functionality
| -rw-r--r-- | ChangeLog | 8 | ||||
| -rw-r--r-- | OpenSSL/SSL.py | 19 | ||||
| -rw-r--r-- | OpenSSL/test/test_ssl.py | 24 | ||||
| -rw-r--r-- | doc/api/ssl.rst | 8 |
4 files changed, 47 insertions, 12 deletions
@@ -1,6 +1,12 @@ +2015-08-17 Maximilian Hils <pyopenssl@maximilianhils.com> + + * OpenSSL/SSL.py, OpenSSL/test/test_ssl.py: Add support for + the ``MSG_PEEK`` flag to ``Connection.recv()`` and + ``Connection.recv_into()``. + 2015-05-27 Jim Shaver <dcypherd@gmail.com> - * OpenSSL/SSL.py, : Add ``get_protocol_version()`` and + * OpenSSL/SSL.py: Add ``get_protocol_version()`` and ``get_protocol_version_name()`` to ``Connection``. Based on work from Rich Moore. diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 8c87c34..9b27013 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -1,3 +1,4 @@ +import socket from sys import platform from functools import wraps, partial from itertools import count, chain @@ -1311,12 +1312,15 @@ class Connection(object): method again with the SAME buffer. :param bufsiz: The maximum number of bytes to read - :param flags: (optional) Included for compatibility with the socket - API, the value is ignored + :param flags: (optional) The only supported flag is ``MSG_PEEK``, + all other flags are ignored. :return: The string read from the Connection """ buf = _ffi.new("char[]", bufsiz) - result = _lib.SSL_read(self._ssl, buf, bufsiz) + if flags is not None and flags & socket.MSG_PEEK: + result = _lib.SSL_peek(self._ssl, buf, bufsiz) + else: + result = _lib.SSL_read(self._ssl, buf, bufsiz) self._raise_ssl_error(self._ssl, result) return _ffi.buffer(buf, result)[:] read = recv @@ -1332,8 +1336,8 @@ class Connection(object): buffer. If not present, defaults to the size of the buffer. If larger than the size of the buffer, is reduced to the size of the buffer. - :param flags: (optional) Included for compatibility with the socket - API, the value is ignored. + :param flags: (optional) The only supported flag is ``MSG_PEEK``, + all other flags are ignored. :return: The number of bytes read into the buffer. """ if nbytes is None: @@ -1345,7 +1349,10 @@ class Connection(object): # better if we could pass memoryviews straight into the SSL_read call, # but right now we can't. Revisit this if CFFI gets that ability. buf = _ffi.new("char[]", nbytes) - result = _lib.SSL_read(self._ssl, buf, nbytes) + if flags is not None and flags & socket.MSG_PEEK: + result = _lib.SSL_peek(self._ssl, buf, nbytes) + else: + result = _lib.SSL_read(self._ssl, buf, nbytes) self._raise_ssl_error(self._ssl, result) # This strange line is all to avoid a memory copy. The buffer protocol diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index e586537..787d636 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -8,7 +8,7 @@ Unit tests for :py:obj:`OpenSSL.SSL`. from gc import collect, get_referrers from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN from sys import platform, getfilesystemencoding -from socket import SHUT_RDWR, error, socket +from socket import MSG_PEEK, SHUT_RDWR, error, socket from os import makedirs from os.path import join from unittest import main @@ -2172,6 +2172,17 @@ class ConnectionTests(TestCase, _LoopbackMixin): self.assertRaises(TypeError, connection.pending, None) + def test_peek(self): + """ + :py:obj:`Connection.recv` peeks into the connection if :py:obj:`socket.MSG_PEEK` is passed. + """ + server, client = self._loopback() + server.send(b('xy')) + self.assertEqual(client.recv(2, MSG_PEEK), b('xy')) + self.assertEqual(client.recv(2, MSG_PEEK), b('xy')) + self.assertEqual(client.recv(2), b('xy')) + + def test_connect_wrong_args(self): """ :py:obj:`Connection.connect` raises :py:obj:`TypeError` if called with a non-address @@ -2999,6 +3010,17 @@ class ConnectionRecvIntoTests(TestCase, _LoopbackMixin): self._doesnt_overfill_test(bytearray) + def test_peek(self): + + server, client = self._loopback() + server.send(b('xy')) + + for _ in range(2): + output_buffer = bytearray(5) + self.assertEqual(client.recv_into(output_buffer, flags=MSG_PEEK), 2) + self.assertEqual(output_buffer, bytearray(b('xy\x00\x00\x00'))) + + try: memoryview except NameError: diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst index 89ae6a1..0548678 100644 --- a/doc/api/ssl.rst +++ b/doc/api/ssl.rst @@ -669,11 +669,12 @@ Connection objects have the following methods: (**not** the underlying transport buffer). -.. py:method:: Connection.recv(bufsize) +.. py:method:: Connection.recv(bufsize[, flags]) Receive data from the Connection. The return value is a string representing the data received. The maximum amount of data to be received at once, is specified - by *bufsize*. + by *bufsize*. The only supported flag is ``MSG_PEEK``, all other flags are + ignored. .. py:method:: Connection.recv_into(buffer[, nbytes[, flags]]) @@ -681,8 +682,7 @@ Connection objects have the following methods: Receive data from the Connection and copy it directly into the provided buffer. The return value is the number of bytes read from the connection. The maximum amount of data to be received at once is specified by *nbytes*. - *flags* is accepted for compatibility with ``socket.recv_into`` but its - value is ignored. + The only supported flag is ``MSG_PEEK``, all other flags are ignored. .. py:method:: Connection.bio_write(bytes) |
