diff options
| author | Maximilian Hils <git@maximilianhils.com> | 2015-08-17 19:27:20 +0200 |
|---|---|---|
| committer | Maximilian Hils <git@maximilianhils.com> | 2015-08-17 19:27:20 +0200 |
| commit | 1d95dea7fea03c7c0df345a5ea30c12d8a0378d2 (patch) | |
| tree | 01ee197401d1b37617fc5dbcae5f2226b17b2ea8 /OpenSSL | |
| parent | cd0d22c16447cfea0e0e2d3147a905f1866530ac (diff) | |
| download | pyopenssl-1d95dea7fea03c7c0df345a5ea30c12d8a0378d2.tar.gz | |
add ssl_peek functionality
Diffstat (limited to 'OpenSSL')
| -rw-r--r-- | OpenSSL/SSL.py | 19 | ||||
| -rw-r--r-- | OpenSSL/test/test_ssl.py | 24 |
2 files changed, 36 insertions, 7 deletions
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 8c87c34..9b27013 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -1,3 +1,4 @@ +import socket from sys import platform from functools import wraps, partial from itertools import count, chain @@ -1311,12 +1312,15 @@ class Connection(object): method again with the SAME buffer. :param bufsiz: The maximum number of bytes to read - :param flags: (optional) Included for compatibility with the socket - API, the value is ignored + :param flags: (optional) The only supported flag is ``MSG_PEEK``, + all other flags are ignored. :return: The string read from the Connection """ buf = _ffi.new("char[]", bufsiz) - result = _lib.SSL_read(self._ssl, buf, bufsiz) + if flags is not None and flags & socket.MSG_PEEK: + result = _lib.SSL_peek(self._ssl, buf, bufsiz) + else: + result = _lib.SSL_read(self._ssl, buf, bufsiz) self._raise_ssl_error(self._ssl, result) return _ffi.buffer(buf, result)[:] read = recv @@ -1332,8 +1336,8 @@ class Connection(object): buffer. If not present, defaults to the size of the buffer. If larger than the size of the buffer, is reduced to the size of the buffer. - :param flags: (optional) Included for compatibility with the socket - API, the value is ignored. + :param flags: (optional) The only supported flag is ``MSG_PEEK``, + all other flags are ignored. :return: The number of bytes read into the buffer. """ if nbytes is None: @@ -1345,7 +1349,10 @@ class Connection(object): # better if we could pass memoryviews straight into the SSL_read call, # but right now we can't. Revisit this if CFFI gets that ability. buf = _ffi.new("char[]", nbytes) - result = _lib.SSL_read(self._ssl, buf, nbytes) + if flags is not None and flags & socket.MSG_PEEK: + result = _lib.SSL_peek(self._ssl, buf, nbytes) + else: + result = _lib.SSL_read(self._ssl, buf, nbytes) self._raise_ssl_error(self._ssl, result) # This strange line is all to avoid a memory copy. The buffer protocol diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index e586537..787d636 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -8,7 +8,7 @@ Unit tests for :py:obj:`OpenSSL.SSL`. from gc import collect, get_referrers from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN from sys import platform, getfilesystemencoding -from socket import SHUT_RDWR, error, socket +from socket import MSG_PEEK, SHUT_RDWR, error, socket from os import makedirs from os.path import join from unittest import main @@ -2172,6 +2172,17 @@ class ConnectionTests(TestCase, _LoopbackMixin): self.assertRaises(TypeError, connection.pending, None) + def test_peek(self): + """ + :py:obj:`Connection.recv` peeks into the connection if :py:obj:`socket.MSG_PEEK` is passed. + """ + server, client = self._loopback() + server.send(b('xy')) + self.assertEqual(client.recv(2, MSG_PEEK), b('xy')) + self.assertEqual(client.recv(2, MSG_PEEK), b('xy')) + self.assertEqual(client.recv(2), b('xy')) + + def test_connect_wrong_args(self): """ :py:obj:`Connection.connect` raises :py:obj:`TypeError` if called with a non-address @@ -2999,6 +3010,17 @@ class ConnectionRecvIntoTests(TestCase, _LoopbackMixin): self._doesnt_overfill_test(bytearray) + def test_peek(self): + + server, client = self._loopback() + server.send(b('xy')) + + for _ in range(2): + output_buffer = bytearray(5) + self.assertEqual(client.recv_into(output_buffer, flags=MSG_PEEK), 2) + self.assertEqual(output_buffer, bytearray(b('xy\x00\x00\x00'))) + + try: memoryview except NameError: |
