summaryrefslogtreecommitdiff
path: root/OpenSSL/SSL.py
diff options
context:
space:
mode:
authorJean-Paul Calderone <exarkun@twistedmatrix.com>2015-04-13 22:14:07 -0400
committerJean-Paul Calderone <exarkun@twistedmatrix.com>2015-04-13 22:14:07 -0400
commit218a0146a82b1e392ef7bbde6fd63f58f2d35ad7 (patch)
treea412d198374e74f8ab4a4662642aa913a54d900c /OpenSSL/SSL.py
parentba1820dfb1b03c849b272410d2c955478e633fbc (diff)
parent5c3b748846ad1f9597d51b24d04ac394980c2480 (diff)
downloadpyopenssl-218a0146a82b1e392ef7bbde6fd63f58f2d35ad7.tar.gz
merge master
Diffstat (limited to 'OpenSSL/SSL.py')
-rw-r--r--OpenSSL/SSL.py165
1 files changed, 165 insertions, 0 deletions
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index 87492af..bba59ac 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -5,6 +5,7 @@ from weakref import WeakValueDictionary
from errno import errorcode
from six import text_type as _text_type
+from six import binary_type as _binary_type
from six import integer_types as integer_types
from six import int2byte, indexbytes
@@ -13,6 +14,7 @@ from OpenSSL._util import (
lib as _lib,
exception_from_error_queue as _exception_from_error_queue,
native as _native,
+ text_to_bytes_and_warn as _text_to_bytes_and_warn,
path_string as _path_string,
)
@@ -318,6 +320,56 @@ class _NpnSelectHelper(_CallbackExceptionHelper):
)
+class _ALPNSelectHelper(_CallbackExceptionHelper):
+ """
+ Wrap a callback such that it can be used as an ALPN selection callback.
+ """
+ def __init__(self, callback):
+ _CallbackExceptionHelper.__init__(self)
+
+ @wraps(callback)
+ def wrapper(ssl, out, outlen, in_, inlen, arg):
+ try:
+ conn = Connection._reverse_mapping[ssl]
+
+ # The string passed to us is made up of multiple
+ # length-prefixed bytestrings. We need to split that into a
+ # list.
+ instr = _ffi.buffer(in_, inlen)[:]
+ protolist = []
+ while instr:
+ encoded_len = indexbytes(instr, 0)
+ proto = instr[1:encoded_len + 1]
+ protolist.append(proto)
+ instr = instr[encoded_len + 1:]
+
+ # Call the callback
+ outstr = callback(conn, protolist)
+
+ if not isinstance(outstr, _binary_type):
+ raise TypeError("ALPN callback must return a bytestring.")
+
+ # Save our callback arguments on the connection object to make
+ # sure that they don't get freed before OpenSSL can use them.
+ # Then, return them in the appropriate output parameters.
+ conn._alpn_select_callback_args = [
+ _ffi.new("unsigned char *", len(outstr)),
+ _ffi.new("unsigned char[]", outstr),
+ ]
+ outlen[0] = conn._alpn_select_callback_args[0][0]
+ out[0] = conn._alpn_select_callback_args[1]
+ return 0
+ except Exception as e:
+ self._problems.append(e)
+ return 2 # SSL_TLSEXT_ERR_ALERT_FATAL
+
+ self.callback = _ffi.callback(
+ "int (*)(SSL *, unsigned char **, unsigned char *, "
+ "const unsigned char *, unsigned int, void *)",
+ wrapper
+ )
+
+
def _asFileDescriptor(obj):
fd = None
if not isinstance(obj, integer_types):
@@ -363,6 +415,22 @@ def _requires_npn(func):
+def _requires_alpn(func):
+ """
+ Wraps any function that requires ALPN support in OpenSSL, ensuring that
+ NotImplementedError is raised if ALPN support is not present.
+ """
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if not _lib.Cryptography_HAS_ALPN:
+ raise NotImplementedError("ALPN not available.")
+
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+
class Session(object):
pass
@@ -424,6 +492,8 @@ class Context(object):
self._npn_advertise_callback = None
self._npn_select_helper = None
self._npn_select_callback = None
+ self._alpn_select_helper = None
+ self._alpn_select_callback = None
# SSL_CTX_set_app_data(self->ctx, self);
# SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE |
@@ -973,6 +1043,44 @@ class Context(object):
_lib.SSL_CTX_set_next_proto_select_cb(
self._context, self._npn_select_callback, _ffi.NULL)
+ @_requires_alpn
+ def set_alpn_protos(self, protos):
+ """
+ Specify the clients ALPN protocol list.
+
+ These protocols are offered to the server during protocol negotiation.
+
+ :param protos: A list of the protocols to be offered to the server.
+ This list should be a Python list of bytestrings representing the
+ protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``.
+ """
+ # Take the list of protocols and join them together, prefixing them
+ # with their lengths.
+ protostr = b''.join(
+ chain.from_iterable((int2byte(len(p)), p) for p in protos)
+ )
+
+ # Build a C string from the list. We don't need to save this off
+ # because OpenSSL immediately copies the data out.
+ input_str = _ffi.new("unsigned char[]", protostr)
+ input_str_len = _ffi.cast("unsigned", len(protostr))
+ _lib.SSL_CTX_set_alpn_protos(self._context, input_str, input_str_len)
+
+ @_requires_alpn
+ def set_alpn_select_callback(self, callback):
+ """
+ Set the callback to handle ALPN protocol choice.
+
+ :param callback: The callback function. It will be invoked with two
+ arguments: the Connection, and a list of offered protocols as
+ bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It should return
+ one of those bytestrings, the chosen protocol.
+ """
+ self._alpn_select_helper = _ALPNSelectHelper(callback)
+ self._alpn_select_callback = self._alpn_select_helper.callback
+ _lib.SSL_CTX_set_alpn_select_cb(
+ self._context, self._alpn_select_callback, _ffi.NULL)
+
ContextType = Context
@@ -1004,6 +1112,12 @@ class Connection(object):
self._npn_advertise_callback_args = None
self._npn_select_callback_args = None
+ # References to strings used for Application Layer Protocol
+ # Negotiation. These strings get copied at some point but it's well
+ # after the callback returns, so we have to hang them somewhere to
+ # avoid them getting freed.
+ self._alpn_select_callback_args = None
+
self._reverse_mapping[self._ssl] = self
if socket is None:
@@ -1042,6 +1156,8 @@ class Connection(object):
self._context._npn_advertise_helper.raise_if_problem()
if self._context._npn_select_helper is not None:
self._context._npn_select_helper.raise_if_problem()
+ if self._context._alpn_select_helper is not None:
+ self._context._alpn_select_helper.raise_if_problem()
error = _lib.SSL_get_error(ssl, result)
if error == _lib.SSL_ERROR_WANT_READ:
@@ -1142,6 +1258,9 @@ class Connection(object):
API, the value is ignored
:return: The number of bytes written
"""
+ # Backward compatibility
+ buf = _text_to_bytes_and_warn("buf", buf)
+
if isinstance(buf, _memoryview):
buf = buf.tobytes()
if isinstance(buf, _buffer):
@@ -1166,6 +1285,8 @@ class Connection(object):
API, the value is ignored
:return: The number of bytes written
"""
+ buf = _text_to_bytes_and_warn("buf", buf)
+
if isinstance(buf, _memoryview):
buf = buf.tobytes()
if isinstance(buf, _buffer):
@@ -1290,6 +1411,8 @@ class Connection(object):
:param buf: The string to put into the memory BIO.
:return: The number of bytes written
"""
+ buf = _text_to_bytes_and_warn("buf", buf)
+
if self._into_ssl is None:
raise TypeError("Connection sock was not None")
@@ -1776,6 +1899,48 @@ class Connection(object):
return _ffi.buffer(data[0], data_len[0])[:]
+ @_requires_alpn
+ def set_alpn_protos(self, protos):
+ """
+ Specify the client's ALPN protocol list.
+
+ These protocols are offered to the server during protocol negotiation.
+
+ :param protos: A list of the protocols to be offered to the server.
+ This list should be a Python list of bytestrings representing the
+ protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``.
+ """
+ # Take the list of protocols and join them together, prefixing them
+ # with their lengths.
+ protostr = b''.join(
+ chain.from_iterable((int2byte(len(p)), p) for p in protos)
+ )
+
+ # Build a C string from the list. We don't need to save this off
+ # because OpenSSL immediately copies the data out.
+ input_str = _ffi.new("unsigned char[]", protostr)
+ input_str_len = _ffi.cast("unsigned", len(protostr))
+ _lib.SSL_set_alpn_protos(self._ssl, input_str, input_str_len)
+
+
+ def get_alpn_proto_negotiated(self):
+ """
+ Get the protocol that was negotiated by ALPN.
+ """
+ if not _lib.Cryptography_HAS_ALPN:
+ raise NotImplementedError("ALPN not available")
+
+ data = _ffi.new("unsigned char **")
+ data_len = _ffi.new("unsigned int *")
+
+ _lib.SSL_get0_alpn_selected(self._ssl, data, data_len)
+
+ if not data_len:
+ return b''
+
+ return _ffi.buffer(data[0], data_len[0])[:]
+
+
ConnectionType = Connection