diff options
| author | Jean-Paul Calderone <exarkun@twistedmatrix.com> | 2015-04-13 22:14:07 -0400 |
|---|---|---|
| committer | Jean-Paul Calderone <exarkun@twistedmatrix.com> | 2015-04-13 22:14:07 -0400 |
| commit | 218a0146a82b1e392ef7bbde6fd63f58f2d35ad7 (patch) | |
| tree | a412d198374e74f8ab4a4662642aa913a54d900c /OpenSSL/SSL.py | |
| parent | ba1820dfb1b03c849b272410d2c955478e633fbc (diff) | |
| parent | 5c3b748846ad1f9597d51b24d04ac394980c2480 (diff) | |
| download | pyopenssl-218a0146a82b1e392ef7bbde6fd63f58f2d35ad7.tar.gz | |
merge master
Diffstat (limited to 'OpenSSL/SSL.py')
| -rw-r--r-- | OpenSSL/SSL.py | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 87492af..bba59ac 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -5,6 +5,7 @@ from weakref import WeakValueDictionary from errno import errorcode from six import text_type as _text_type +from six import binary_type as _binary_type from six import integer_types as integer_types from six import int2byte, indexbytes @@ -13,6 +14,7 @@ from OpenSSL._util import ( lib as _lib, exception_from_error_queue as _exception_from_error_queue, native as _native, + text_to_bytes_and_warn as _text_to_bytes_and_warn, path_string as _path_string, ) @@ -318,6 +320,56 @@ class _NpnSelectHelper(_CallbackExceptionHelper): ) +class _ALPNSelectHelper(_CallbackExceptionHelper): + """ + Wrap a callback such that it can be used as an ALPN selection callback. + """ + def __init__(self, callback): + _CallbackExceptionHelper.__init__(self) + + @wraps(callback) + def wrapper(ssl, out, outlen, in_, inlen, arg): + try: + conn = Connection._reverse_mapping[ssl] + + # The string passed to us is made up of multiple + # length-prefixed bytestrings. We need to split that into a + # list. + instr = _ffi.buffer(in_, inlen)[:] + protolist = [] + while instr: + encoded_len = indexbytes(instr, 0) + proto = instr[1:encoded_len + 1] + protolist.append(proto) + instr = instr[encoded_len + 1:] + + # Call the callback + outstr = callback(conn, protolist) + + if not isinstance(outstr, _binary_type): + raise TypeError("ALPN callback must return a bytestring.") + + # Save our callback arguments on the connection object to make + # sure that they don't get freed before OpenSSL can use them. + # Then, return them in the appropriate output parameters. + conn._alpn_select_callback_args = [ + _ffi.new("unsigned char *", len(outstr)), + _ffi.new("unsigned char[]", outstr), + ] + outlen[0] = conn._alpn_select_callback_args[0][0] + out[0] = conn._alpn_select_callback_args[1] + return 0 + except Exception as e: + self._problems.append(e) + return 2 # SSL_TLSEXT_ERR_ALERT_FATAL + + self.callback = _ffi.callback( + "int (*)(SSL *, unsigned char **, unsigned char *, " + "const unsigned char *, unsigned int, void *)", + wrapper + ) + + def _asFileDescriptor(obj): fd = None if not isinstance(obj, integer_types): @@ -363,6 +415,22 @@ def _requires_npn(func): +def _requires_alpn(func): + """ + Wraps any function that requires ALPN support in OpenSSL, ensuring that + NotImplementedError is raised if ALPN support is not present. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if not _lib.Cryptography_HAS_ALPN: + raise NotImplementedError("ALPN not available.") + + return func(*args, **kwargs) + + return wrapper + + + class Session(object): pass @@ -424,6 +492,8 @@ class Context(object): self._npn_advertise_callback = None self._npn_select_helper = None self._npn_select_callback = None + self._alpn_select_helper = None + self._alpn_select_callback = None # SSL_CTX_set_app_data(self->ctx, self); # SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE | @@ -973,6 +1043,44 @@ class Context(object): _lib.SSL_CTX_set_next_proto_select_cb( self._context, self._npn_select_callback, _ffi.NULL) + @_requires_alpn + def set_alpn_protos(self, protos): + """ + Specify the clients ALPN protocol list. + + These protocols are offered to the server during protocol negotiation. + + :param protos: A list of the protocols to be offered to the server. + This list should be a Python list of bytestrings representing the + protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``. + """ + # Take the list of protocols and join them together, prefixing them + # with their lengths. + protostr = b''.join( + chain.from_iterable((int2byte(len(p)), p) for p in protos) + ) + + # Build a C string from the list. We don't need to save this off + # because OpenSSL immediately copies the data out. + input_str = _ffi.new("unsigned char[]", protostr) + input_str_len = _ffi.cast("unsigned", len(protostr)) + _lib.SSL_CTX_set_alpn_protos(self._context, input_str, input_str_len) + + @_requires_alpn + def set_alpn_select_callback(self, callback): + """ + Set the callback to handle ALPN protocol choice. + + :param callback: The callback function. It will be invoked with two + arguments: the Connection, and a list of offered protocols as + bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It should return + one of those bytestrings, the chosen protocol. + """ + self._alpn_select_helper = _ALPNSelectHelper(callback) + self._alpn_select_callback = self._alpn_select_helper.callback + _lib.SSL_CTX_set_alpn_select_cb( + self._context, self._alpn_select_callback, _ffi.NULL) + ContextType = Context @@ -1004,6 +1112,12 @@ class Connection(object): self._npn_advertise_callback_args = None self._npn_select_callback_args = None + # References to strings used for Application Layer Protocol + # Negotiation. These strings get copied at some point but it's well + # after the callback returns, so we have to hang them somewhere to + # avoid them getting freed. + self._alpn_select_callback_args = None + self._reverse_mapping[self._ssl] = self if socket is None: @@ -1042,6 +1156,8 @@ class Connection(object): self._context._npn_advertise_helper.raise_if_problem() if self._context._npn_select_helper is not None: self._context._npn_select_helper.raise_if_problem() + if self._context._alpn_select_helper is not None: + self._context._alpn_select_helper.raise_if_problem() error = _lib.SSL_get_error(ssl, result) if error == _lib.SSL_ERROR_WANT_READ: @@ -1142,6 +1258,9 @@ class Connection(object): API, the value is ignored :return: The number of bytes written """ + # Backward compatibility + buf = _text_to_bytes_and_warn("buf", buf) + if isinstance(buf, _memoryview): buf = buf.tobytes() if isinstance(buf, _buffer): @@ -1166,6 +1285,8 @@ class Connection(object): API, the value is ignored :return: The number of bytes written """ + buf = _text_to_bytes_and_warn("buf", buf) + if isinstance(buf, _memoryview): buf = buf.tobytes() if isinstance(buf, _buffer): @@ -1290,6 +1411,8 @@ class Connection(object): :param buf: The string to put into the memory BIO. :return: The number of bytes written """ + buf = _text_to_bytes_and_warn("buf", buf) + if self._into_ssl is None: raise TypeError("Connection sock was not None") @@ -1776,6 +1899,48 @@ class Connection(object): return _ffi.buffer(data[0], data_len[0])[:] + @_requires_alpn + def set_alpn_protos(self, protos): + """ + Specify the client's ALPN protocol list. + + These protocols are offered to the server during protocol negotiation. + + :param protos: A list of the protocols to be offered to the server. + This list should be a Python list of bytestrings representing the + protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``. + """ + # Take the list of protocols and join them together, prefixing them + # with their lengths. + protostr = b''.join( + chain.from_iterable((int2byte(len(p)), p) for p in protos) + ) + + # Build a C string from the list. We don't need to save this off + # because OpenSSL immediately copies the data out. + input_str = _ffi.new("unsigned char[]", protostr) + input_str_len = _ffi.cast("unsigned", len(protostr)) + _lib.SSL_set_alpn_protos(self._ssl, input_str, input_str_len) + + + def get_alpn_proto_negotiated(self): + """ + Get the protocol that was negotiated by ALPN. + """ + if not _lib.Cryptography_HAS_ALPN: + raise NotImplementedError("ALPN not available") + + data = _ffi.new("unsigned char **") + data_len = _ffi.new("unsigned int *") + + _lib.SSL_get0_alpn_selected(self._ssl, data, data_len) + + if not data_len: + return b'' + + return _ffi.buffer(data[0], data_len[0])[:] + + ConnectionType = Connection |
