diff options
-rw-r--r-- | tests/test_websocketproxy.py | 14 | ||||
-rw-r--r-- | websockify/auth_plugins.py | 50 | ||||
-rw-r--r-- | websockify/websocket.py | 12 | ||||
-rwxr-xr-x | websockify/websocketproxy.py | 33 |
4 files changed, 88 insertions, 21 deletions
diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py index 8103ef6..92fd5db 100644 --- a/tests/test_websocketproxy.py +++ b/tests/test_websocketproxy.py @@ -106,11 +106,11 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): def lookup(self, token): return (self.source + token).split(',') - self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy', - lambda *args, **kwargs: None) + self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error', + staticmethod(lambda *args, **kwargs: None)) self.handler.server.token_plugin = TestPlugin("somehost,") - self.handler.new_websocket_client() + self.handler.validate_connection() self.assertEqual(self.handler.server.target_host, "somehost") self.assertEqual(self.handler.server.target_port, "blah") @@ -119,9 +119,9 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): class TestPlugin(auth_plugins.BasePlugin): def authenticate(self, headers, target_host, target_port): if target_host == self.source: - raise auth_plugins.AuthenticationError("some error") + raise auth_plugins.AuthenticationError(response_msg="some_error") - self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy', + self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error', staticmethod(lambda *args, **kwargs: None)) self.handler.server.auth_plugin = TestPlugin("somehost") @@ -129,8 +129,8 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): self.handler.server.target_port = "someport" self.assertRaises(auth_plugins.AuthenticationError, - self.handler.new_websocket_client) + self.handler.validate_connection) self.handler.server.target_host = "someotherhost" - self.handler.new_websocket_client() + self.handler.validate_connection() diff --git a/websockify/auth_plugins.py b/websockify/auth_plugins.py index 647c26e..924d5de 100644 --- a/websockify/auth_plugins.py +++ b/websockify/auth_plugins.py @@ -7,7 +7,15 @@ class BasePlugin(object): class AuthenticationError(Exception): - pass + def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None): + self.code = response_code + self.headers = response_headers + self.msg = response_msg + + if log_msg is None: + log_msg = response_msg + + super(AuthenticationError, self).__init__('%s %s' % (self.code, log_msg)) class InvalidOriginError(AuthenticationError): @@ -16,8 +24,44 @@ class InvalidOriginError(AuthenticationError): self.actual_origin = actual super(InvalidOriginError, self).__init__( - "Invalid Origin Header: Expected one of " - "%s, got '%s'" % (expected, actual)) + response_msg='Invalid Origin', + log_msg="Invalid Origin Header: Expected one of " + "%s, got '%s'" % (expected, actual)) + + +class BasicHTTPAuth(object): + def __init__(self, src=None): + self.src = src + + def authenticate(self, headers, target_host, target_port): + import base64 + + auth_header = headers.get('Authorization') + if auth_header: + if not auth_header.startswith('Basic '): + raise AuthenticationError(response_code=403) + + try: + user_pass_raw = base64.b64decode(auth_header[6:]) + except TypeError: + raise AuthenticationError(response_code=403) + + user_pass = user_pass_raw.split(':', 1) + if len(user_pass) != 2: + raise AuthenticationError(response_code=403) + + if not self.validate_creds: + raise AuthenticationError(response_code=403) + + else: + raise AuthenticationError(response_code=401, + response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'}) + + def validate_creds(username, password): + if '%s:%s' % (username, password) == self.src: + return True + else: + return False class ExpectOrigin(object): def __init__(self, src=None): diff --git a/websockify/websocket.py b/websockify/websocket.py index 1cbf583..7fa9651 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -474,9 +474,13 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): """Upgrade a connection to Websocket, if requested. If this succeeds, new_websocket_client() will be called. Otherwise, False is returned. """ + if (self.headers.get('upgrade') and self.headers.get('upgrade').lower() == 'websocket'): + # ensure connection is authorized, and determine the target + self.validate_connection() + if not self.do_websocket_handshake(): return False @@ -549,6 +553,10 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): """ Do something with a WebSockets client connection. """ raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded") + def validate_connection(self): + """ Ensure that the connection is a valid connection, and set the target. """ + pass + def do_HEAD(self): if self.only_upgrade: self.send_error(405, "Method Not Allowed") @@ -789,7 +797,7 @@ class WebSocketServer(object): """ ready = select.select([sock], [], [], 3)[0] - + if not ready: raise self.EClose("ignoring socket not ready") # Peek, but do not read the data so that we have a opportunity @@ -903,7 +911,7 @@ class WebSocketServer(object): def top_new_client(self, startsock, address): """ Do something with a WebSockets client connection. """ - # handler process + # handler process client = None try: try: diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index 029b6f3..46ab545 100755 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -18,6 +18,7 @@ try: from http.server import HTTPServer except: from BaseHTTPServer import HTTPServer import select from websockify import websocket +from websockify import auth_plugins as auth try: from urllib.parse import parse_qs, urlparse except: @@ -37,20 +38,34 @@ Traffic Legend: < - Client send <. - Client send partial """ + + def send_auth_error(self, ex): + self.send_response(ex.code, ex.msg) + self.send_header('Content-Type', 'text/html') + for name, val in ex.headers.items(): + self.send_header(name, val) + + self.end_headers() + + def validate_connection(self): + if self.server.token_plugin: + (self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path) + + if self.server.auth_plugin: + try: + self.server.auth_plugin.authenticate( + headers=self.headers, target_host=self.server.target_host, + target_port=self.server.target_port) + except auth.AuthenticationError: + ex = sys.exc_info()[1] + self.send_auth_error(ex) + raise def new_websocket_client(self): """ Called after a new WebSocket connection has been established. """ - # Checks if we receive a token, and look - # for a valid target for it then - if self.server.token_plugin: - (self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path) - - if self.server.auth_plugin: - self.server.auth_plugin.authenticate( - headers=self.headers, target_host=self.server.target_host, - target_port=self.server.target_port) + # Checking for a token is done in validate_connection() # Connect to the target if self.server.wrap_cmd: |