summaryrefslogtreecommitdiff
path: root/src/werkzeug/wsgi.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/werkzeug/wsgi.py')
-rw-r--r--src/werkzeug/wsgi.py230
1 files changed, 124 insertions, 106 deletions
diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py
index 60966503..f069f2d8 100644
--- a/src/werkzeug/wsgi.py
+++ b/src/werkzeug/wsgi.py
@@ -11,14 +11,24 @@
import io
import re
import warnings
-from functools import partial, update_wrapper
+from functools import partial
+from functools import update_wrapper
from itertools import chain
-from werkzeug._compat import BytesIO, implements_iterator, \
- make_literal_wrapper, string_types, text_type, to_bytes, to_unicode, \
- try_coerce_native, wsgi_get_bytes
-from werkzeug._internal import _encode_idna
-from werkzeug.urls import uri_to_iri, url_join, url_parse, url_quote
+from ._compat import BytesIO
+from ._compat import implements_iterator
+from ._compat import make_literal_wrapper
+from ._compat import string_types
+from ._compat import text_type
+from ._compat import to_bytes
+from ._compat import to_unicode
+from ._compat import try_coerce_native
+from ._compat import wsgi_get_bytes
+from ._internal import _encode_idna
+from .urls import uri_to_iri
+from .urls import url_join
+from .urls import url_parse
+from .urls import url_quote
def responder(f):
@@ -34,8 +44,13 @@ def responder(f):
return update_wrapper(lambda *a: f(*a)(*a[-2:]), f)
-def get_current_url(environ, root_only=False, strip_querystring=False,
- host_only=False, trusted_hosts=None):
+def get_current_url(
+ environ,
+ root_only=False,
+ strip_querystring=False,
+ host_only=False,
+ trusted_hosts=None,
+):
"""A handy helper function that recreates the full URL as IRI for the
current request or parts of it. Here's an example:
@@ -70,19 +85,19 @@ def get_current_url(environ, root_only=False, strip_querystring=False,
:param trusted_hosts: a list of trusted hosts, see :func:`host_is_trusted`
for more information.
"""
- tmp = [environ['wsgi.url_scheme'], '://', get_host(environ, trusted_hosts)]
+ tmp = [environ["wsgi.url_scheme"], "://", get_host(environ, trusted_hosts)]
cat = tmp.append
if host_only:
- return uri_to_iri(''.join(tmp) + '/')
- cat(url_quote(wsgi_get_bytes(environ.get('SCRIPT_NAME', ''))).rstrip('/'))
- cat('/')
+ return uri_to_iri("".join(tmp) + "/")
+ cat(url_quote(wsgi_get_bytes(environ.get("SCRIPT_NAME", ""))).rstrip("/"))
+ cat("/")
if not root_only:
- cat(url_quote(wsgi_get_bytes(environ.get('PATH_INFO', '')).lstrip(b'/')))
+ cat(url_quote(wsgi_get_bytes(environ.get("PATH_INFO", "")).lstrip(b"/")))
if not strip_querystring:
qs = get_query_string(environ)
if qs:
- cat('?' + qs)
- return uri_to_iri(''.join(tmp))
+ cat("?" + qs)
+ return uri_to_iri("".join(tmp))
def host_is_trusted(hostname, trusted_list):
@@ -103,8 +118,8 @@ def host_is_trusted(hostname, trusted_list):
trusted_list = [trusted_list]
def _normalize(hostname):
- if ':' in hostname:
- hostname = hostname.rsplit(':', 1)[0]
+ if ":" in hostname:
+ hostname = hostname.rsplit(":", 1)[0]
return _encode_idna(hostname)
try:
@@ -112,7 +127,7 @@ def host_is_trusted(hostname, trusted_list):
except UnicodeError:
return False
for ref in trusted_list:
- if ref.startswith('.'):
+ if ref.startswith("."):
ref = ref[1:]
suffix_match = True
else:
@@ -123,7 +138,7 @@ def host_is_trusted(hostname, trusted_list):
return False
if ref == hostname:
return True
- if suffix_match and hostname.endswith(b'.' + ref):
+ if suffix_match and hostname.endswith(b"." + ref):
return True
return False
@@ -144,20 +159,23 @@ def get_host(environ, trusted_hosts=None):
:raise ~werkzeug.exceptions.SecurityError: If the host is not
trusted.
"""
- if 'HTTP_HOST' in environ:
- rv = environ['HTTP_HOST']
- if environ['wsgi.url_scheme'] == 'http' and rv.endswith(':80'):
+ if "HTTP_HOST" in environ:
+ rv = environ["HTTP_HOST"]
+ if environ["wsgi.url_scheme"] == "http" and rv.endswith(":80"):
rv = rv[:-3]
- elif environ['wsgi.url_scheme'] == 'https' and rv.endswith(':443'):
+ elif environ["wsgi.url_scheme"] == "https" and rv.endswith(":443"):
rv = rv[:-4]
else:
- rv = environ['SERVER_NAME']
- if (environ['wsgi.url_scheme'], environ['SERVER_PORT']) not \
- in (('https', '443'), ('http', '80')):
- rv += ':' + environ['SERVER_PORT']
+ rv = environ["SERVER_NAME"]
+ if (environ["wsgi.url_scheme"], environ["SERVER_PORT"]) not in (
+ ("https", "443"),
+ ("http", "80"),
+ ):
+ rv += ":" + environ["SERVER_PORT"]
if trusted_hosts is not None:
if not host_is_trusted(rv, trusted_hosts):
- from werkzeug.exceptions import SecurityError
+ from .exceptions import SecurityError
+
raise SecurityError('Host "%s" is not trusted' % rv)
return rv
@@ -171,10 +189,10 @@ def get_content_length(environ):
:param environ: the WSGI environ to fetch the content length from.
"""
- if environ.get('HTTP_TRANSFER_ENCODING', '') == 'chunked':
+ if environ.get("HTTP_TRANSFER_ENCODING", "") == "chunked":
return None
- content_length = environ.get('CONTENT_LENGTH')
+ content_length = environ.get("CONTENT_LENGTH")
if content_length is not None:
try:
return max(0, int(content_length))
@@ -199,13 +217,13 @@ def get_input_stream(environ, safe_fallback=True):
content length is not set. Disabling this allows infinite streams,
which can be a denial-of-service risk.
"""
- stream = environ['wsgi.input']
+ stream = environ["wsgi.input"]
content_length = get_content_length(environ)
# A wsgi extension that tells us if the input is terminated. In
# that case we return the stream unchanged as we know we can safely
# read it until the end.
- if environ.get('wsgi.input_terminated'):
+ if environ.get("wsgi.input_terminated"):
return stream
# If the request doesn't specify a content length, returning the stream is
@@ -228,14 +246,14 @@ def get_query_string(environ):
:param environ: the WSGI environment object to get the query string from.
"""
- qs = wsgi_get_bytes(environ.get('QUERY_STRING', ''))
+ qs = wsgi_get_bytes(environ.get("QUERY_STRING", ""))
# QUERY_STRING really should be ascii safe but some browsers
# will send us some unicode stuff (I am looking at you IE).
# In that case we want to urllib quote it badly.
- return try_coerce_native(url_quote(qs, safe=':&%=+$!*\'(),'))
+ return try_coerce_native(url_quote(qs, safe=":&%=+$!*'(),"))
-def get_path_info(environ, charset='utf-8', errors='replace'):
+def get_path_info(environ, charset="utf-8", errors="replace"):
"""Returns the `PATH_INFO` from the WSGI environment and properly
decodes it. This also takes care about the WSGI decoding dance
on Python 3 environments. if the `charset` is set to `None` a
@@ -248,11 +266,11 @@ def get_path_info(environ, charset='utf-8', errors='replace'):
decoding should be performed.
:param errors: the decoding error handling.
"""
- path = wsgi_get_bytes(environ.get('PATH_INFO', ''))
+ path = wsgi_get_bytes(environ.get("PATH_INFO", ""))
return to_unicode(path, charset, errors, allow_none_charset=True)
-def get_script_name(environ, charset='utf-8', errors='replace'):
+def get_script_name(environ, charset="utf-8", errors="replace"):
"""Returns the `SCRIPT_NAME` from the WSGI environment and properly
decodes it. This also takes care about the WSGI decoding dance
on Python 3 environments. if the `charset` is set to `None` a
@@ -265,11 +283,11 @@ def get_script_name(environ, charset='utf-8', errors='replace'):
decoding should be performed.
:param errors: the decoding error handling.
"""
- path = wsgi_get_bytes(environ.get('SCRIPT_NAME', ''))
+ path = wsgi_get_bytes(environ.get("SCRIPT_NAME", ""))
return to_unicode(path, charset, errors, allow_none_charset=True)
-def pop_path_info(environ, charset='utf-8', errors='replace'):
+def pop_path_info(environ, charset="utf-8", errors="replace"):
"""Removes and returns the next segment of `PATH_INFO`, pushing it onto
`SCRIPT_NAME`. Returns `None` if there is nothing left on `PATH_INFO`.
@@ -296,32 +314,32 @@ def pop_path_info(environ, charset='utf-8', errors='replace'):
:param environ: the WSGI environment that is modified.
"""
- path = environ.get('PATH_INFO')
+ path = environ.get("PATH_INFO")
if not path:
return None
- script_name = environ.get('SCRIPT_NAME', '')
+ script_name = environ.get("SCRIPT_NAME", "")
# shift multiple leading slashes over
old_path = path
- path = path.lstrip('/')
+ path = path.lstrip("/")
if path != old_path:
- script_name += '/' * (len(old_path) - len(path))
+ script_name += "/" * (len(old_path) - len(path))
- if '/' not in path:
- environ['PATH_INFO'] = ''
- environ['SCRIPT_NAME'] = script_name + path
+ if "/" not in path:
+ environ["PATH_INFO"] = ""
+ environ["SCRIPT_NAME"] = script_name + path
rv = wsgi_get_bytes(path)
else:
- segment, path = path.split('/', 1)
- environ['PATH_INFO'] = '/' + path
- environ['SCRIPT_NAME'] = script_name + segment
+ segment, path = path.split("/", 1)
+ environ["PATH_INFO"] = "/" + path
+ environ["SCRIPT_NAME"] = script_name + segment
rv = wsgi_get_bytes(segment)
return to_unicode(rv, charset, errors, allow_none_charset=True)
-def peek_path_info(environ, charset='utf-8', errors='replace'):
+def peek_path_info(environ, charset="utf-8", errors="replace"):
"""Returns the next segment on the `PATH_INFO` or `None` if there
is none. Works like :func:`pop_path_info` without modifying the
environment:
@@ -342,18 +360,19 @@ def peek_path_info(environ, charset='utf-8', errors='replace'):
:param environ: the WSGI environment that is checked.
"""
- segments = environ.get('PATH_INFO', '').lstrip('/').split('/', 1)
+ segments = environ.get("PATH_INFO", "").lstrip("/").split("/", 1)
if segments:
- return to_unicode(wsgi_get_bytes(segments[0]),
- charset, errors, allow_none_charset=True)
+ return to_unicode(
+ wsgi_get_bytes(segments[0]), charset, errors, allow_none_charset=True
+ )
def extract_path_info(
environ_or_baseurl,
path_or_url,
- charset='utf-8',
- errors='werkzeug.url_quote',
- collapse_http_schemes=True
+ charset="utf-8",
+ errors="werkzeug.url_quote",
+ collapse_http_schemes=True,
):
"""Extracts the path info from the given URL (or WSGI environment) and
path. The path info returned is a unicode string, not a bytestring
@@ -395,29 +414,29 @@ def extract_path_info(
.. versionadded:: 0.6
"""
+
def _normalize_netloc(scheme, netloc):
- parts = netloc.split(u'@', 1)[-1].split(u':', 1)
+ parts = netloc.split(u"@", 1)[-1].split(u":", 1)
if len(parts) == 2:
netloc, port = parts
- if (scheme == u'http' and port == u'80') or \
- (scheme == u'https' and port == u'443'):
+ if (scheme == u"http" and port == u"80") or (
+ scheme == u"https" and port == u"443"
+ ):
port = None
else:
netloc = parts[0]
port = None
if port is not None:
- netloc += u':' + port
+ netloc += u":" + port
return netloc
# make sure whatever we are working on is a IRI and parse it
path = uri_to_iri(path_or_url, charset, errors)
if isinstance(environ_or_baseurl, dict):
- environ_or_baseurl = get_current_url(environ_or_baseurl,
- root_only=True)
+ environ_or_baseurl = get_current_url(environ_or_baseurl, root_only=True)
base_iri = uri_to_iri(environ_or_baseurl, charset, errors)
base_scheme, base_netloc, base_path = url_parse(base_iri)[:3]
- cur_scheme, cur_netloc, cur_path, = \
- url_parse(url_join(base_iri, path))[:3]
+ cur_scheme, cur_netloc, cur_path, = url_parse(url_join(base_iri, path))[:3]
# normalize the network location
base_netloc = _normalize_netloc(base_scheme, base_netloc)
@@ -426,11 +445,10 @@ def extract_path_info(
# is that IRI even on a known HTTP scheme?
if collapse_http_schemes:
for scheme in base_scheme, cur_scheme:
- if scheme not in (u'http', u'https'):
+ if scheme not in (u"http", u"https"):
return None
else:
- if not (base_scheme in (u'http', u'https')
- and base_scheme == cur_scheme):
+ if not (base_scheme in (u"http", u"https") and base_scheme == cur_scheme):
return None
# are the netlocs compatible?
@@ -438,16 +456,15 @@ def extract_path_info(
return None
# are we below the application path?
- base_path = base_path.rstrip(u'/')
+ base_path = base_path.rstrip(u"/")
if not cur_path.startswith(base_path):
return None
- return u'/' + cur_path[len(base_path):].lstrip(u'/')
+ return u"/" + cur_path[len(base_path) :].lstrip(u"/")
@implements_iterator
class ClosingIterator(object):
-
"""The WSGI specification requires that all middlewares and gateways
respect the `close` callback of the iterable returned by the application.
Because it is useful to add another close action to a returned iterable
@@ -478,7 +495,7 @@ class ClosingIterator(object):
callbacks = [callbacks]
else:
callbacks = list(callbacks)
- iterable_close = getattr(iterable, 'close', None)
+ iterable_close = getattr(iterable, "close", None)
if iterable_close:
callbacks.insert(0, iterable_close)
self._callbacks = callbacks
@@ -510,12 +527,11 @@ def wrap_file(environ, file, buffer_size=8192):
:param file: a :class:`file`-like object with a :meth:`~file.read` method.
:param buffer_size: number of bytes for one iteration.
"""
- return environ.get('wsgi.file_wrapper', FileWrapper)(file, buffer_size)
+ return environ.get("wsgi.file_wrapper", FileWrapper)(file, buffer_size)
@implements_iterator
class FileWrapper(object):
-
"""This class can be used to convert a :class:`file`-like object into
an iterable. It yields `buffer_size` blocks until the file is fully
read.
@@ -538,22 +554,22 @@ class FileWrapper(object):
self.buffer_size = buffer_size
def close(self):
- if hasattr(self.file, 'close'):
+ if hasattr(self.file, "close"):
self.file.close()
def seekable(self):
- if hasattr(self.file, 'seekable'):
+ if hasattr(self.file, "seekable"):
return self.file.seekable()
- if hasattr(self.file, 'seek'):
+ if hasattr(self.file, "seek"):
return True
return False
def seek(self, *args):
- if hasattr(self.file, 'seek'):
+ if hasattr(self.file, "seek"):
self.file.seek(*args)
def tell(self):
- if hasattr(self.file, 'tell'):
+ if hasattr(self.file, "tell"):
return self.file.tell()
return None
@@ -593,7 +609,7 @@ class _RangeWrapper(object):
if byte_range is not None:
self.end_byte = self.start_byte + self.byte_range
self.read_length = 0
- self.seekable = hasattr(iterable, 'seekable') and iterable.seekable()
+ self.seekable = hasattr(iterable, "seekable") and iterable.seekable()
self.end_reached = False
def __iter__(self):
@@ -618,7 +634,7 @@ class _RangeWrapper(object):
while self.read_length <= self.start_byte:
chunk = self._next_chunk()
if chunk is not None:
- chunk = chunk[self.start_byte - self.read_length:]
+ chunk = chunk[self.start_byte - self.read_length :]
contextual_read_length = self.start_byte
return chunk, contextual_read_length
@@ -633,7 +649,7 @@ class _RangeWrapper(object):
chunk = self._next_chunk()
if self.end_byte is not None and self.read_length >= self.end_byte:
self.end_reached = True
- return chunk[:self.end_byte - contextual_read_length]
+ return chunk[: self.end_byte - contextual_read_length]
return chunk
def __next__(self):
@@ -644,16 +660,17 @@ class _RangeWrapper(object):
raise StopIteration()
def close(self):
- if hasattr(self.iterable, 'close'):
+ if hasattr(self.iterable, "close"):
self.iterable.close()
def _make_chunk_iter(stream, limit, buffer_size):
"""Helper for the line and chunk iter functions."""
if isinstance(stream, (bytes, bytearray, text_type)):
- raise TypeError('Passed a string or byte object instead of '
- 'true iterator or stream.')
- if not hasattr(stream, 'read'):
+ raise TypeError(
+ "Passed a string or byte object instead of true iterator or stream."
+ )
+ if not hasattr(stream, "read"):
for item in stream:
if item:
yield item
@@ -668,8 +685,7 @@ def _make_chunk_iter(stream, limit, buffer_size):
yield item
-def make_line_iter(stream, limit=None, buffer_size=10 * 1024,
- cap_at_buffer=False):
+def make_line_iter(stream, limit=None, buffer_size=10 * 1024, cap_at_buffer=False):
"""Safely iterates line-based over an input stream. If the input stream
is not a :class:`LimitedStream` the `limit` parameter is mandatory.
@@ -703,15 +719,15 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024,
"""
_iter = _make_chunk_iter(stream, limit, buffer_size)
- first_item = next(_iter, '')
+ first_item = next(_iter, "")
if not first_item:
return
s = make_literal_wrapper(first_item)
- empty = s('')
- cr = s('\r')
- lf = s('\n')
- crlf = s('\r\n')
+ empty = s("")
+ cr = s("\r")
+ lf = s("\n")
+ crlf = s("\r\n")
_iter = chain((first_item,), _iter)
@@ -719,7 +735,7 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024,
_join = empty.join
buffer = []
while 1:
- new_data = next(_iter, '')
+ new_data = next(_iter, "")
if not new_data:
break
new_buf = []
@@ -754,8 +770,9 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024,
yield previous
-def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024,
- cap_at_buffer=False):
+def make_chunk_iter(
+ stream, separator, limit=None, buffer_size=10 * 1024, cap_at_buffer=False
+):
"""Works like :func:`make_line_iter` but accepts a separator
which divides chunks. If you want newline based processing
you should use :func:`make_line_iter` instead as it
@@ -782,23 +799,23 @@ def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024,
"""
_iter = _make_chunk_iter(stream, limit, buffer_size)
- first_item = next(_iter, '')
+ first_item = next(_iter, "")
if not first_item:
return
_iter = chain((first_item,), _iter)
if isinstance(first_item, text_type):
separator = to_unicode(separator)
- _split = re.compile(r'(%s)' % re.escape(separator)).split
- _join = u''.join
+ _split = re.compile(r"(%s)" % re.escape(separator)).split
+ _join = u"".join
else:
separator = to_bytes(separator)
- _split = re.compile(b'(' + re.escape(separator) + b')').split
- _join = b''.join
+ _split = re.compile(b"(" + re.escape(separator) + b")").split
+ _join = b"".join
buffer = []
while 1:
- new_data = next(_iter, '')
+ new_data = next(_iter, "")
if not new_data:
break
chunks = _split(new_data)
@@ -828,7 +845,6 @@ def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024,
@implements_iterator
class LimitedStream(io.IOBase):
-
"""Wraps a stream so that it doesn't read more than n bytes. If the
stream is exhausted and the caller tries to get more bytes from it
:func:`on_exhausted` is called which by default returns an empty
@@ -891,7 +907,8 @@ class LimitedStream(io.IOBase):
the client went away. By default a
:exc:`~werkzeug.exceptions.ClientDisconnected` exception is raised.
"""
- from werkzeug.exceptions import ClientDisconnected
+ from .exceptions import ClientDisconnected
+
raise ClientDisconnected()
def exhaust(self, chunk_size=1024 * 64):
@@ -984,9 +1001,10 @@ class LimitedStream(io.IOBase):
return True
-from werkzeug.middleware.dispatcher import DispatcherMiddleware as _DispatcherMiddleware
-from werkzeug.middleware.http_proxy import ProxyMiddleware as _ProxyMiddleware
-from werkzeug.middleware.shared_data import SharedDataMiddleware as _SharedDataMiddleware
+# DEPRECATED
+from .middleware.dispatcher import DispatcherMiddleware as _DispatcherMiddleware
+from .middleware.http_proxy import ProxyMiddleware as _ProxyMiddleware
+from .middleware.shared_data import SharedDataMiddleware as _SharedDataMiddleware
class ProxyMiddleware(_ProxyMiddleware):