diff options
| author | Cory Benfield <lukasaoz@gmail.com> | 2015-03-22 09:05:28 +0000 |
|---|---|---|
| committer | Cory Benfield <lukasaoz@gmail.com> | 2015-03-22 09:07:30 +0000 |
| commit | 0ea76e7d977b19f2bb4ce4dee9bee8aa179eaff0 (patch) | |
| tree | c584cfc4c2c722da4272f4d696ef77103bb47e03 /OpenSSL/SSL.py | |
| parent | 4969c22cf65f91f5bc87bf25ce45875dae32f576 (diff) | |
| download | pyopenssl-0ea76e7d977b19f2bb4ce4dee9bee8aa179eaff0.tar.gz | |
Handle exceptions in NPN callbacks.
Diffstat (limited to 'OpenSSL/SSL.py')
| -rw-r--r-- | OpenSSL/SSL.py | 169 |
1 files changed, 105 insertions, 64 deletions
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index e97df8b..5af4969 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -165,8 +165,24 @@ class SysCallError(Error): pass +class _CallbackExceptionHelper(object): + """ + A base class for wrapper classes that allow for intelligent exception + handling in OpenSSL callbacks. + """ + def __init__(self, callback): + pass + + def raise_if_problem(self): + if self._problems: + try: + _raise_current_error() + except Error: + pass + raise self._problems.pop(0) + -class _VerifyHelper(object): +class _VerifyHelper(_CallbackExceptionHelper): def __init__(self, callback): self._problems = [] @@ -197,15 +213,87 @@ class _VerifyHelper(object): "int (*)(int, X509_STORE_CTX *)", wrapper) - def raise_if_problem(self): - if self._problems: +class _NpnAdvertiseHelper(_CallbackExceptionHelper): + def __init__(self, callback): + self._problems = [] + + @wraps(callback) + def wrapper(ssl, out, outlen, arg): try: - _raise_current_error() - except Error: - pass - raise self._problems.pop(0) + conn = Connection._reverse_mapping[ssl] + protos = callback(conn) + + # Join the protocols into a Python bytestring, length-prefixing + # each element. + protostr = b''.join( + chain.from_iterable((int2byte(len(p)), p) for p in protos) + ) + + # Save our callback arguments on the connection object. This is + # done to make sure that they don't get freed before OpenSSL + # uses them. Then, return them appropriately in the output + # parameters. + conn._npn_advertise_callback_args = [ + _ffi.new("unsigned int *", len(protostr)), + _ffi.new("unsigned char[]", protostr), + ] + outlen[0] = conn._npn_advertise_callback_args[0][0] + out[0] = conn._npn_advertise_callback_args[1] + return 0 + except Exception as e: + self._problems.append(e) + return 2 # SSL_TLSEXT_ERR_ALERT_FATAL + + self.callback = _ffi.callback( + "int (*)(SSL *, const unsigned char **, unsigned int *, void *)", + wrapper + ) +class _NpnSelectHelper(_CallbackExceptionHelper): + def __init__(self, callback): + self._problems = [] + + @wraps(callback) + def wrapper(ssl, out, outlen, in_, inlen, arg): + try: + conn = Connection._reverse_mapping[ssl] + + # The string passed to us is actually made up of multiple + # length-prefixed bytestrings. We need to split that into a + # list. + instr = _ffi.buffer(in_, inlen)[:] + protolist = [] + while instr: + l = indexbytes(instr, 0) + proto = instr[1:l+1] + protolist.append(proto) + instr = instr[l+1:] + + # Call the callback + outstr = callback(conn, protolist) + + # Save our callback arguments on the connection object. This is + # done to make sure that they don't get freed before OpenSSL + # uses them. Then, return them appropriately in the output + # parameters. + conn._npn_select_callback_args = [ + _ffi.new("unsigned char *", len(outstr)), + _ffi.new("unsigned char[]", outstr), + ] + outlen[0] = conn._npn_select_callback_args[0][0] + out[0] = conn._npn_select_callback_args[1] + return 0 + except Exception as e: + self._problems.append(e) + return 2 # SSL_TLSEXT_ERR_ALERT_FATAL + + self.callback = _ffi.callback( + "int (*)(SSL *, unsigned char **, unsigned char *, " + "const unsigned char *, unsigned int, void *)", + wrapper + ) + def _asFileDescriptor(obj): fd = None @@ -294,7 +382,9 @@ class Context(object): self._info_callback = None self._tlsext_servername_callback = None self._app_data = None + self._npn_advertise_helper = None self._npn_advertise_callback = None + self._npn_select_helper = None self._npn_select_callback = None # SSL_CTX_set_app_data(self->ctx, self); @@ -824,31 +914,8 @@ class Context(object): bytestrings representing the advertised protocols, like ``[b'http/1.1', b'spdy/2']``. """ - @wraps(callback) - def wrapper(ssl, out, outlen, arg): - conn = Connection._reverse_mapping[ssl] - protos = callback(conn) - - # Join the protocols into a Python bytestring, length-prefixing - # each element. - protostr = b''.join( - chain.from_iterable((int2byte(len(p)), p) for p in protos) - ) - - # Save our callback arguments on the connection object. This is - # done to make sure that they don't get freed before OpenSSL uses - # them. Then, return them appropriately in the output parameters. - conn._npn_advertise_callback_args = [ - _ffi.new("unsigned int *", len(protostr)), - _ffi.new("unsigned char[]", protostr), - ] - outlen[0] = conn._npn_advertise_callback_args[0][0] - out[0] = conn._npn_advertise_callback_args[1] - return 0 - - self._npn_advertise_callback = _ffi.callback( - "int (*)(SSL *, const unsigned char **, unsigned int *, void *)", - wrapper) + self._npn_advertise_helper = _NpnAdvertiseHelper(callback) + self._npn_advertise_callback = self._npn_advertise_helper.callback _lib.SSL_CTX_set_next_protos_advertised_cb( self._context, self._npn_advertise_callback, _ffi.NULL) @@ -863,38 +930,8 @@ class Context(object): bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return one of those bytestrings, the chosen protocol. """ - @wraps(callback) - def wrapper(ssl, out, outlen, in_, inlen, arg): - conn = Connection._reverse_mapping[ssl] - - # The string passed to us is actually made up of multiple - # length-prefixed bytestrings. We need to split that into a list. - instr = _ffi.buffer(in_, inlen)[:] - protolist = [] - while instr: - l = indexbytes(instr, 0) - proto = instr[1:l+1] - protolist.append(proto) - instr = instr[l+1:] - - # Call the callback - outstr = callback(conn, protolist) - - # Save our callback arguments on the connection object. This is - # done to make sure that they don't get freed before OpenSSL uses - # them. Then, return them appropriately in the output parameters. - conn._npn_select_callback_args = [ - _ffi.new("unsigned char *", len(outstr)), - _ffi.new("unsigned char[]", outstr), - ] - outlen[0] = conn._npn_select_callback_args[0][0] - out[0] = conn._npn_select_callback_args[1] - return 0 - - self._npn_select_callback = _ffi.callback( - "int (*)(SSL *, unsigned char **, unsigned char *, " - "const unsigned char *, unsigned int, void *)", - wrapper) + self._npn_select_helper = _NpnSelectHelper(callback) + self._npn_select_callback = self._npn_select_helper.callback _lib.SSL_CTX_set_next_proto_select_cb( self._context, self._npn_select_callback, _ffi.NULL) @@ -963,6 +1000,10 @@ class Connection(object): def _raise_ssl_error(self, ssl, result): if self._context._verify_helper is not None: self._context._verify_helper.raise_if_problem() + if self._context._npn_advertise_helper is not None: + self._context._npn_advertise_helper.raise_if_problem() + if self._context._npn_select_helper is not None: + self._context._npn_select_helper.raise_if_problem() error = _lib.SSL_get_error(ssl, result) if error == _lib.SSL_ERROR_WANT_READ: |
