summaryrefslogtreecommitdiff
path: root/webtest/http.py
blob: dc4b3d967bac6c59ca52099c229d90727f969ce1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# -*- coding: utf-8 -*-
from waitress.server import WSGIServer
from six.moves import http_client
from six import text_type
import threading
import logging
import socket
import webob
import time
import os


def _free_port():
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind(('', 0))
    ip, port = s.getsockname()
    s.close()
    ip = os.environ.get('WEBTEST_SERVER_BIND', '127.0.0.1')
    return ip, port


class StopableWSGIServer(WSGIServer):

    def __init__(self, application, *args, **kwargs):
        super(StopableWSGIServer, self).__init__(self.wrapper, *args, **kwargs)
        self.main_thread = None
        self.test_app = application
        self.application_url = 'http://%s:%s/' % (self.adj.host, self.adj.port)

    def wrapper(self, environ, start_response):
        if '__file__' in environ['PATH_INFO']:
            req = webob.Request(environ)
            resp = webob.Response()
            resp.content_type = 'text/html; charset=UTF-8'
            filename = req.params.get('__file__')
            body = open(filename, 'r').read()
            body.replace('http://localhost/',
                         'http://%s/' % req.host)
            if isinstance(body, text_type):
                body = body.encode('utf8')
            resp.body = body
            return resp(environ, start_response)
        elif '__application__' in environ['PATH_INFO']:
            return webob.Response('server started')(environ, start_response)
        return self.test_app(environ, start_response)

    def run(self):
        try:
            self.asyncore.loop(.5, map=self._map)
        except (SystemExit, KeyboardInterrupt):
            self.task_dispatcher.shutdown()

    def shutdown(self):
        # avoid showing traceback related to asyncore
        self.logger.setLevel(logging.FATAL)
        while self._map:
            triggers = list(self._map.values())
            for trigger in triggers:
                trigger.handle_close()
        self.maintenance(0)
        while not self.task_dispatcher.shutdown():
            pass

    @classmethod
    def create(cls, application, **kwargs):
        host, port = _free_port()
        kwargs['port'] = port
        if 'host' not in kwargs:
            kwargs['host'] = host
        server = cls(application, **kwargs)
        thread = threading.Thread(target=server.run)
        server.main_thread = thread
        thread.start()
        return server

    def wait(self):
        conn = http_client.HTTPConnection(self.adj.host, self.adj.port)
        time.sleep(.5)
        for i in range(100):
            try:
                conn.request('GET', '/__application__')
                conn.getresponse()
            except (socket.error, http_client.HTTPException):
                time.sleep(.3)
            else:
                return True
        try:
            self.shutdown()
        except:
            pass
        return False