diff options
Diffstat (limited to 'src/werkzeug/wsgi.py')
| -rw-r--r-- | src/werkzeug/wsgi.py | 230 |
1 files changed, 124 insertions, 106 deletions
diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py index 60966503..f069f2d8 100644 --- a/src/werkzeug/wsgi.py +++ b/src/werkzeug/wsgi.py @@ -11,14 +11,24 @@ import io import re import warnings -from functools import partial, update_wrapper +from functools import partial +from functools import update_wrapper from itertools import chain -from werkzeug._compat import BytesIO, implements_iterator, \ - make_literal_wrapper, string_types, text_type, to_bytes, to_unicode, \ - try_coerce_native, wsgi_get_bytes -from werkzeug._internal import _encode_idna -from werkzeug.urls import uri_to_iri, url_join, url_parse, url_quote +from ._compat import BytesIO +from ._compat import implements_iterator +from ._compat import make_literal_wrapper +from ._compat import string_types +from ._compat import text_type +from ._compat import to_bytes +from ._compat import to_unicode +from ._compat import try_coerce_native +from ._compat import wsgi_get_bytes +from ._internal import _encode_idna +from .urls import uri_to_iri +from .urls import url_join +from .urls import url_parse +from .urls import url_quote def responder(f): @@ -34,8 +44,13 @@ def responder(f): return update_wrapper(lambda *a: f(*a)(*a[-2:]), f) -def get_current_url(environ, root_only=False, strip_querystring=False, - host_only=False, trusted_hosts=None): +def get_current_url( + environ, + root_only=False, + strip_querystring=False, + host_only=False, + trusted_hosts=None, +): """A handy helper function that recreates the full URL as IRI for the current request or parts of it. Here's an example: @@ -70,19 +85,19 @@ def get_current_url(environ, root_only=False, strip_querystring=False, :param trusted_hosts: a list of trusted hosts, see :func:`host_is_trusted` for more information. """ - tmp = [environ['wsgi.url_scheme'], '://', get_host(environ, trusted_hosts)] + tmp = [environ["wsgi.url_scheme"], "://", get_host(environ, trusted_hosts)] cat = tmp.append if host_only: - return uri_to_iri(''.join(tmp) + '/') - cat(url_quote(wsgi_get_bytes(environ.get('SCRIPT_NAME', ''))).rstrip('/')) - cat('/') + return uri_to_iri("".join(tmp) + "/") + cat(url_quote(wsgi_get_bytes(environ.get("SCRIPT_NAME", ""))).rstrip("/")) + cat("/") if not root_only: - cat(url_quote(wsgi_get_bytes(environ.get('PATH_INFO', '')).lstrip(b'/'))) + cat(url_quote(wsgi_get_bytes(environ.get("PATH_INFO", "")).lstrip(b"/"))) if not strip_querystring: qs = get_query_string(environ) if qs: - cat('?' + qs) - return uri_to_iri(''.join(tmp)) + cat("?" + qs) + return uri_to_iri("".join(tmp)) def host_is_trusted(hostname, trusted_list): @@ -103,8 +118,8 @@ def host_is_trusted(hostname, trusted_list): trusted_list = [trusted_list] def _normalize(hostname): - if ':' in hostname: - hostname = hostname.rsplit(':', 1)[0] + if ":" in hostname: + hostname = hostname.rsplit(":", 1)[0] return _encode_idna(hostname) try: @@ -112,7 +127,7 @@ def host_is_trusted(hostname, trusted_list): except UnicodeError: return False for ref in trusted_list: - if ref.startswith('.'): + if ref.startswith("."): ref = ref[1:] suffix_match = True else: @@ -123,7 +138,7 @@ def host_is_trusted(hostname, trusted_list): return False if ref == hostname: return True - if suffix_match and hostname.endswith(b'.' + ref): + if suffix_match and hostname.endswith(b"." + ref): return True return False @@ -144,20 +159,23 @@ def get_host(environ, trusted_hosts=None): :raise ~werkzeug.exceptions.SecurityError: If the host is not trusted. """ - if 'HTTP_HOST' in environ: - rv = environ['HTTP_HOST'] - if environ['wsgi.url_scheme'] == 'http' and rv.endswith(':80'): + if "HTTP_HOST" in environ: + rv = environ["HTTP_HOST"] + if environ["wsgi.url_scheme"] == "http" and rv.endswith(":80"): rv = rv[:-3] - elif environ['wsgi.url_scheme'] == 'https' and rv.endswith(':443'): + elif environ["wsgi.url_scheme"] == "https" and rv.endswith(":443"): rv = rv[:-4] else: - rv = environ['SERVER_NAME'] - if (environ['wsgi.url_scheme'], environ['SERVER_PORT']) not \ - in (('https', '443'), ('http', '80')): - rv += ':' + environ['SERVER_PORT'] + rv = environ["SERVER_NAME"] + if (environ["wsgi.url_scheme"], environ["SERVER_PORT"]) not in ( + ("https", "443"), + ("http", "80"), + ): + rv += ":" + environ["SERVER_PORT"] if trusted_hosts is not None: if not host_is_trusted(rv, trusted_hosts): - from werkzeug.exceptions import SecurityError + from .exceptions import SecurityError + raise SecurityError('Host "%s" is not trusted' % rv) return rv @@ -171,10 +189,10 @@ def get_content_length(environ): :param environ: the WSGI environ to fetch the content length from. """ - if environ.get('HTTP_TRANSFER_ENCODING', '') == 'chunked': + if environ.get("HTTP_TRANSFER_ENCODING", "") == "chunked": return None - content_length = environ.get('CONTENT_LENGTH') + content_length = environ.get("CONTENT_LENGTH") if content_length is not None: try: return max(0, int(content_length)) @@ -199,13 +217,13 @@ def get_input_stream(environ, safe_fallback=True): content length is not set. Disabling this allows infinite streams, which can be a denial-of-service risk. """ - stream = environ['wsgi.input'] + stream = environ["wsgi.input"] content_length = get_content_length(environ) # A wsgi extension that tells us if the input is terminated. In # that case we return the stream unchanged as we know we can safely # read it until the end. - if environ.get('wsgi.input_terminated'): + if environ.get("wsgi.input_terminated"): return stream # If the request doesn't specify a content length, returning the stream is @@ -228,14 +246,14 @@ def get_query_string(environ): :param environ: the WSGI environment object to get the query string from. """ - qs = wsgi_get_bytes(environ.get('QUERY_STRING', '')) + qs = wsgi_get_bytes(environ.get("QUERY_STRING", "")) # QUERY_STRING really should be ascii safe but some browsers # will send us some unicode stuff (I am looking at you IE). # In that case we want to urllib quote it badly. - return try_coerce_native(url_quote(qs, safe=':&%=+$!*\'(),')) + return try_coerce_native(url_quote(qs, safe=":&%=+$!*'(),")) -def get_path_info(environ, charset='utf-8', errors='replace'): +def get_path_info(environ, charset="utf-8", errors="replace"): """Returns the `PATH_INFO` from the WSGI environment and properly decodes it. This also takes care about the WSGI decoding dance on Python 3 environments. if the `charset` is set to `None` a @@ -248,11 +266,11 @@ def get_path_info(environ, charset='utf-8', errors='replace'): decoding should be performed. :param errors: the decoding error handling. """ - path = wsgi_get_bytes(environ.get('PATH_INFO', '')) + path = wsgi_get_bytes(environ.get("PATH_INFO", "")) return to_unicode(path, charset, errors, allow_none_charset=True) -def get_script_name(environ, charset='utf-8', errors='replace'): +def get_script_name(environ, charset="utf-8", errors="replace"): """Returns the `SCRIPT_NAME` from the WSGI environment and properly decodes it. This also takes care about the WSGI decoding dance on Python 3 environments. if the `charset` is set to `None` a @@ -265,11 +283,11 @@ def get_script_name(environ, charset='utf-8', errors='replace'): decoding should be performed. :param errors: the decoding error handling. """ - path = wsgi_get_bytes(environ.get('SCRIPT_NAME', '')) + path = wsgi_get_bytes(environ.get("SCRIPT_NAME", "")) return to_unicode(path, charset, errors, allow_none_charset=True) -def pop_path_info(environ, charset='utf-8', errors='replace'): +def pop_path_info(environ, charset="utf-8", errors="replace"): """Removes and returns the next segment of `PATH_INFO`, pushing it onto `SCRIPT_NAME`. Returns `None` if there is nothing left on `PATH_INFO`. @@ -296,32 +314,32 @@ def pop_path_info(environ, charset='utf-8', errors='replace'): :param environ: the WSGI environment that is modified. """ - path = environ.get('PATH_INFO') + path = environ.get("PATH_INFO") if not path: return None - script_name = environ.get('SCRIPT_NAME', '') + script_name = environ.get("SCRIPT_NAME", "") # shift multiple leading slashes over old_path = path - path = path.lstrip('/') + path = path.lstrip("/") if path != old_path: - script_name += '/' * (len(old_path) - len(path)) + script_name += "/" * (len(old_path) - len(path)) - if '/' not in path: - environ['PATH_INFO'] = '' - environ['SCRIPT_NAME'] = script_name + path + if "/" not in path: + environ["PATH_INFO"] = "" + environ["SCRIPT_NAME"] = script_name + path rv = wsgi_get_bytes(path) else: - segment, path = path.split('/', 1) - environ['PATH_INFO'] = '/' + path - environ['SCRIPT_NAME'] = script_name + segment + segment, path = path.split("/", 1) + environ["PATH_INFO"] = "/" + path + environ["SCRIPT_NAME"] = script_name + segment rv = wsgi_get_bytes(segment) return to_unicode(rv, charset, errors, allow_none_charset=True) -def peek_path_info(environ, charset='utf-8', errors='replace'): +def peek_path_info(environ, charset="utf-8", errors="replace"): """Returns the next segment on the `PATH_INFO` or `None` if there is none. Works like :func:`pop_path_info` without modifying the environment: @@ -342,18 +360,19 @@ def peek_path_info(environ, charset='utf-8', errors='replace'): :param environ: the WSGI environment that is checked. """ - segments = environ.get('PATH_INFO', '').lstrip('/').split('/', 1) + segments = environ.get("PATH_INFO", "").lstrip("/").split("/", 1) if segments: - return to_unicode(wsgi_get_bytes(segments[0]), - charset, errors, allow_none_charset=True) + return to_unicode( + wsgi_get_bytes(segments[0]), charset, errors, allow_none_charset=True + ) def extract_path_info( environ_or_baseurl, path_or_url, - charset='utf-8', - errors='werkzeug.url_quote', - collapse_http_schemes=True + charset="utf-8", + errors="werkzeug.url_quote", + collapse_http_schemes=True, ): """Extracts the path info from the given URL (or WSGI environment) and path. The path info returned is a unicode string, not a bytestring @@ -395,29 +414,29 @@ def extract_path_info( .. versionadded:: 0.6 """ + def _normalize_netloc(scheme, netloc): - parts = netloc.split(u'@', 1)[-1].split(u':', 1) + parts = netloc.split(u"@", 1)[-1].split(u":", 1) if len(parts) == 2: netloc, port = parts - if (scheme == u'http' and port == u'80') or \ - (scheme == u'https' and port == u'443'): + if (scheme == u"http" and port == u"80") or ( + scheme == u"https" and port == u"443" + ): port = None else: netloc = parts[0] port = None if port is not None: - netloc += u':' + port + netloc += u":" + port return netloc # make sure whatever we are working on is a IRI and parse it path = uri_to_iri(path_or_url, charset, errors) if isinstance(environ_or_baseurl, dict): - environ_or_baseurl = get_current_url(environ_or_baseurl, - root_only=True) + environ_or_baseurl = get_current_url(environ_or_baseurl, root_only=True) base_iri = uri_to_iri(environ_or_baseurl, charset, errors) base_scheme, base_netloc, base_path = url_parse(base_iri)[:3] - cur_scheme, cur_netloc, cur_path, = \ - url_parse(url_join(base_iri, path))[:3] + cur_scheme, cur_netloc, cur_path, = url_parse(url_join(base_iri, path))[:3] # normalize the network location base_netloc = _normalize_netloc(base_scheme, base_netloc) @@ -426,11 +445,10 @@ def extract_path_info( # is that IRI even on a known HTTP scheme? if collapse_http_schemes: for scheme in base_scheme, cur_scheme: - if scheme not in (u'http', u'https'): + if scheme not in (u"http", u"https"): return None else: - if not (base_scheme in (u'http', u'https') - and base_scheme == cur_scheme): + if not (base_scheme in (u"http", u"https") and base_scheme == cur_scheme): return None # are the netlocs compatible? @@ -438,16 +456,15 @@ def extract_path_info( return None # are we below the application path? - base_path = base_path.rstrip(u'/') + base_path = base_path.rstrip(u"/") if not cur_path.startswith(base_path): return None - return u'/' + cur_path[len(base_path):].lstrip(u'/') + return u"/" + cur_path[len(base_path) :].lstrip(u"/") @implements_iterator class ClosingIterator(object): - """The WSGI specification requires that all middlewares and gateways respect the `close` callback of the iterable returned by the application. Because it is useful to add another close action to a returned iterable @@ -478,7 +495,7 @@ class ClosingIterator(object): callbacks = [callbacks] else: callbacks = list(callbacks) - iterable_close = getattr(iterable, 'close', None) + iterable_close = getattr(iterable, "close", None) if iterable_close: callbacks.insert(0, iterable_close) self._callbacks = callbacks @@ -510,12 +527,11 @@ def wrap_file(environ, file, buffer_size=8192): :param file: a :class:`file`-like object with a :meth:`~file.read` method. :param buffer_size: number of bytes for one iteration. """ - return environ.get('wsgi.file_wrapper', FileWrapper)(file, buffer_size) + return environ.get("wsgi.file_wrapper", FileWrapper)(file, buffer_size) @implements_iterator class FileWrapper(object): - """This class can be used to convert a :class:`file`-like object into an iterable. It yields `buffer_size` blocks until the file is fully read. @@ -538,22 +554,22 @@ class FileWrapper(object): self.buffer_size = buffer_size def close(self): - if hasattr(self.file, 'close'): + if hasattr(self.file, "close"): self.file.close() def seekable(self): - if hasattr(self.file, 'seekable'): + if hasattr(self.file, "seekable"): return self.file.seekable() - if hasattr(self.file, 'seek'): + if hasattr(self.file, "seek"): return True return False def seek(self, *args): - if hasattr(self.file, 'seek'): + if hasattr(self.file, "seek"): self.file.seek(*args) def tell(self): - if hasattr(self.file, 'tell'): + if hasattr(self.file, "tell"): return self.file.tell() return None @@ -593,7 +609,7 @@ class _RangeWrapper(object): if byte_range is not None: self.end_byte = self.start_byte + self.byte_range self.read_length = 0 - self.seekable = hasattr(iterable, 'seekable') and iterable.seekable() + self.seekable = hasattr(iterable, "seekable") and iterable.seekable() self.end_reached = False def __iter__(self): @@ -618,7 +634,7 @@ class _RangeWrapper(object): while self.read_length <= self.start_byte: chunk = self._next_chunk() if chunk is not None: - chunk = chunk[self.start_byte - self.read_length:] + chunk = chunk[self.start_byte - self.read_length :] contextual_read_length = self.start_byte return chunk, contextual_read_length @@ -633,7 +649,7 @@ class _RangeWrapper(object): chunk = self._next_chunk() if self.end_byte is not None and self.read_length >= self.end_byte: self.end_reached = True - return chunk[:self.end_byte - contextual_read_length] + return chunk[: self.end_byte - contextual_read_length] return chunk def __next__(self): @@ -644,16 +660,17 @@ class _RangeWrapper(object): raise StopIteration() def close(self): - if hasattr(self.iterable, 'close'): + if hasattr(self.iterable, "close"): self.iterable.close() def _make_chunk_iter(stream, limit, buffer_size): """Helper for the line and chunk iter functions.""" if isinstance(stream, (bytes, bytearray, text_type)): - raise TypeError('Passed a string or byte object instead of ' - 'true iterator or stream.') - if not hasattr(stream, 'read'): + raise TypeError( + "Passed a string or byte object instead of true iterator or stream." + ) + if not hasattr(stream, "read"): for item in stream: if item: yield item @@ -668,8 +685,7 @@ def _make_chunk_iter(stream, limit, buffer_size): yield item -def make_line_iter(stream, limit=None, buffer_size=10 * 1024, - cap_at_buffer=False): +def make_line_iter(stream, limit=None, buffer_size=10 * 1024, cap_at_buffer=False): """Safely iterates line-based over an input stream. If the input stream is not a :class:`LimitedStream` the `limit` parameter is mandatory. @@ -703,15 +719,15 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024, """ _iter = _make_chunk_iter(stream, limit, buffer_size) - first_item = next(_iter, '') + first_item = next(_iter, "") if not first_item: return s = make_literal_wrapper(first_item) - empty = s('') - cr = s('\r') - lf = s('\n') - crlf = s('\r\n') + empty = s("") + cr = s("\r") + lf = s("\n") + crlf = s("\r\n") _iter = chain((first_item,), _iter) @@ -719,7 +735,7 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024, _join = empty.join buffer = [] while 1: - new_data = next(_iter, '') + new_data = next(_iter, "") if not new_data: break new_buf = [] @@ -754,8 +770,9 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024, yield previous -def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024, - cap_at_buffer=False): +def make_chunk_iter( + stream, separator, limit=None, buffer_size=10 * 1024, cap_at_buffer=False +): """Works like :func:`make_line_iter` but accepts a separator which divides chunks. If you want newline based processing you should use :func:`make_line_iter` instead as it @@ -782,23 +799,23 @@ def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024, """ _iter = _make_chunk_iter(stream, limit, buffer_size) - first_item = next(_iter, '') + first_item = next(_iter, "") if not first_item: return _iter = chain((first_item,), _iter) if isinstance(first_item, text_type): separator = to_unicode(separator) - _split = re.compile(r'(%s)' % re.escape(separator)).split - _join = u''.join + _split = re.compile(r"(%s)" % re.escape(separator)).split + _join = u"".join else: separator = to_bytes(separator) - _split = re.compile(b'(' + re.escape(separator) + b')').split - _join = b''.join + _split = re.compile(b"(" + re.escape(separator) + b")").split + _join = b"".join buffer = [] while 1: - new_data = next(_iter, '') + new_data = next(_iter, "") if not new_data: break chunks = _split(new_data) @@ -828,7 +845,6 @@ def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024, @implements_iterator class LimitedStream(io.IOBase): - """Wraps a stream so that it doesn't read more than n bytes. If the stream is exhausted and the caller tries to get more bytes from it :func:`on_exhausted` is called which by default returns an empty @@ -891,7 +907,8 @@ class LimitedStream(io.IOBase): the client went away. By default a :exc:`~werkzeug.exceptions.ClientDisconnected` exception is raised. """ - from werkzeug.exceptions import ClientDisconnected + from .exceptions import ClientDisconnected + raise ClientDisconnected() def exhaust(self, chunk_size=1024 * 64): @@ -984,9 +1001,10 @@ class LimitedStream(io.IOBase): return True -from werkzeug.middleware.dispatcher import DispatcherMiddleware as _DispatcherMiddleware -from werkzeug.middleware.http_proxy import ProxyMiddleware as _ProxyMiddleware -from werkzeug.middleware.shared_data import SharedDataMiddleware as _SharedDataMiddleware +# DEPRECATED +from .middleware.dispatcher import DispatcherMiddleware as _DispatcherMiddleware +from .middleware.http_proxy import ProxyMiddleware as _ProxyMiddleware +from .middleware.shared_data import SharedDataMiddleware as _SharedDataMiddleware class ProxyMiddleware(_ProxyMiddleware): |
