summaryrefslogtreecommitdiff
path: root/OpenSSL/SSL.py
diff options
context:
space:
mode:
authorCory Benfield <lukasaoz@gmail.com>2015-03-22 09:05:28 +0000
committerCory Benfield <lukasaoz@gmail.com>2015-03-22 09:07:30 +0000
commit0ea76e7d977b19f2bb4ce4dee9bee8aa179eaff0 (patch)
treec584cfc4c2c722da4272f4d696ef77103bb47e03 /OpenSSL/SSL.py
parent4969c22cf65f91f5bc87bf25ce45875dae32f576 (diff)
downloadpyopenssl-0ea76e7d977b19f2bb4ce4dee9bee8aa179eaff0.tar.gz
Handle exceptions in NPN callbacks.
Diffstat (limited to 'OpenSSL/SSL.py')
-rw-r--r--OpenSSL/SSL.py169
1 files changed, 105 insertions, 64 deletions
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index e97df8b..5af4969 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -165,8 +165,24 @@ class SysCallError(Error):
pass
+class _CallbackExceptionHelper(object):
+ """
+ A base class for wrapper classes that allow for intelligent exception
+ handling in OpenSSL callbacks.
+ """
+ def __init__(self, callback):
+ pass
+
+ def raise_if_problem(self):
+ if self._problems:
+ try:
+ _raise_current_error()
+ except Error:
+ pass
+ raise self._problems.pop(0)
+
-class _VerifyHelper(object):
+class _VerifyHelper(_CallbackExceptionHelper):
def __init__(self, callback):
self._problems = []
@@ -197,15 +213,87 @@ class _VerifyHelper(object):
"int (*)(int, X509_STORE_CTX *)", wrapper)
- def raise_if_problem(self):
- if self._problems:
+class _NpnAdvertiseHelper(_CallbackExceptionHelper):
+ def __init__(self, callback):
+ self._problems = []
+
+ @wraps(callback)
+ def wrapper(ssl, out, outlen, arg):
try:
- _raise_current_error()
- except Error:
- pass
- raise self._problems.pop(0)
+ conn = Connection._reverse_mapping[ssl]
+ protos = callback(conn)
+
+ # Join the protocols into a Python bytestring, length-prefixing
+ # each element.
+ protostr = b''.join(
+ chain.from_iterable((int2byte(len(p)), p) for p in protos)
+ )
+
+ # Save our callback arguments on the connection object. This is
+ # done to make sure that they don't get freed before OpenSSL
+ # uses them. Then, return them appropriately in the output
+ # parameters.
+ conn._npn_advertise_callback_args = [
+ _ffi.new("unsigned int *", len(protostr)),
+ _ffi.new("unsigned char[]", protostr),
+ ]
+ outlen[0] = conn._npn_advertise_callback_args[0][0]
+ out[0] = conn._npn_advertise_callback_args[1]
+ return 0
+ except Exception as e:
+ self._problems.append(e)
+ return 2 # SSL_TLSEXT_ERR_ALERT_FATAL
+
+ self.callback = _ffi.callback(
+ "int (*)(SSL *, const unsigned char **, unsigned int *, void *)",
+ wrapper
+ )
+class _NpnSelectHelper(_CallbackExceptionHelper):
+ def __init__(self, callback):
+ self._problems = []
+
+ @wraps(callback)
+ def wrapper(ssl, out, outlen, in_, inlen, arg):
+ try:
+ conn = Connection._reverse_mapping[ssl]
+
+ # The string passed to us is actually made up of multiple
+ # length-prefixed bytestrings. We need to split that into a
+ # list.
+ instr = _ffi.buffer(in_, inlen)[:]
+ protolist = []
+ while instr:
+ l = indexbytes(instr, 0)
+ proto = instr[1:l+1]
+ protolist.append(proto)
+ instr = instr[l+1:]
+
+ # Call the callback
+ outstr = callback(conn, protolist)
+
+ # Save our callback arguments on the connection object. This is
+ # done to make sure that they don't get freed before OpenSSL
+ # uses them. Then, return them appropriately in the output
+ # parameters.
+ conn._npn_select_callback_args = [
+ _ffi.new("unsigned char *", len(outstr)),
+ _ffi.new("unsigned char[]", outstr),
+ ]
+ outlen[0] = conn._npn_select_callback_args[0][0]
+ out[0] = conn._npn_select_callback_args[1]
+ return 0
+ except Exception as e:
+ self._problems.append(e)
+ return 2 # SSL_TLSEXT_ERR_ALERT_FATAL
+
+ self.callback = _ffi.callback(
+ "int (*)(SSL *, unsigned char **, unsigned char *, "
+ "const unsigned char *, unsigned int, void *)",
+ wrapper
+ )
+
def _asFileDescriptor(obj):
fd = None
@@ -294,7 +382,9 @@ class Context(object):
self._info_callback = None
self._tlsext_servername_callback = None
self._app_data = None
+ self._npn_advertise_helper = None
self._npn_advertise_callback = None
+ self._npn_select_helper = None
self._npn_select_callback = None
# SSL_CTX_set_app_data(self->ctx, self);
@@ -824,31 +914,8 @@ class Context(object):
bytestrings representing the advertised protocols, like
``[b'http/1.1', b'spdy/2']``.
"""
- @wraps(callback)
- def wrapper(ssl, out, outlen, arg):
- conn = Connection._reverse_mapping[ssl]
- protos = callback(conn)
-
- # Join the protocols into a Python bytestring, length-prefixing
- # each element.
- protostr = b''.join(
- chain.from_iterable((int2byte(len(p)), p) for p in protos)
- )
-
- # Save our callback arguments on the connection object. This is
- # done to make sure that they don't get freed before OpenSSL uses
- # them. Then, return them appropriately in the output parameters.
- conn._npn_advertise_callback_args = [
- _ffi.new("unsigned int *", len(protostr)),
- _ffi.new("unsigned char[]", protostr),
- ]
- outlen[0] = conn._npn_advertise_callback_args[0][0]
- out[0] = conn._npn_advertise_callback_args[1]
- return 0
-
- self._npn_advertise_callback = _ffi.callback(
- "int (*)(SSL *, const unsigned char **, unsigned int *, void *)",
- wrapper)
+ self._npn_advertise_helper = _NpnAdvertiseHelper(callback)
+ self._npn_advertise_callback = self._npn_advertise_helper.callback
_lib.SSL_CTX_set_next_protos_advertised_cb(
self._context, self._npn_advertise_callback, _ffi.NULL)
@@ -863,38 +930,8 @@ class Context(object):
bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return
one of those bytestrings, the chosen protocol.
"""
- @wraps(callback)
- def wrapper(ssl, out, outlen, in_, inlen, arg):
- conn = Connection._reverse_mapping[ssl]
-
- # The string passed to us is actually made up of multiple
- # length-prefixed bytestrings. We need to split that into a list.
- instr = _ffi.buffer(in_, inlen)[:]
- protolist = []
- while instr:
- l = indexbytes(instr, 0)
- proto = instr[1:l+1]
- protolist.append(proto)
- instr = instr[l+1:]
-
- # Call the callback
- outstr = callback(conn, protolist)
-
- # Save our callback arguments on the connection object. This is
- # done to make sure that they don't get freed before OpenSSL uses
- # them. Then, return them appropriately in the output parameters.
- conn._npn_select_callback_args = [
- _ffi.new("unsigned char *", len(outstr)),
- _ffi.new("unsigned char[]", outstr),
- ]
- outlen[0] = conn._npn_select_callback_args[0][0]
- out[0] = conn._npn_select_callback_args[1]
- return 0
-
- self._npn_select_callback = _ffi.callback(
- "int (*)(SSL *, unsigned char **, unsigned char *, "
- "const unsigned char *, unsigned int, void *)",
- wrapper)
+ self._npn_select_helper = _NpnSelectHelper(callback)
+ self._npn_select_callback = self._npn_select_helper.callback
_lib.SSL_CTX_set_next_proto_select_cb(
self._context, self._npn_select_callback, _ffi.NULL)
@@ -963,6 +1000,10 @@ class Connection(object):
def _raise_ssl_error(self, ssl, result):
if self._context._verify_helper is not None:
self._context._verify_helper.raise_if_problem()
+ if self._context._npn_advertise_helper is not None:
+ self._context._npn_advertise_helper.raise_if_problem()
+ if self._context._npn_select_helper is not None:
+ self._context._npn_select_helper.raise_if_problem()
error = _lib.SSL_get_error(ssl, result)
if error == _lib.SSL_ERROR_WANT_READ: